feat(ai): Implement Zod-based tool validation and improve Agent API

- Replace JSON Schema with Zod schemas for tool parameter definitions
- Add runtime validation for all tool calls at provider level
- Create shared validation module with detailed error formatting
- Update Agent API with comprehensive event system
- Add agent tests with calculator tool for multi-turn execution
- Add abort test to verify proper handling of aborted requests
- Update documentation with detailed event flow examples
- Rename generate.ts to stream.ts for clarity
This commit is contained in:
Mario Zechner 2025-09-09 14:58:54 +02:00
parent 594b0dac6c
commit 35fe8f21e9
24 changed files with 1069 additions and 221 deletions

View file

@ -4,6 +4,7 @@ A collection of tools for managing LLM deployments and building AI agents.
## Packages
- **[@mariozechner/pi-ai](packages/ai)** - Unified multi-provider LLM API
- **[@mariozechner/pi-tui](packages/tui)** - Terminal UI library with differential rendering
- **[@mariozechner/pi-agent](packages/agent)** - General-purpose agent with tool calling and session persistence
- **[@mariozechner/pi](packages/pods)** - CLI for managing vLLM deployments on GPU pods

58
package-lock.json generated
View file

@ -20,9 +20,9 @@
}
},
"node_modules/@anthropic-ai/sdk": {
"version": "0.60.0",
"resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.60.0.tgz",
"integrity": "sha512-9zu/TXaUy8BZhXedDtt1wT3H4LOlpKDO1/ftiFpeR3N1PCr3KJFKkxxlQWWt1NNp08xSwUNJ3JNY8yhl8av6eQ==",
"version": "0.61.0",
"resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.61.0.tgz",
"integrity": "sha512-GnlOXrPxow0uoaVB3DGNh9EJBU1MyagCBCLpU+bwDVlj/oOPYIwoiasMWlykkfYcQOrDP2x/zHnRD0xN7PeZPw==",
"license": "MIT",
"bin": {
"anthropic-ai-sdk": "bin/cli"
@ -634,9 +634,9 @@
}
},
"node_modules/@google/genai": {
"version": "1.15.0",
"resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.15.0.tgz",
"integrity": "sha512-4CSW+hRTESWl3xVtde7pkQ3E+dDFhDq+m4ztmccRctZfx1gKy3v0M9STIMGk6Nq0s6O2uKMXupOZQ1JGorXVwQ==",
"version": "1.17.0",
"resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.17.0.tgz",
"integrity": "sha512-r/OZWN9D8WvYrte3bcKPoLODrZ+2TjfxHm5OOyVHUbdFYIp1C4yJaXX4+sCS8I/+CbN9PxLjU5zm1cgmS7qz+A==",
"license": "Apache-2.0",
"dependencies": {
"google-auth-library": "^9.14.2",
@ -646,7 +646,7 @@
"node": ">=20.0.0"
},
"peerDependencies": {
"@modelcontextprotocol/sdk": "^1.11.0"
"@modelcontextprotocol/sdk": "^1.11.4"
},
"peerDependenciesMeta": {
"@modelcontextprotocol/sdk": {
@ -1310,9 +1310,9 @@
}
},
"node_modules/chalk": {
"version": "5.5.0",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-5.5.0.tgz",
"integrity": "sha512-1tm8DTaJhPBG3bIkVeZt1iZM9GfSX2lzOeDVZH9R9ffRHpmHvxZ/QhgQH/aDTkswQVt+YHdXAdS/In/30OjCbg==",
"version": "5.6.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-5.6.2.tgz",
"integrity": "sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA==",
"license": "MIT",
"engines": {
"node": "^12.17.0 || ^14.13 || >=16.0.0"
@ -1907,9 +1907,9 @@
}
},
"node_modules/openai": {
"version": "5.15.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-5.15.0.tgz",
"integrity": "sha512-kcUdws8K/A8m02I+IqFBwO51gS+87GP89yWEufGbzEi8anBz4FB/bti2QxaJdGwwY4mwJGzx85XO7TuL/Tpu1w==",
"version": "5.20.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-5.20.0.tgz",
"integrity": "sha512-Bmc2zLM/YWgFrDpXr9hwXqGGDdMmMpE9+qoZPsaHpn0Y/Qk1Vu26hNqXo7+nHdli+sLsXINvS1f8kR3NKhGKmA==",
"license": "Apache-2.0",
"bin": {
"openai": "bin/cli"
@ -2714,12 +2714,31 @@
}
}
},
"node_modules/zod": {
"version": "3.25.76",
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
"license": "MIT",
"peer": true,
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}
},
"node_modules/zod-to-json-schema": {
"version": "3.24.6",
"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.6.tgz",
"integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==",
"license": "ISC",
"peerDependencies": {
"zod": "^3.24.1"
}
},
"packages/agent": {
"name": "@mariozechner/pi-agent",
"version": "0.5.31",
"license": "MIT",
"dependencies": {
"@mariozechner/pi-tui": "^0.5.30",
"@mariozechner/pi-tui": "^0.5.31",
"@types/glob": "^8.1.0",
"chalk": "^5.5.0",
"glob": "^11.0.3",
@ -3101,10 +3120,11 @@
"version": "0.5.31",
"license": "MIT",
"dependencies": {
"@anthropic-ai/sdk": "^0.60.0",
"@google/genai": "^1.15.0",
"chalk": "^5.5.0",
"openai": "^5.15.0"
"@anthropic-ai/sdk": "^0.61.0",
"@google/genai": "^1.17.0",
"chalk": "^5.6.2",
"openai": "^5.20.0",
"zod-to-json-schema": "^3.24.6"
},
"devDependencies": {
"@types/node": "^24.3.0",
@ -3137,7 +3157,7 @@
"version": "0.5.31",
"license": "MIT",
"dependencies": {
"@mariozechner/pi-agent": "^0.5.30",
"@mariozechner/pi-agent": "^0.5.31",
"chalk": "^5.5.0"
},
"bin": {

View file

@ -24,20 +24,18 @@ npm install @mariozechner/pi-ai
## Quick Start
```typescript
import { getModel, stream, complete, Context, Tool } from '@mariozechner/pi-ai';
import { getModel, stream, complete, Context, Tool, z } from '@mariozechner/pi-ai';
// Fully typed with auto-complete support for both providers and models
const model = getModel('openai', 'gpt-4o-mini');
// Define tools
// Define tools with Zod schemas for type safety and validation
const tools: Tool[] = [{
name: 'get_time',
description: 'Get the current time',
parameters: {
type: 'object',
properties: {},
required: []
}
parameters: z.object({
timezone: z.string().optional().describe('Optional timezone (e.g., America/New_York)')
})
}];
// Build a conversation context (easily serializable and transferable between models)
@ -94,7 +92,11 @@ const toolCalls = finalMessage.content.filter(b => b.type === 'toolCall');
for (const call of toolCalls) {
// Execute the tool
const result = call.name === 'get_time'
? new Date().toISOString()
? new Date().toLocaleString('en-US', {
timeZone: call.arguments.timezone || 'UTC',
dateStyle: 'full',
timeStyle: 'long'
})
: 'Unknown tool';
// Add tool result to context
@ -102,7 +104,7 @@ for (const call of toolCalls) {
role: 'toolResult',
toolCallId: call.id,
toolName: call.name,
content: result,
output: result,
isError: false
});
}
@ -129,6 +131,70 @@ for (const block of response.content) {
}
```
## Tools
Tools enable LLMs to interact with external systems. This library uses Zod schemas for type-safe tool definitions with automatic validation.
### Defining Tools
```typescript
import { z, Tool } from '@mariozechner/pi-ai';
// Define tool parameters with Zod
const weatherTool: Tool = {
name: 'get_weather',
description: 'Get current weather for a location',
parameters: z.object({
location: z.string().describe('City name or coordinates'),
units: z.enum(['celsius', 'fahrenheit']).default('celsius')
})
};
// Complex validation with Zod refinements
const bookMeetingTool: Tool = {
name: 'book_meeting',
description: 'Schedule a meeting',
parameters: z.object({
title: z.string().min(1),
startTime: z.string().datetime(),
endTime: z.string().datetime(),
attendees: z.array(z.string().email()).min(1)
}).refine(
data => new Date(data.endTime) > new Date(data.startTime),
{ message: 'End time must be after start time' }
)
};
```
### Handling Tool Calls
```typescript
const context: Context = {
messages: [{ role: 'user', content: 'What is the weather in London?' }],
tools: [weatherTool]
};
const response = await complete(model, context);
// Check for tool calls in the response
for (const block of response.content) {
if (block.type === 'toolCall') {
// Arguments are automatically validated against the Zod schema
// If validation fails, an error event is emitted
const result = await executeWeatherApi(block.arguments);
// Add tool result to continue the conversation
context.messages.push({
role: 'toolResult',
toolCallId: block.id,
toolName: block.name,
output: JSON.stringify(result),
isError: false
});
}
}
```
## Image Input
Models with vision capabilities can process images. You can check if a model supports images via the `input` property. If you pass images to a non-vision model, they are silently ignored.
@ -260,7 +326,7 @@ for await (const event of s) {
## Errors & Abort Signal
When a request ends with an error (including aborts), the API returns an `AssistantMessage` with:
When a request ends with an error (including aborts and tool call validation errors), the API returns an `AssistantMessage` with:
- `stopReason: 'error'` - Indicates the request ended with an error
- `error: string` - Error message describing what happened
- `content: array` - **Partial content** accumulated before the error
@ -503,6 +569,189 @@ const continuation = await complete(newModel, restored);
> **Note**: If the context contains images (encoded as base64 as shown in the Image Input section), those will also be serialized.
## Agent API
The Agent API provides a higher-level interface for building agents with tools. It handles tool execution, validation, and provides detailed event streaming for interactive applications.
### Event System
The Agent API streams events during execution, allowing you to build reactive UIs and track agent progress. The agent processes prompts in **turns**, where each turn consists of:
1. An assistant message (the LLM's response)
2. Optional tool executions if the assistant calls tools
3. Tool result messages that are fed back to the LLM
This continues until the assistant produces a response without tool calls.
### Event Flow Example
Given a prompt asking to calculate two expressions and sum them:
```typescript
import { prompt, AgentContext, calculateTool } from '@mariozechner/pi-ai';
const context: AgentContext = {
systemPrompt: 'You are a helpful math assistant.',
messages: [],
tools: [calculateTool]
};
const stream = prompt(
{ role: 'user', content: 'Calculate 15 * 20 and 30 * 40, then sum the results' },
context,
{ model: getModel('openai', 'gpt-4o-mini') }
);
// Expected event sequence:
// 1. agent_start - Agent begins processing
// 2. turn_start - First turn begins
// 3. message_start - User message starts
// 4. message_end - User message ends
// 5. message_start - Assistant message starts
// 6. message_update - Assistant streams response with tool calls
// 7. message_end - Assistant message ends
// 8. tool_execution_start - First calculation (15 * 20)
// 9. tool_execution_end - Result: 300
// 10. tool_execution_start - Second calculation (30 * 40)
// 11. tool_execution_end - Result: 1200
// 12. message_start - Tool result message for first calculation
// 13. message_end - Tool result message ends
// 14. message_start - Tool result message for second calculation
// 15. message_end - Tool result message ends
// 16. turn_end - First turn ends with 2 tool results
// 17. turn_start - Second turn begins
// 18. message_start - Assistant message starts
// 19. message_update - Assistant streams response with sum calculation
// 20. message_end - Assistant message ends
// 21. tool_execution_start - Sum calculation (300 + 1200)
// 22. tool_execution_end - Result: 1500
// 23. message_start - Tool result message for sum
// 24. message_end - Tool result message ends
// 25. turn_end - Second turn ends with 1 tool result
// 26. turn_start - Third turn begins
// 27. message_start - Final assistant message starts
// 28. message_update - Assistant streams final answer
// 29. message_end - Final assistant message ends
// 30. turn_end - Third turn ends with 0 tool results
// 31. agent_end - Agent completes with all messages
```
### Handling Events
```typescript
for await (const event of stream) {
switch (event.type) {
case 'agent_start':
console.log('Agent started');
break;
case 'turn_start':
console.log('New turn started');
break;
case 'message_start':
console.log(`${event.message.role} message started`);
break;
case 'message_update':
// Only for assistant messages during streaming
if (event.message.content.some(c => c.type === 'text')) {
console.log('Assistant:', event.message.content);
}
break;
case 'tool_execution_start':
console.log(`Calling ${event.toolName} with:`, event.args);
break;
case 'tool_execution_end':
if (event.isError) {
console.error(`Tool failed:`, event.result);
} else {
console.log(`Tool result:`, event.result.output);
}
break;
case 'turn_end':
console.log(`Turn ended with ${event.toolResults.length} tool calls`);
break;
case 'agent_end':
console.log(`Agent completed with ${event.messages.length} new messages`);
break;
}
}
// Get all messages generated during this agent execution
// These include the user message and can be directly appended to context.messages
const messages = await stream.result();
context.messages.push(...messages);
```
### Defining Tools with Zod
Tools use Zod schemas for runtime validation and type inference:
```typescript
import { z } from 'zod';
import { AgentTool, AgentToolResult } from '@mariozechner/pi-ai';
const weatherSchema = z.object({
city: z.string().min(1, 'City is required'),
units: z.enum(['celsius', 'fahrenheit']).default('celsius')
});
const weatherTool: AgentTool<typeof weatherSchema, { temp: number }> = {
label: 'Get Weather',
name: 'get_weather',
description: 'Get current weather for a city',
parameters: weatherSchema,
execute: async (toolCallId, args) => {
// args is fully typed: { city: string, units: 'celsius' | 'fahrenheit' }
const temp = Math.round(Math.random() * 30);
return {
output: `Temperature in ${args.city}: ${temp}°${args.units[0].toUpperCase()}`,
details: { temp }
};
}
};
```
### Validation and Error Handling
Tool arguments are automatically validated using the Zod schema. Invalid arguments result in detailed error messages:
```typescript
// If the LLM calls with invalid arguments:
// get_weather({ city: '', units: 'kelvin' })
// The tool execution will fail with:
/*
Validation failed for tool "get_weather":
- city: City is required
- units: Invalid enum value. Expected 'celsius' | 'fahrenheit', received 'kelvin'
Received arguments:
{
"city": "",
"units": "kelvin"
}
*/
```
### Built-in Example Tools
The library includes example tools for common operations:
```typescript
import { calculateTool, getCurrentTimeTool } from '@mariozechner/pi-ai';
const context: AgentContext = {
systemPrompt: 'You are a helpful assistant.',
messages: [],
tools: [calculateTool, getCurrentTimeTool]
};
```
## Browser Usage
The library supports browser environments. You must pass the API key explicitly since environment variables are not available in browsers:
@ -533,6 +782,7 @@ GEMINI_API_KEY=...
GROQ_API_KEY=gsk_...
CEREBRAS_API_KEY=csk-...
XAI_API_KEY=xai-...
ZAI_API_KEY=...
OPENROUTER_API_KEY=sk-or-...
```
@ -549,6 +799,21 @@ const response = await complete(model, context, {
});
```
### Programmatic API Key Management
You can also set and get API keys programmatically:
```typescript
import { setApiKey, getApiKey } from '@mariozechner/pi-ai';
// Set API key for a provider
setApiKey('openai', 'sk-...');
setApiKey('anthropic', 'sk-ant-...');
// Get API key for a provider (checks both programmatic and env vars)
const key = getApiKey('openai');
```
## License
MIT

View file

@ -19,10 +19,11 @@
"prepublishOnly": "npm run clean && npm run build"
},
"dependencies": {
"@anthropic-ai/sdk": "^0.60.0",
"@google/genai": "^1.15.0",
"chalk": "^5.5.0",
"openai": "^5.15.0"
"@anthropic-ai/sdk": "^0.61.0",
"@google/genai": "^1.17.0",
"chalk": "^5.6.2",
"openai": "^5.20.0",
"zod-to-json-schema": "^3.24.6"
},
"keywords": [
"ai",

View file

@ -1,97 +1,68 @@
import { EventStream } from "../event-stream";
import { streamSimple } from "../generate.js";
import type {
AssistantMessage,
Context,
Message,
Model,
SimpleGenerateOptions,
ToolResultMessage,
UserMessage,
} from "../types.js";
import type { AgentContext, AgentTool, AgentToolResult } from "./types";
// Event types
export type AgentEvent =
| { type: "message_start"; message: Message }
| { type: "message_update"; message: AssistantMessage }
| { type: "message_complete"; message: Message }
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
| {
type: "tool_execution_complete";
toolCallId: string;
toolName: string;
result: AgentToolResult<any> | string;
isError: boolean;
}
| { type: "turn_complete"; messages: AgentContext["messages"] };
// Configuration for prompt execution
export interface PromptConfig {
model: Model<any>;
apiKey: string;
enableThinking?: boolean;
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
}
import { streamSimple } from "../stream.js";
import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js";
import { validateToolArguments } from "../validation.js";
import type { AgentContext, AgentEvent, AgentTool, AgentToolResult, PromptConfig } from "./types";
// Main prompt function - returns a stream of events
export function prompt(
prompt: UserMessage,
context: AgentContext,
config: PromptConfig,
prompt: UserMessage,
signal?: AbortSignal,
): EventStream<AgentEvent, AgentContext["messages"]> {
const stream = new EventStream<AgentEvent, AgentContext["messages"]>(
(event) => event.type === "turn_complete",
(event) => (event.type === "turn_complete" ? event.messages : []),
(event) => event.type === "agent_end",
(event) => (event.type === "agent_end" ? event.messages : []),
);
// Run the prompt async
(async () => {
try {
// Track new messages generated during this prompt
const newMessages: AgentContext["messages"] = [];
// Track new messages generated during this prompt
const newMessages: AgentContext["messages"] = [];
// Create user message for the prompt
const messages = [...context.messages, prompt];
newMessages.push(prompt);
// Create user message
const messages = [...context.messages, prompt];
newMessages.push(prompt);
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_complete", message: prompt });
// Update context with new messages
const currentContext: AgentContext = {
...context,
messages,
};
// Update context with new messages
const currentContext: AgentContext = {
...context,
messages,
};
// Keep looping while we have tool calls
let hasMoreToolCalls = true;
while (hasMoreToolCalls) {
// Stream assistant response
const assistantMessage = await streamAssistantResponse(currentContext, config, signal, stream);
newMessages.push(assistantMessage);
// Check for tool calls
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
hasMoreToolCalls = toolCalls.length > 0;
if (hasMoreToolCalls) {
// Execute tool calls
const toolResults = await executeToolCalls(currentContext.tools, assistantMessage, signal, stream);
newMessages.push(...toolResults);
// Add tool results to context
currentContext.messages = [...currentContext.messages, ...toolResults];
}
// Keep looping while we have tool calls
let hasMoreToolCalls = true;
let firstTurn = true;
while (hasMoreToolCalls) {
if (!firstTurn) {
stream.push({ type: "turn_start" });
} else {
firstTurn = false;
}
// Stream assistant response
const assistantMessage = await streamAssistantResponse(currentContext, config, signal, stream);
newMessages.push(assistantMessage);
stream.push({ type: "turn_complete", messages: newMessages });
} catch (error) {
// End stream on error
stream.end([]);
throw error;
// Check for tool calls
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
hasMoreToolCalls = toolCalls.length > 0;
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
// Execute tool calls
toolResults.push(...(await executeToolCalls(currentContext.tools, assistantMessage, signal, stream)));
currentContext.messages.push(...toolResults);
newMessages.push(...toolResults);
}
stream.push({ type: "turn_end", assistantMessage, toolResults: toolResults });
}
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
})();
return stream;
@ -122,16 +93,7 @@ async function streamAssistantResponse(
tools: context.tools, // AgentTool extends Tool, so this works
};
const options: SimpleGenerateOptions = {
apiKey: config.apiKey,
signal,
};
if (config.model.reasoning && config.enableThinking) {
options.reasoning = "medium";
}
const response = await streamSimple(config.model, processedContext, options);
const response = await streamSimple(config.model, processedContext, { ...config, signal });
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
@ -147,14 +109,17 @@ async function streamAssistantResponse(
case "text_start":
case "text_delta":
case "text_end":
case "thinking_start":
case "thinking_delta":
case "thinking_end":
case "toolcall_start":
case "toolcall_delta":
case "toolcall_end":
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({ type: "message_update", message: { ...partialMessage } });
stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } });
}
break;
@ -166,7 +131,7 @@ async function streamAssistantResponse(
} else {
context.messages.push(finalMessage);
}
stream.push({ type: "message_complete", message: finalMessage });
stream.push({ type: "message_end", message: finalMessage });
return finalMessage;
}
}
@ -176,7 +141,7 @@ async function streamAssistantResponse(
}
async function executeToolCalls<T>(
tools: AgentTool<T>[] | undefined,
tools: AgentTool<any, T>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
@ -199,14 +164,19 @@ async function executeToolCalls<T>(
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
resultOrError = await tool.execute(toolCall.arguments, toolCall.id, signal);
// Validate arguments using shared validation function
const validatedArgs = validateToolArguments(tool, toolCall);
// Execute with validated, typed arguments
resultOrError = await tool.execute(toolCall.id, validatedArgs, signal);
} catch (e) {
resultOrError = `Error: ${e instanceof Error ? e.message : String(e)}`;
resultOrError = e instanceof Error ? e.message : String(e);
isError = true;
}
stream.push({
type: "tool_execution_complete",
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result: resultOrError,
@ -224,7 +194,7 @@ async function executeToolCalls<T>(
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_complete", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
}
return results;

View file

@ -1,3 +1,3 @@
export { type AgentEvent, type PromptConfig, prompt } from "./agent";
export { prompt } from "./agent";
export * from "./tools";
export type { AgentContext, AgentTool } from "./types";
export type { AgentContext, AgentEvent, AgentTool, PromptConfig } from "./types";

View file

@ -1,3 +1,4 @@
import { z } from "zod";
import type { AgentTool } from "../../agent";
export interface CalculateResult {
@ -14,21 +15,16 @@ export function calculate(expression: string): CalculateResult {
}
}
export const calculateTool: AgentTool<undefined> = {
const calculateSchema = z.object({
expression: z.string().describe("The mathematical expression to evaluate"),
});
export const calculateTool: AgentTool<typeof calculateSchema, undefined> = {
label: "Calculator",
name: "calculate",
description: "Evaluate mathematical expressions",
parameters: {
type: "object",
properties: {
expression: {
type: "string",
description: "The mathematical expression to evaluate",
},
},
required: ["expression"],
},
execute: async (args: { expression: string }) => {
parameters: calculateSchema,
execute: async (_toolCallId, args) => {
return calculate(args.expression);
},
};

View file

@ -1,3 +1,4 @@
import { z } from "zod";
import type { AgentTool } from "../../agent";
import type { AgentToolResult } from "../types";
@ -25,20 +26,16 @@ export async function getCurrentTime(timezone?: string): Promise<GetCurrentTimeR
};
}
export const getCurrentTimeTool: AgentTool<{ utcTimestamp: number }> = {
const getCurrentTimeSchema = z.object({
timezone: z.string().optional().describe("Optional timezone (e.g., 'America/New_York', 'Europe/London')"),
});
export const getCurrentTimeTool: AgentTool<typeof getCurrentTimeSchema, { utcTimestamp: number }> = {
label: "Current Time",
name: "get_current_time",
description: "Get the current date and time",
parameters: {
type: "object",
properties: {
timezone: {
type: "string",
description: "Optional timezone (e.g., 'America/New_York', 'Europe/London')",
},
},
},
execute: async (args: { timezone?: string }) => {
parameters: getCurrentTimeSchema,
execute: async (_toolCallId, args) => {
return getCurrentTime(args.timezone);
},
};

View file

@ -1,4 +1,13 @@
import type { Message, Tool } from "../types.js";
import type { ZodSchema, z } from "zod";
import type {
AssistantMessage,
AssistantMessageEvent,
Message,
Model,
SimpleStreamOptions,
Tool,
ToolResultMessage,
} from "../types.js";
export interface AgentToolResult<T> {
// Output of the tool to be given to the LLM in ToolResultMessage.content
@ -8,10 +17,14 @@ export interface AgentToolResult<T> {
}
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TDetails> extends Tool {
export interface AgentTool<TParameters extends ZodSchema = ZodSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (params: any, toolCallId: string, signal?: AbortSignal) => Promise<AgentToolResult<TDetails>>;
execute: (
toolCallId: string,
params: z.infer<TParameters>,
signal?: AbortSignal,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
@ -20,3 +33,37 @@ export interface AgentContext {
messages: Message[];
tools?: AgentTool<any>[];
}
// Event types
export type AgentEvent =
// Emitted when the agent starts. An agent can emit multiple turns
| { type: "agent_start" }
// Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages
| { type: "turn_start" }
// Emitted when a user, assistant or tool result message starts
| { type: "message_start"; message: Message }
// Emitted when an asssitant messages is updated due to streaming
| { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage }
// Emitted when a user, assistant or tool result message is complete
| { type: "message_end"; message: Message }
// Emitted when a tool execution starts
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
// Emitted when a tool execution completes
| {
type: "tool_execution_end";
toolCallId: string;
toolName: string;
result: AgentToolResult<any> | string;
isError: boolean;
}
// Emitted when a full turn completes
| { type: "turn_end"; assistantMessage: AssistantMessage; toolResults: ToolResultMessage[] }
// Emitted when the agent has completed all its turns. All messages from every turn are
// contained in messages, which can be appended to the context
| { type: "agent_end"; messages: AgentContext["messages"] };
// Configuration for prompt execution
export interface PromptConfig extends SimpleStreamOptions {
model: Model<any>;
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
}

View file

@ -1,8 +1,9 @@
export { z } from "zod";
export * from "./agent/index.js";
export * from "./generate.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/google.js";
export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./stream.js";
export * from "./types.js";

View file

@ -1413,6 +1413,23 @@ export const MODELS = {
} satisfies Model<"anthropic-messages">,
},
openrouter: {
"nvidia/nemotron-nano-9b-v2": {
id: "nvidia/nemotron-nano-9b-v2",
name: "NVIDIA: Nemotron Nano 9B V2",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: true,
input: ["text"],
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openrouter/sonoma-dusk-alpha": {
id: "openrouter/sonoma-dusk-alpha",
name: "Sonoma Dusk Alpha",

View file

@ -4,32 +4,34 @@ import type {
MessageCreateParamsStreaming,
MessageParam,
} from "@anthropic-ai/sdk/resources/messages.js";
import { zodToJsonSchema } from "zod-to-json-schema";
import { AssistantMessageEventStream } from "../event-stream.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
Message,
Model,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { validateToolArguments } from "../validation.js";
import { transformMessages } from "./transorm-messages.js";
export interface AnthropicOptions extends GenerateOptions {
export interface AnthropicOptions extends StreamOptions {
thinkingEnabled?: boolean;
thinkingBudgetTokens?: number;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
export const streamAnthropic: StreamFunction<"anthropic-messages"> = (
model: Model<"anthropic-messages">,
context: Context,
options?: AnthropicOptions,
@ -159,6 +161,15 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
});
} else if (block.type === "toolCall") {
block.arguments = JSON.parse(block.partialJson);
// Validate tool arguments if tool definition is available
if (context.tools) {
const tool = context.tools.find((t) => t.name === block.name);
if (tool) {
block.arguments = validateToolArguments(tool, block);
}
}
delete (block as any).partialJson;
stream.push({
type: "toolcall_end",
@ -390,7 +401,7 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages">
content: blocks,
});
} else if (msg.role === "toolResult") {
// Collect all consecutive toolResult messages
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
const toolResults: ContentBlockParam[] = [];
// Add the current tool result
@ -430,15 +441,19 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages">
function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] {
if (!tools) return [];
return tools.map((tool) => ({
name: tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: tool.parameters.properties || {},
required: tool.parameters.required || [],
},
}));
return tools.map((tool) => {
const jsonSchema = zodToJsonSchema(tool.parameters, { $refStrategy: "none" }) as any;
return {
name: tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: jsonSchema.properties || {},
required: jsonSchema.required || [],
},
};
});
}
function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason {

View file

@ -7,24 +7,26 @@ import {
GoogleGenAI,
type Part,
} from "@google/genai";
import { zodToJsonSchema } from "zod-to-json-schema";
import { AssistantMessageEventStream } from "../event-stream.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
Model,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { validateToolArguments } from "../validation.js";
import { transformMessages } from "./transorm-messages.js";
export interface GoogleOptions extends GenerateOptions {
export interface GoogleOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
@ -35,7 +37,7 @@ export interface GoogleOptions extends GenerateOptions {
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
export const streamGoogle: StreamFunction<"google-generative-ai"> = (
model: Model<"google-generative-ai">,
context: Context,
options?: GoogleOptions,
@ -159,6 +161,15 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
name: part.functionCall.name || "",
arguments: part.functionCall.args as Record<string, any>,
};
// Validate tool arguments if tool definition is available
if (context.tools) {
const tool = context.tools.find((t) => t.name === toolCall.name);
if (tool) {
toolCall.arguments = validateToolArguments(tool, toolCall);
}
}
output.content.push(toolCall);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
@ -380,7 +391,7 @@ function convertTools(tools: Tool[]): any[] {
functionDeclarations: tools.map((tool) => ({
name: tool.name,
description: tool.description,
parameters: tool.parameters,
parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }),
})),
},
];

View file

@ -7,28 +7,30 @@ import type {
ChatCompletionContentPartText,
ChatCompletionMessageParam,
} from "openai/resources/chat/completions.js";
import { zodToJsonSchema } from "zod-to-json-schema";
import { AssistantMessageEventStream } from "../event-stream.js";
import { calculateCost } from "../models.js";
import type {
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
Model,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { validateToolArguments } from "../validation.js";
import { transformMessages } from "./transorm-messages.js";
export interface OpenAICompletionsOptions extends GenerateOptions {
export interface OpenAICompletionsOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
reasoningEffort?: "minimal" | "low" | "medium" | "high";
}
export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
export const streamOpenAICompletions: StreamFunction<"openai-completions"> = (
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
@ -79,6 +81,15 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
});
} else if (block.type === "toolCall") {
block.arguments = JSON.parse(block.partialArgs || "{}");
// Validate tool arguments if tool definition is available
if (context.tools) {
const tool = context.tools.find((t) => t.name === block.name);
if (tool) {
block.arguments = validateToolArguments(tool, block);
}
}
delete block.partialArgs;
stream.push({
type: "toolcall_end",
@ -381,7 +392,7 @@ function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }),
},
}));
}

View file

@ -10,25 +10,27 @@ import type {
ResponseOutputMessage,
ResponseReasoningItem,
} from "openai/resources/responses/responses.js";
import { zodToJsonSchema } from "zod-to-json-schema";
import { AssistantMessageEventStream } from "../event-stream.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
Model,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { validateToolArguments } from "../validation.js";
import { transformMessages } from "./transorm-messages.js";
// OpenAI Responses-specific options
export interface OpenAIResponsesOptions extends GenerateOptions {
export interface OpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
}
@ -36,7 +38,7 @@ export interface OpenAIResponsesOptions extends GenerateOptions {
/**
* Generate function for OpenAI Responses API
*/
export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
export const streamOpenAIResponses: StreamFunction<"openai-responses"> = (
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
@ -238,6 +240,15 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
name: item.name,
arguments: JSON.parse(item.arguments),
};
// Validate tool arguments if tool definition is available
if (context.tools) {
const tool = context.tools.find((t) => t.name === toolCall.name);
if (tool) {
toolCall.arguments = validateToolArguments(tool, toolCall);
}
}
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
}
@ -451,7 +462,7 @@ function convertTools(tools: Tool[]): OpenAITool[] {
type: "function",
name: tool.name,
description: tool.description,
parameters: tool.parameters,
parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }),
strict: null,
}));
}

View file

@ -11,7 +11,7 @@ import type {
Model,
OptionsForApi,
ReasoningEffort,
SimpleGenerateOptions,
SimpleStreamOptions,
} from "./types.js";
const apiKeys: Map<string, string> = new Map();
@ -90,7 +90,7 @@ export async function complete<TApi extends Api>(
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleGenerateOptions,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const apiKey = options?.apiKey || getApiKey(model.provider);
if (!apiKey) {
@ -104,7 +104,7 @@ export function streamSimple<TApi extends Api>(
export async function completeSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleGenerateOptions,
options?: SimpleStreamOptions,
): Promise<AssistantMessage> {
const s = streamSimple(model, context, options);
return s.result();
@ -112,7 +112,7 @@ export async function completeSimple<TApi extends Api>(
function mapOptionsForApi<TApi extends Api>(
model: Model<TApi>,
options?: SimpleGenerateOptions,
options?: SimpleStreamOptions,
apiKey?: string,
): OptionsForApi<TApi> {
const base = {

View file

@ -16,11 +16,11 @@ export interface ApiOptionsMap {
}
// Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys
type _CheckExhaustive = ApiOptionsMap extends Record<Api, GenerateOptions>
? Record<Api, GenerateOptions> extends ApiOptionsMap
type _CheckExhaustive = ApiOptionsMap extends Record<Api, StreamOptions>
? Record<Api, StreamOptions> extends ApiOptionsMap
? true
: ["ApiOptionsMap is missing some KnownApi values", Exclude<Api, keyof ApiOptionsMap>]
: ["ApiOptionsMap doesn't extend Record<KnownApi, GenerateOptions>"];
: ["ApiOptionsMap doesn't extend Record<KnownApi, StreamOptions>"];
const _exhaustive: _CheckExhaustive = true;
// Helper type to get options for a specific API
@ -32,20 +32,20 @@ export type Provider = KnownProvider | string;
export type ReasoningEffort = "minimal" | "low" | "medium" | "high";
// Base options all providers share
export interface GenerateOptions {
export interface StreamOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
}
// Unified options with reasoning (what public generate() accepts)
export interface SimpleGenerateOptions extends GenerateOptions {
// Unified options with reasoning passed to streamSimple() and completeSimple()
export interface SimpleStreamOptions extends StreamOptions {
reasoning?: ReasoningEffort;
}
// Generic GenerateFunction with typed options
export type GenerateFunction<TApi extends Api> = (
// Generic StreamFunction with typed options
export type StreamFunction<TApi extends Api> = (
model: Model<TApi>,
context: Context,
options: OptionsForApi<TApi>,
@ -119,10 +119,12 @@ export interface ToolResultMessage<TDetails = any> {
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
export interface Tool {
import type { ZodSchema } from "zod";
export interface Tool<TParameters extends ZodSchema = ZodSchema> {
name: string;
description: string;
parameters: Record<string, any>; // JSON Schema
parameters: TParameters;
}
export interface Context {

View file

@ -0,0 +1,32 @@
import { z } from "zod";
import type { Tool, ToolCall } from "./types.js";
/**
* Validates tool call arguments against the tool's Zod schema
* @param tool The tool definition with Zod schema
* @param toolCall The tool call from the LLM
* @returns The validated arguments
* @throws ZodError with formatted message if validation fails
*/
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
try {
// Validate arguments with Zod schema
return tool.parameters.parse(toolCall.arguments);
} catch (e) {
if (e instanceof z.ZodError) {
// Format validation errors nicely
const errors = e.issues
.map((err) => {
const path = err.path.length > 0 ? err.path.join(".") : "root";
return ` - ${path}: ${err.message}`;
})
.join("\n");
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
// Throw a new error with the formatted message
throw new Error(errorMessage);
}
throw e;
}
}

View file

@ -1,6 +1,6 @@
import { describe, expect, it } from "vitest";
import { complete, stream } from "../src/generate.js";
import { getModel } from "../src/models.js";
import { complete, stream } from "../src/stream.js";
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {

View file

@ -0,0 +1,347 @@
import { describe, expect, it } from "vitest";
import { prompt } from "../src/agent/agent.js";
import { calculateTool } from "../src/agent/tools/calculate.js";
import type { AgentContext, AgentEvent, PromptConfig } from "../src/agent/types.js";
import { getModel } from "../src/models.js";
import type { Api, Message, Model, OptionsForApi, UserMessage } from "../src/types.js";
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: PromptConfig = {
model,
...options,
};
// Create the user prompt asking for multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: `Use the calculator tool to complete the following mulit-step task.
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
2. Calculate the sum of the two results using the calculator tool
3. Output ONLY the final sum as a single integer number, nothing else.`,
};
// Calculate expected results (using integers)
const expectedFirst = 3485 * 4234; // = 14755490
const expectedSecond = 88823 * 3482; // = 309281786
const expectedSum = expectedFirst + expectedSecond; // = 324037276
// Track events for verification
const events: AgentEvent[] = [];
let turns = 0;
let toolCallCount = 0;
const toolResults: number[] = [];
let finalAnswer: number | undefined;
// Execute the prompt
const stream = prompt(userPrompt, context, config);
for await (const event of stream) {
events.push(event);
switch (event.type) {
case "turn_start":
turns++;
console.log(`\n=== Turn ${turns} started ===`);
break;
case "turn_end":
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
console.log(event.assistantMessage);
break;
case "tool_execution_end":
if (!event.isError && typeof event.result === "object" && event.result.output) {
toolCallCount++;
// Extract number from output like "expression = result"
const match = event.result.output.match(/=\s*([\d.]+)/);
if (match) {
const value = parseFloat(match[1]);
toolResults.push(value);
console.log(`Tool ${toolCallCount}: ${event.result.output}`);
}
}
break;
case "message_end":
// Just track the message end event, don't extract answer here
break;
}
}
// Get the final messages
const finalMessages = await stream.result();
// Verify the results
expect(finalMessages).toBeDefined();
expect(finalMessages.length).toBeGreaterThan(0);
const finalMessage = finalMessages[finalMessages.length - 1];
expect(finalMessage).toBeDefined();
expect(finalMessage.role).toBe("assistant");
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Extract the final answer from the last assistant message
const content = finalMessage.content
.filter((c) => c.type === "text")
.map((c) => (c.type === "text" ? c.text : ""))
.join(" ");
// Look for integers in the response that might be the final answer
const numbers = content.match(/\b\d+\b/g);
if (numbers) {
// Check if any of the numbers matches our expected sum
for (const num of numbers) {
const value = parseInt(num, 10);
if (Math.abs(value - expectedSum) < 10) {
finalAnswer = value;
break;
}
}
// If no exact match, take the last large number as likely the answer
if (finalAnswer === undefined) {
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
if (largeNumbers.length > 0) {
finalAnswer = largeNumbers[largeNumbers.length - 1];
}
}
}
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
// (or possibly 2 if the model calculates the sum itself without a tool)
expect(toolCallCount).toBeGreaterThanOrEqual(2);
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
expect(turns).toBeGreaterThanOrEqual(3);
expect(turns).toBeLessThanOrEqual(4);
// Verify the individual calculations are in the results
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
expect(hasFirstCalc).toBe(true);
expect(hasSecondCalc).toBe(true);
// Verify the final sum
if (finalAnswer !== undefined) {
expect(finalAnswer).toBe(expectedSum);
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
} else {
// If we couldn't extract the final answer from text, check if it's in the tool results
const hasSum = toolResults.some((r) => r === expectedSum);
expect(hasSum).toBe(true);
}
// Log summary
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
if (turns === 3) {
console.log("Model used parallel tool calls for initial calculations");
} else {
console.log("Model used sequential tool calls");
}
return {
turns,
toolCallCount,
toolResults,
finalAnswer,
events,
};
}
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: PromptConfig = {
model,
...options,
};
// Create a prompt that will require multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
};
// Create abort controller
const abortController = new AbortController();
// Track events for verification
const events: AgentEvent[] = [];
let toolCallCount = 0;
const errorReceived = false;
let finalMessages: Message[] | undefined;
// Execute the prompt
const stream = prompt(userPrompt, context, config, abortController.signal);
// Abort after first tool execution
const abortPromise = (async () => {
for await (const event of stream) {
events.push(event);
if (event.type === "tool_execution_end" && !event.isError) {
toolCallCount++;
// Abort after first successful tool execution
if (toolCallCount === 1) {
console.log("Aborting after first tool execution");
abortController.abort();
}
}
if (event.type === "agent_end") {
finalMessages = event.messages;
}
}
})();
finalMessages = await stream.result();
// Verify abort behavior
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
const assistantMessage = finalMessages[finalMessages.length - 1];
if (!assistantMessage) throw new Error("No final message received");
expect(assistantMessage).toBeDefined();
expect(assistantMessage.role).toBe("assistant");
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Should have executed 1 tool call before abort
expect(toolCallCount).toBeGreaterThanOrEqual(1);
expect(assistantMessage.stopReason).toBe("error");
return {
toolCallCount,
events,
errorReceived,
finalMessages,
};
}
describe("Agent Calculator Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
const model = getModel("openai", "gpt-5-mini");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
const model = getModel("xai", "grok-3");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
const model = getModel("zai", "glm-4.5-air");
it("should calculate multiple expressions and sum the results", async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
}, 30000);
it("should handle abort during tool execution", async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
}, 30000);
});
});

View file

@ -1,6 +1,6 @@
import { describe, expect, it } from "vitest";
import { complete } from "../src/generate.js";
import { getModel } from "../src/models.js";
import { complete } from "../src/stream.js";
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {

View file

@ -3,8 +3,9 @@ import { readFileSync } from "fs";
import { dirname, join } from "path";
import { fileURLToPath } from "url";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { complete, stream } from "../src/generate.js";
import { z } from "zod";
import { getModel } from "../src/models.js";
import { complete, stream } from "../src/stream.js";
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
const __filename = fileURLToPath(import.meta.url);
@ -14,19 +15,13 @@ const __dirname = dirname(__filename);
const calculatorTool: Tool = {
name: "calculator",
description: "Perform basic arithmetic operations",
parameters: {
type: "object",
properties: {
a: { type: "number", description: "First number" },
b: { type: "number", description: "Second number" },
operation: {
type: "string",
enum: ["add", "subtract", "multiply", "divide"],
description: "The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'.",
},
},
required: ["a", "b", "operation"],
},
parameters: z.object({
a: z.number().describe("First number"),
b: z.number().describe("Second number"),
operation: z
.enum(["add", "subtract", "multiply", "divide"])
.describe("The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'."),
}),
};
async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {

View file

@ -1,19 +1,16 @@
import { describe, expect, it } from "vitest";
import { complete } from "../src/generate.js";
import { z } from "zod";
import { getModel } from "../src/models.js";
import { complete } from "../src/stream.js";
import type { Api, AssistantMessage, Context, Message, Model, Tool, ToolResultMessage } from "../src/types.js";
// Tool for testing
const weatherTool: Tool = {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string", description: "City name" },
},
required: ["location"],
},
parameters: z.object({
location: z.string().describe("City name"),
}),
};
// Pre-built contexts representing typical outputs from each provider

View file

@ -0,0 +1,112 @@
import { describe, expect, it } from "vitest";
import { z } from "zod";
import type { AgentTool } from "../src/agent/types.js";
describe("Tool Validation with Zod", () => {
// Define a test tool with Zod schema
const testSchema = z.object({
name: z.string().min(1, "Name is required"),
age: z.number().int().min(0).max(150),
email: z.string().email("Invalid email format"),
tags: z.array(z.string()).optional(),
});
const testTool: AgentTool<typeof testSchema, void> = {
label: "Test Tool",
name: "test_tool",
description: "A test tool for validation",
parameters: testSchema,
execute: async (_toolCallId, args) => {
return {
output: `Processed: ${args.name}, ${args.age}, ${args.email}`,
details: undefined,
};
},
};
it("should validate correct input", () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
tags: ["developer", "typescript"],
};
// This should not throw
const result = testTool.parameters.parse(validInput);
expect(result).toEqual(validInput);
});
it("should reject invalid email", () => {
const invalidInput = {
name: "John Doe",
age: 30,
email: "not-an-email",
};
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
});
it("should reject missing required fields", () => {
const invalidInput = {
age: 30,
email: "john@example.com",
};
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
});
it("should reject invalid age", () => {
const invalidInput = {
name: "John Doe",
age: -5,
email: "john@example.com",
};
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
});
it("should format validation errors nicely", () => {
const invalidInput = {
name: "",
age: 200,
email: "invalid",
};
try {
testTool.parameters.parse(invalidInput);
// Should not reach here
expect(true).toBe(false);
} catch (e) {
if (e instanceof z.ZodError) {
const errors = e.issues
.map((err) => {
const path = err.path.length > 0 ? err.path.join(".") : "root";
return ` - ${path}: ${err.message}`;
})
.join("\n");
expect(errors).toContain("name: Name is required");
expect(errors).toContain("age: Number must be less than or equal to 150");
expect(errors).toContain("email: Invalid email format");
} else {
throw e;
}
}
});
it("should have type-safe execute function", async () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
};
// Validate and execute
const validated = testTool.parameters.parse(validInput);
const result = await testTool.execute("test-id", validated);
expect(result.output).toBe("Processed: John Doe, 30, john@example.com");
expect(result.details).toBeUndefined();
});
});