diff --git a/README.md b/README.md index f577490a..f3fe3686 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ A collection of tools for managing LLM deployments and building AI agents. ## Packages +- **[@mariozechner/pi-ai](packages/ai)** - Unified multi-provider LLM API - **[@mariozechner/pi-tui](packages/tui)** - Terminal UI library with differential rendering - **[@mariozechner/pi-agent](packages/agent)** - General-purpose agent with tool calling and session persistence - **[@mariozechner/pi](packages/pods)** - CLI for managing vLLM deployments on GPU pods diff --git a/package-lock.json b/package-lock.json index fefd38ed..25c18133 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20,9 +20,9 @@ } }, "node_modules/@anthropic-ai/sdk": { - "version": "0.60.0", - "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.60.0.tgz", - "integrity": "sha512-9zu/TXaUy8BZhXedDtt1wT3H4LOlpKDO1/ftiFpeR3N1PCr3KJFKkxxlQWWt1NNp08xSwUNJ3JNY8yhl8av6eQ==", + "version": "0.61.0", + "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.61.0.tgz", + "integrity": "sha512-GnlOXrPxow0uoaVB3DGNh9EJBU1MyagCBCLpU+bwDVlj/oOPYIwoiasMWlykkfYcQOrDP2x/zHnRD0xN7PeZPw==", "license": "MIT", "bin": { "anthropic-ai-sdk": "bin/cli" @@ -634,9 +634,9 @@ } }, "node_modules/@google/genai": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.15.0.tgz", - "integrity": "sha512-4CSW+hRTESWl3xVtde7pkQ3E+dDFhDq+m4ztmccRctZfx1gKy3v0M9STIMGk6Nq0s6O2uKMXupOZQ1JGorXVwQ==", + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.17.0.tgz", + "integrity": "sha512-r/OZWN9D8WvYrte3bcKPoLODrZ+2TjfxHm5OOyVHUbdFYIp1C4yJaXX4+sCS8I/+CbN9PxLjU5zm1cgmS7qz+A==", "license": "Apache-2.0", "dependencies": { "google-auth-library": "^9.14.2", @@ -646,7 +646,7 @@ "node": ">=20.0.0" }, "peerDependencies": { - "@modelcontextprotocol/sdk": "^1.11.0" + "@modelcontextprotocol/sdk": "^1.11.4" }, "peerDependenciesMeta": { "@modelcontextprotocol/sdk": { @@ -1310,9 +1310,9 @@ } }, "node_modules/chalk": { - "version": "5.5.0", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.5.0.tgz", - "integrity": "sha512-1tm8DTaJhPBG3bIkVeZt1iZM9GfSX2lzOeDVZH9R9ffRHpmHvxZ/QhgQH/aDTkswQVt+YHdXAdS/In/30OjCbg==", + "version": "5.6.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.6.2.tgz", + "integrity": "sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA==", "license": "MIT", "engines": { "node": "^12.17.0 || ^14.13 || >=16.0.0" @@ -1907,9 +1907,9 @@ } }, "node_modules/openai": { - "version": "5.15.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-5.15.0.tgz", - "integrity": "sha512-kcUdws8K/A8m02I+IqFBwO51gS+87GP89yWEufGbzEi8anBz4FB/bti2QxaJdGwwY4mwJGzx85XO7TuL/Tpu1w==", + "version": "5.20.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.20.0.tgz", + "integrity": "sha512-Bmc2zLM/YWgFrDpXr9hwXqGGDdMmMpE9+qoZPsaHpn0Y/Qk1Vu26hNqXo7+nHdli+sLsXINvS1f8kR3NKhGKmA==", "license": "Apache-2.0", "bin": { "openai": "bin/cli" @@ -2714,12 +2714,31 @@ } } }, + "node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.6", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.6.tgz", + "integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } + }, "packages/agent": { "name": "@mariozechner/pi-agent", "version": "0.5.31", "license": "MIT", "dependencies": { - "@mariozechner/pi-tui": "^0.5.30", + "@mariozechner/pi-tui": "^0.5.31", "@types/glob": "^8.1.0", "chalk": "^5.5.0", "glob": "^11.0.3", @@ -3101,10 +3120,11 @@ "version": "0.5.31", "license": "MIT", "dependencies": { - "@anthropic-ai/sdk": "^0.60.0", - "@google/genai": "^1.15.0", - "chalk": "^5.5.0", - "openai": "^5.15.0" + "@anthropic-ai/sdk": "^0.61.0", + "@google/genai": "^1.17.0", + "chalk": "^5.6.2", + "openai": "^5.20.0", + "zod-to-json-schema": "^3.24.6" }, "devDependencies": { "@types/node": "^24.3.0", @@ -3137,7 +3157,7 @@ "version": "0.5.31", "license": "MIT", "dependencies": { - "@mariozechner/pi-agent": "^0.5.30", + "@mariozechner/pi-agent": "^0.5.31", "chalk": "^5.5.0" }, "bin": { diff --git a/packages/ai/README.md b/packages/ai/README.md index 942506c9..5c2af079 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -24,20 +24,18 @@ npm install @mariozechner/pi-ai ## Quick Start ```typescript -import { getModel, stream, complete, Context, Tool } from '@mariozechner/pi-ai'; +import { getModel, stream, complete, Context, Tool, z } from '@mariozechner/pi-ai'; // Fully typed with auto-complete support for both providers and models const model = getModel('openai', 'gpt-4o-mini'); -// Define tools +// Define tools with Zod schemas for type safety and validation const tools: Tool[] = [{ name: 'get_time', description: 'Get the current time', - parameters: { - type: 'object', - properties: {}, - required: [] - } + parameters: z.object({ + timezone: z.string().optional().describe('Optional timezone (e.g., America/New_York)') + }) }]; // Build a conversation context (easily serializable and transferable between models) @@ -94,7 +92,11 @@ const toolCalls = finalMessage.content.filter(b => b.type === 'toolCall'); for (const call of toolCalls) { // Execute the tool const result = call.name === 'get_time' - ? new Date().toISOString() + ? new Date().toLocaleString('en-US', { + timeZone: call.arguments.timezone || 'UTC', + dateStyle: 'full', + timeStyle: 'long' + }) : 'Unknown tool'; // Add tool result to context @@ -102,7 +104,7 @@ for (const call of toolCalls) { role: 'toolResult', toolCallId: call.id, toolName: call.name, - content: result, + output: result, isError: false }); } @@ -129,6 +131,70 @@ for (const block of response.content) { } ``` +## Tools + +Tools enable LLMs to interact with external systems. This library uses Zod schemas for type-safe tool definitions with automatic validation. + +### Defining Tools + +```typescript +import { z, Tool } from '@mariozechner/pi-ai'; + +// Define tool parameters with Zod +const weatherTool: Tool = { + name: 'get_weather', + description: 'Get current weather for a location', + parameters: z.object({ + location: z.string().describe('City name or coordinates'), + units: z.enum(['celsius', 'fahrenheit']).default('celsius') + }) +}; + +// Complex validation with Zod refinements +const bookMeetingTool: Tool = { + name: 'book_meeting', + description: 'Schedule a meeting', + parameters: z.object({ + title: z.string().min(1), + startTime: z.string().datetime(), + endTime: z.string().datetime(), + attendees: z.array(z.string().email()).min(1) + }).refine( + data => new Date(data.endTime) > new Date(data.startTime), + { message: 'End time must be after start time' } + ) +}; +``` + +### Handling Tool Calls + +```typescript +const context: Context = { + messages: [{ role: 'user', content: 'What is the weather in London?' }], + tools: [weatherTool] +}; + +const response = await complete(model, context); + +// Check for tool calls in the response +for (const block of response.content) { + if (block.type === 'toolCall') { + // Arguments are automatically validated against the Zod schema + // If validation fails, an error event is emitted + const result = await executeWeatherApi(block.arguments); + + // Add tool result to continue the conversation + context.messages.push({ + role: 'toolResult', + toolCallId: block.id, + toolName: block.name, + output: JSON.stringify(result), + isError: false + }); + } +} +``` + ## Image Input Models with vision capabilities can process images. You can check if a model supports images via the `input` property. If you pass images to a non-vision model, they are silently ignored. @@ -260,7 +326,7 @@ for await (const event of s) { ## Errors & Abort Signal -When a request ends with an error (including aborts), the API returns an `AssistantMessage` with: +When a request ends with an error (including aborts and tool call validation errors), the API returns an `AssistantMessage` with: - `stopReason: 'error'` - Indicates the request ended with an error - `error: string` - Error message describing what happened - `content: array` - **Partial content** accumulated before the error @@ -503,6 +569,189 @@ const continuation = await complete(newModel, restored); > **Note**: If the context contains images (encoded as base64 as shown in the Image Input section), those will also be serialized. +## Agent API + +The Agent API provides a higher-level interface for building agents with tools. It handles tool execution, validation, and provides detailed event streaming for interactive applications. + +### Event System + +The Agent API streams events during execution, allowing you to build reactive UIs and track agent progress. The agent processes prompts in **turns**, where each turn consists of: +1. An assistant message (the LLM's response) +2. Optional tool executions if the assistant calls tools +3. Tool result messages that are fed back to the LLM + +This continues until the assistant produces a response without tool calls. + +### Event Flow Example + +Given a prompt asking to calculate two expressions and sum them: + +```typescript +import { prompt, AgentContext, calculateTool } from '@mariozechner/pi-ai'; + +const context: AgentContext = { + systemPrompt: 'You are a helpful math assistant.', + messages: [], + tools: [calculateTool] +}; + +const stream = prompt( + { role: 'user', content: 'Calculate 15 * 20 and 30 * 40, then sum the results' }, + context, + { model: getModel('openai', 'gpt-4o-mini') } +); + +// Expected event sequence: +// 1. agent_start - Agent begins processing +// 2. turn_start - First turn begins +// 3. message_start - User message starts +// 4. message_end - User message ends +// 5. message_start - Assistant message starts +// 6. message_update - Assistant streams response with tool calls +// 7. message_end - Assistant message ends +// 8. tool_execution_start - First calculation (15 * 20) +// 9. tool_execution_end - Result: 300 +// 10. tool_execution_start - Second calculation (30 * 40) +// 11. tool_execution_end - Result: 1200 +// 12. message_start - Tool result message for first calculation +// 13. message_end - Tool result message ends +// 14. message_start - Tool result message for second calculation +// 15. message_end - Tool result message ends +// 16. turn_end - First turn ends with 2 tool results +// 17. turn_start - Second turn begins +// 18. message_start - Assistant message starts +// 19. message_update - Assistant streams response with sum calculation +// 20. message_end - Assistant message ends +// 21. tool_execution_start - Sum calculation (300 + 1200) +// 22. tool_execution_end - Result: 1500 +// 23. message_start - Tool result message for sum +// 24. message_end - Tool result message ends +// 25. turn_end - Second turn ends with 1 tool result +// 26. turn_start - Third turn begins +// 27. message_start - Final assistant message starts +// 28. message_update - Assistant streams final answer +// 29. message_end - Final assistant message ends +// 30. turn_end - Third turn ends with 0 tool results +// 31. agent_end - Agent completes with all messages +``` + +### Handling Events + +```typescript +for await (const event of stream) { + switch (event.type) { + case 'agent_start': + console.log('Agent started'); + break; + + case 'turn_start': + console.log('New turn started'); + break; + + case 'message_start': + console.log(`${event.message.role} message started`); + break; + + case 'message_update': + // Only for assistant messages during streaming + if (event.message.content.some(c => c.type === 'text')) { + console.log('Assistant:', event.message.content); + } + break; + + case 'tool_execution_start': + console.log(`Calling ${event.toolName} with:`, event.args); + break; + + case 'tool_execution_end': + if (event.isError) { + console.error(`Tool failed:`, event.result); + } else { + console.log(`Tool result:`, event.result.output); + } + break; + + case 'turn_end': + console.log(`Turn ended with ${event.toolResults.length} tool calls`); + break; + + case 'agent_end': + console.log(`Agent completed with ${event.messages.length} new messages`); + break; + } +} + +// Get all messages generated during this agent execution +// These include the user message and can be directly appended to context.messages +const messages = await stream.result(); +context.messages.push(...messages); +``` + +### Defining Tools with Zod + +Tools use Zod schemas for runtime validation and type inference: + +```typescript +import { z } from 'zod'; +import { AgentTool, AgentToolResult } from '@mariozechner/pi-ai'; + +const weatherSchema = z.object({ + city: z.string().min(1, 'City is required'), + units: z.enum(['celsius', 'fahrenheit']).default('celsius') +}); + +const weatherTool: AgentTool = { + label: 'Get Weather', + name: 'get_weather', + description: 'Get current weather for a city', + parameters: weatherSchema, + execute: async (toolCallId, args) => { + // args is fully typed: { city: string, units: 'celsius' | 'fahrenheit' } + const temp = Math.round(Math.random() * 30); + return { + output: `Temperature in ${args.city}: ${temp}°${args.units[0].toUpperCase()}`, + details: { temp } + }; + } +}; +``` + +### Validation and Error Handling + +Tool arguments are automatically validated using the Zod schema. Invalid arguments result in detailed error messages: + +```typescript +// If the LLM calls with invalid arguments: +// get_weather({ city: '', units: 'kelvin' }) + +// The tool execution will fail with: +/* +Validation failed for tool "get_weather": + - city: City is required + - units: Invalid enum value. Expected 'celsius' | 'fahrenheit', received 'kelvin' + +Received arguments: +{ + "city": "", + "units": "kelvin" +} +*/ +``` + +### Built-in Example Tools + +The library includes example tools for common operations: + +```typescript +import { calculateTool, getCurrentTimeTool } from '@mariozechner/pi-ai'; + +const context: AgentContext = { + systemPrompt: 'You are a helpful assistant.', + messages: [], + tools: [calculateTool, getCurrentTimeTool] +}; +``` + ## Browser Usage The library supports browser environments. You must pass the API key explicitly since environment variables are not available in browsers: @@ -533,6 +782,7 @@ GEMINI_API_KEY=... GROQ_API_KEY=gsk_... CEREBRAS_API_KEY=csk-... XAI_API_KEY=xai-... +ZAI_API_KEY=... OPENROUTER_API_KEY=sk-or-... ``` @@ -549,6 +799,21 @@ const response = await complete(model, context, { }); ``` +### Programmatic API Key Management + +You can also set and get API keys programmatically: + +```typescript +import { setApiKey, getApiKey } from '@mariozechner/pi-ai'; + +// Set API key for a provider +setApiKey('openai', 'sk-...'); +setApiKey('anthropic', 'sk-ant-...'); + +// Get API key for a provider (checks both programmatic and env vars) +const key = getApiKey('openai'); +``` + ## License MIT \ No newline at end of file diff --git a/packages/ai/package.json b/packages/ai/package.json index 5fee305c..628843a1 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -19,10 +19,11 @@ "prepublishOnly": "npm run clean && npm run build" }, "dependencies": { - "@anthropic-ai/sdk": "^0.60.0", - "@google/genai": "^1.15.0", - "chalk": "^5.5.0", - "openai": "^5.15.0" + "@anthropic-ai/sdk": "^0.61.0", + "@google/genai": "^1.17.0", + "chalk": "^5.6.2", + "openai": "^5.20.0", + "zod-to-json-schema": "^3.24.6" }, "keywords": [ "ai", diff --git a/packages/ai/src/agent/agent.ts b/packages/ai/src/agent/agent.ts index 2ec05833..540995a5 100644 --- a/packages/ai/src/agent/agent.ts +++ b/packages/ai/src/agent/agent.ts @@ -1,97 +1,68 @@ import { EventStream } from "../event-stream"; -import { streamSimple } from "../generate.js"; -import type { - AssistantMessage, - Context, - Message, - Model, - SimpleGenerateOptions, - ToolResultMessage, - UserMessage, -} from "../types.js"; -import type { AgentContext, AgentTool, AgentToolResult } from "./types"; - -// Event types -export type AgentEvent = - | { type: "message_start"; message: Message } - | { type: "message_update"; message: AssistantMessage } - | { type: "message_complete"; message: Message } - | { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any } - | { - type: "tool_execution_complete"; - toolCallId: string; - toolName: string; - result: AgentToolResult | string; - isError: boolean; - } - | { type: "turn_complete"; messages: AgentContext["messages"] }; - -// Configuration for prompt execution -export interface PromptConfig { - model: Model; - apiKey: string; - enableThinking?: boolean; - preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise; -} +import { streamSimple } from "../stream.js"; +import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js"; +import { validateToolArguments } from "../validation.js"; +import type { AgentContext, AgentEvent, AgentTool, AgentToolResult, PromptConfig } from "./types"; // Main prompt function - returns a stream of events export function prompt( + prompt: UserMessage, context: AgentContext, config: PromptConfig, - prompt: UserMessage, signal?: AbortSignal, ): EventStream { const stream = new EventStream( - (event) => event.type === "turn_complete", - (event) => (event.type === "turn_complete" ? event.messages : []), + (event) => event.type === "agent_end", + (event) => (event.type === "agent_end" ? event.messages : []), ); // Run the prompt async (async () => { - try { - // Track new messages generated during this prompt - const newMessages: AgentContext["messages"] = []; + // Track new messages generated during this prompt + const newMessages: AgentContext["messages"] = []; + // Create user message for the prompt + const messages = [...context.messages, prompt]; + newMessages.push(prompt); - // Create user message - const messages = [...context.messages, prompt]; - newMessages.push(prompt); + stream.push({ type: "agent_start" }); + stream.push({ type: "turn_start" }); + stream.push({ type: "message_start", message: prompt }); + stream.push({ type: "message_end", message: prompt }); - stream.push({ type: "message_start", message: prompt }); - stream.push({ type: "message_complete", message: prompt }); + // Update context with new messages + const currentContext: AgentContext = { + ...context, + messages, + }; - // Update context with new messages - const currentContext: AgentContext = { - ...context, - messages, - }; - - // Keep looping while we have tool calls - let hasMoreToolCalls = true; - while (hasMoreToolCalls) { - // Stream assistant response - const assistantMessage = await streamAssistantResponse(currentContext, config, signal, stream); - newMessages.push(assistantMessage); - - // Check for tool calls - const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); - hasMoreToolCalls = toolCalls.length > 0; - - if (hasMoreToolCalls) { - // Execute tool calls - const toolResults = await executeToolCalls(currentContext.tools, assistantMessage, signal, stream); - newMessages.push(...toolResults); - - // Add tool results to context - currentContext.messages = [...currentContext.messages, ...toolResults]; - } + // Keep looping while we have tool calls + let hasMoreToolCalls = true; + let firstTurn = true; + while (hasMoreToolCalls) { + if (!firstTurn) { + stream.push({ type: "turn_start" }); + } else { + firstTurn = false; } + // Stream assistant response + const assistantMessage = await streamAssistantResponse(currentContext, config, signal, stream); + newMessages.push(assistantMessage); - stream.push({ type: "turn_complete", messages: newMessages }); - } catch (error) { - // End stream on error - stream.end([]); - throw error; + // Check for tool calls + const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); + hasMoreToolCalls = toolCalls.length > 0; + + const toolResults: ToolResultMessage[] = []; + if (hasMoreToolCalls) { + // Execute tool calls + toolResults.push(...(await executeToolCalls(currentContext.tools, assistantMessage, signal, stream))); + currentContext.messages.push(...toolResults); + newMessages.push(...toolResults); + } + stream.push({ type: "turn_end", assistantMessage, toolResults: toolResults }); } + stream.push({ type: "agent_end", messages: newMessages }); + stream.end(newMessages); })(); return stream; @@ -122,16 +93,7 @@ async function streamAssistantResponse( tools: context.tools, // AgentTool extends Tool, so this works }; - const options: SimpleGenerateOptions = { - apiKey: config.apiKey, - signal, - }; - - if (config.model.reasoning && config.enableThinking) { - options.reasoning = "medium"; - } - - const response = await streamSimple(config.model, processedContext, options); + const response = await streamSimple(config.model, processedContext, { ...config, signal }); let partialMessage: AssistantMessage | null = null; let addedPartial = false; @@ -147,14 +109,17 @@ async function streamAssistantResponse( case "text_start": case "text_delta": + case "text_end": case "thinking_start": case "thinking_delta": + case "thinking_end": case "toolcall_start": case "toolcall_delta": + case "toolcall_end": if (partialMessage) { partialMessage = event.partial; context.messages[context.messages.length - 1] = partialMessage; - stream.push({ type: "message_update", message: { ...partialMessage } }); + stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } }); } break; @@ -166,7 +131,7 @@ async function streamAssistantResponse( } else { context.messages.push(finalMessage); } - stream.push({ type: "message_complete", message: finalMessage }); + stream.push({ type: "message_end", message: finalMessage }); return finalMessage; } } @@ -176,7 +141,7 @@ async function streamAssistantResponse( } async function executeToolCalls( - tools: AgentTool[] | undefined, + tools: AgentTool[] | undefined, assistantMessage: AssistantMessage, signal: AbortSignal | undefined, stream: EventStream, @@ -199,14 +164,19 @@ async function executeToolCalls( try { if (!tool) throw new Error(`Tool ${toolCall.name} not found`); - resultOrError = await tool.execute(toolCall.arguments, toolCall.id, signal); + + // Validate arguments using shared validation function + const validatedArgs = validateToolArguments(tool, toolCall); + + // Execute with validated, typed arguments + resultOrError = await tool.execute(toolCall.id, validatedArgs, signal); } catch (e) { - resultOrError = `Error: ${e instanceof Error ? e.message : String(e)}`; + resultOrError = e instanceof Error ? e.message : String(e); isError = true; } stream.push({ - type: "tool_execution_complete", + type: "tool_execution_end", toolCallId: toolCall.id, toolName: toolCall.name, result: resultOrError, @@ -224,7 +194,7 @@ async function executeToolCalls( results.push(toolResultMessage); stream.push({ type: "message_start", message: toolResultMessage }); - stream.push({ type: "message_complete", message: toolResultMessage }); + stream.push({ type: "message_end", message: toolResultMessage }); } return results; diff --git a/packages/ai/src/agent/index.ts b/packages/ai/src/agent/index.ts index df4b022f..af776d60 100644 --- a/packages/ai/src/agent/index.ts +++ b/packages/ai/src/agent/index.ts @@ -1,3 +1,3 @@ -export { type AgentEvent, type PromptConfig, prompt } from "./agent"; +export { prompt } from "./agent"; export * from "./tools"; -export type { AgentContext, AgentTool } from "./types"; +export type { AgentContext, AgentEvent, AgentTool, PromptConfig } from "./types"; diff --git a/packages/ai/src/agent/tools/calculate.ts b/packages/ai/src/agent/tools/calculate.ts index c0eff265..92f71dd1 100644 --- a/packages/ai/src/agent/tools/calculate.ts +++ b/packages/ai/src/agent/tools/calculate.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import type { AgentTool } from "../../agent"; export interface CalculateResult { @@ -14,21 +15,16 @@ export function calculate(expression: string): CalculateResult { } } -export const calculateTool: AgentTool = { +const calculateSchema = z.object({ + expression: z.string().describe("The mathematical expression to evaluate"), +}); + +export const calculateTool: AgentTool = { label: "Calculator", name: "calculate", description: "Evaluate mathematical expressions", - parameters: { - type: "object", - properties: { - expression: { - type: "string", - description: "The mathematical expression to evaluate", - }, - }, - required: ["expression"], - }, - execute: async (args: { expression: string }) => { + parameters: calculateSchema, + execute: async (_toolCallId, args) => { return calculate(args.expression); }, }; diff --git a/packages/ai/src/agent/tools/get-current-time.ts b/packages/ai/src/agent/tools/get-current-time.ts index a4774302..31a5f068 100644 --- a/packages/ai/src/agent/tools/get-current-time.ts +++ b/packages/ai/src/agent/tools/get-current-time.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import type { AgentTool } from "../../agent"; import type { AgentToolResult } from "../types"; @@ -25,20 +26,16 @@ export async function getCurrentTime(timezone?: string): Promise = { +const getCurrentTimeSchema = z.object({ + timezone: z.string().optional().describe("Optional timezone (e.g., 'America/New_York', 'Europe/London')"), +}); + +export const getCurrentTimeTool: AgentTool = { label: "Current Time", name: "get_current_time", description: "Get the current date and time", - parameters: { - type: "object", - properties: { - timezone: { - type: "string", - description: "Optional timezone (e.g., 'America/New_York', 'Europe/London')", - }, - }, - }, - execute: async (args: { timezone?: string }) => { + parameters: getCurrentTimeSchema, + execute: async (_toolCallId, args) => { return getCurrentTime(args.timezone); }, }; diff --git a/packages/ai/src/agent/types.ts b/packages/ai/src/agent/types.ts index d8ddcac9..42866694 100644 --- a/packages/ai/src/agent/types.ts +++ b/packages/ai/src/agent/types.ts @@ -1,4 +1,13 @@ -import type { Message, Tool } from "../types.js"; +import type { ZodSchema, z } from "zod"; +import type { + AssistantMessage, + AssistantMessageEvent, + Message, + Model, + SimpleStreamOptions, + Tool, + ToolResultMessage, +} from "../types.js"; export interface AgentToolResult { // Output of the tool to be given to the LLM in ToolResultMessage.content @@ -8,10 +17,14 @@ export interface AgentToolResult { } // AgentTool extends Tool but adds the execute function -export interface AgentTool extends Tool { +export interface AgentTool extends Tool { // A human-readable label for the tool to be displayed in UI label: string; - execute: (params: any, toolCallId: string, signal?: AbortSignal) => Promise>; + execute: ( + toolCallId: string, + params: z.infer, + signal?: AbortSignal, + ) => Promise>; } // AgentContext is like Context but uses AgentTool @@ -20,3 +33,37 @@ export interface AgentContext { messages: Message[]; tools?: AgentTool[]; } + +// Event types +export type AgentEvent = + // Emitted when the agent starts. An agent can emit multiple turns + | { type: "agent_start" } + // Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages + | { type: "turn_start" } + // Emitted when a user, assistant or tool result message starts + | { type: "message_start"; message: Message } + // Emitted when an asssitant messages is updated due to streaming + | { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage } + // Emitted when a user, assistant or tool result message is complete + | { type: "message_end"; message: Message } + // Emitted when a tool execution starts + | { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any } + // Emitted when a tool execution completes + | { + type: "tool_execution_end"; + toolCallId: string; + toolName: string; + result: AgentToolResult | string; + isError: boolean; + } + // Emitted when a full turn completes + | { type: "turn_end"; assistantMessage: AssistantMessage; toolResults: ToolResultMessage[] } + // Emitted when the agent has completed all its turns. All messages from every turn are + // contained in messages, which can be appended to the context + | { type: "agent_end"; messages: AgentContext["messages"] }; + +// Configuration for prompt execution +export interface PromptConfig extends SimpleStreamOptions { + model: Model; + preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise; +} diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index e0e24534..3c8073ed 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,8 +1,9 @@ +export { z } from "zod"; export * from "./agent/index.js"; -export * from "./generate.js"; export * from "./models.js"; export * from "./providers/anthropic.js"; export * from "./providers/google.js"; export * from "./providers/openai-completions.js"; export * from "./providers/openai-responses.js"; +export * from "./stream.js"; export * from "./types.js"; diff --git a/packages/ai/src/models.generated.ts b/packages/ai/src/models.generated.ts index 05a0cbbf..67694c6b 100644 --- a/packages/ai/src/models.generated.ts +++ b/packages/ai/src/models.generated.ts @@ -1413,6 +1413,23 @@ export const MODELS = { } satisfies Model<"anthropic-messages">, }, openrouter: { + "nvidia/nemotron-nano-9b-v2": { + id: "nvidia/nemotron-nano-9b-v2", + name: "NVIDIA: Nemotron Nano 9B V2", + api: "openai-completions", + provider: "openrouter", + baseUrl: "https://openrouter.ai/api/v1", + reasoning: true, + input: ["text"], + cost: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + }, + contextWindow: 128000, + maxTokens: 4096, + } satisfies Model<"openai-completions">, "openrouter/sonoma-dusk-alpha": { id: "openrouter/sonoma-dusk-alpha", name: "Sonoma Dusk Alpha", diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index 683f0ed1..832c48cc 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -4,32 +4,34 @@ import type { MessageCreateParamsStreaming, MessageParam, } from "@anthropic-ai/sdk/resources/messages.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, Context, - GenerateFunction, - GenerateOptions, Message, Model, StopReason, + StreamFunction, + StreamOptions, TextContent, ThinkingContent, Tool, ToolCall, ToolResultMessage, } from "../types.js"; +import { validateToolArguments } from "../validation.js"; import { transformMessages } from "./transorm-messages.js"; -export interface AnthropicOptions extends GenerateOptions { +export interface AnthropicOptions extends StreamOptions { thinkingEnabled?: boolean; thinkingBudgetTokens?: number; toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string }; } -export const streamAnthropic: GenerateFunction<"anthropic-messages"> = ( +export const streamAnthropic: StreamFunction<"anthropic-messages"> = ( model: Model<"anthropic-messages">, context: Context, options?: AnthropicOptions, @@ -159,6 +161,15 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = ( }); } else if (block.type === "toolCall") { block.arguments = JSON.parse(block.partialJson); + + // Validate tool arguments if tool definition is available + if (context.tools) { + const tool = context.tools.find((t) => t.name === block.name); + if (tool) { + block.arguments = validateToolArguments(tool, block); + } + } + delete (block as any).partialJson; stream.push({ type: "toolcall_end", @@ -390,7 +401,7 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages"> content: blocks, }); } else if (msg.role === "toolResult") { - // Collect all consecutive toolResult messages + // Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint const toolResults: ContentBlockParam[] = []; // Add the current tool result @@ -430,15 +441,19 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages"> function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] { if (!tools) return []; - return tools.map((tool) => ({ - name: tool.name, - description: tool.description, - input_schema: { - type: "object" as const, - properties: tool.parameters.properties || {}, - required: tool.parameters.required || [], - }, - })); + return tools.map((tool) => { + const jsonSchema = zodToJsonSchema(tool.parameters, { $refStrategy: "none" }) as any; + + return { + name: tool.name, + description: tool.description, + input_schema: { + type: "object" as const, + properties: jsonSchema.properties || {}, + required: jsonSchema.required || [], + }, + }; + }); } function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason { diff --git a/packages/ai/src/providers/google.ts b/packages/ai/src/providers/google.ts index 8a90f9eb..aa184c3d 100644 --- a/packages/ai/src/providers/google.ts +++ b/packages/ai/src/providers/google.ts @@ -7,24 +7,26 @@ import { GoogleGenAI, type Part, } from "@google/genai"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, Context, - GenerateFunction, - GenerateOptions, Model, StopReason, + StreamFunction, + StreamOptions, TextContent, ThinkingContent, Tool, ToolCall, } from "../types.js"; +import { validateToolArguments } from "../validation.js"; import { transformMessages } from "./transorm-messages.js"; -export interface GoogleOptions extends GenerateOptions { +export interface GoogleOptions extends StreamOptions { toolChoice?: "auto" | "none" | "any"; thinking?: { enabled: boolean; @@ -35,7 +37,7 @@ export interface GoogleOptions extends GenerateOptions { // Counter for generating unique tool call IDs let toolCallCounter = 0; -export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( +export const streamGoogle: StreamFunction<"google-generative-ai"> = ( model: Model<"google-generative-ai">, context: Context, options?: GoogleOptions, @@ -159,6 +161,15 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( name: part.functionCall.name || "", arguments: part.functionCall.args as Record, }; + + // Validate tool arguments if tool definition is available + if (context.tools) { + const tool = context.tools.find((t) => t.name === toolCall.name); + if (tool) { + toolCall.arguments = validateToolArguments(tool, toolCall); + } + } + output.content.push(toolCall); stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); stream.push({ @@ -380,7 +391,7 @@ function convertTools(tools: Tool[]): any[] { functionDeclarations: tools.map((tool) => ({ name: tool.name, description: tool.description, - parameters: tool.parameters, + parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }), })), }, ]; diff --git a/packages/ai/src/providers/openai-completions.ts b/packages/ai/src/providers/openai-completions.ts index dcd8eb01..d21c63a3 100644 --- a/packages/ai/src/providers/openai-completions.ts +++ b/packages/ai/src/providers/openai-completions.ts @@ -7,28 +7,30 @@ import type { ChatCompletionContentPartText, ChatCompletionMessageParam, } from "openai/resources/chat/completions.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { AssistantMessage, Context, - GenerateFunction, - GenerateOptions, Model, StopReason, + StreamFunction, + StreamOptions, TextContent, ThinkingContent, Tool, ToolCall, } from "../types.js"; +import { validateToolArguments } from "../validation.js"; import { transformMessages } from "./transorm-messages.js"; -export interface OpenAICompletionsOptions extends GenerateOptions { +export interface OpenAICompletionsOptions extends StreamOptions { toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } }; reasoningEffort?: "minimal" | "low" | "medium" | "high"; } -export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( +export const streamOpenAICompletions: StreamFunction<"openai-completions"> = ( model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions, @@ -79,6 +81,15 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( }); } else if (block.type === "toolCall") { block.arguments = JSON.parse(block.partialArgs || "{}"); + + // Validate tool arguments if tool definition is available + if (context.tools) { + const tool = context.tools.find((t) => t.name === block.name); + if (tool) { + block.arguments = validateToolArguments(tool, block); + } + } + delete block.partialArgs; stream.push({ type: "toolcall_end", @@ -381,7 +392,7 @@ function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool function: { name: tool.name, description: tool.description, - parameters: tool.parameters, + parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }), }, })); } diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index 484caa2f..7e15b93e 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -10,25 +10,27 @@ import type { ResponseOutputMessage, ResponseReasoningItem, } from "openai/resources/responses/responses.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, Context, - GenerateFunction, - GenerateOptions, Model, StopReason, + StreamFunction, + StreamOptions, TextContent, ThinkingContent, Tool, ToolCall, } from "../types.js"; +import { validateToolArguments } from "../validation.js"; import { transformMessages } from "./transorm-messages.js"; // OpenAI Responses-specific options -export interface OpenAIResponsesOptions extends GenerateOptions { +export interface OpenAIResponsesOptions extends StreamOptions { reasoningEffort?: "minimal" | "low" | "medium" | "high"; reasoningSummary?: "auto" | "detailed" | "concise" | null; } @@ -36,7 +38,7 @@ export interface OpenAIResponsesOptions extends GenerateOptions { /** * Generate function for OpenAI Responses API */ -export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( +export const streamOpenAIResponses: StreamFunction<"openai-responses"> = ( model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions, @@ -238,6 +240,15 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( name: item.name, arguments: JSON.parse(item.arguments), }; + + // Validate tool arguments if tool definition is available + if (context.tools) { + const tool = context.tools.find((t) => t.name === toolCall.name); + if (tool) { + toolCall.arguments = validateToolArguments(tool, toolCall); + } + } + stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output }); } } @@ -451,7 +462,7 @@ function convertTools(tools: Tool[]): OpenAITool[] { type: "function", name: tool.name, description: tool.description, - parameters: tool.parameters, + parameters: zodToJsonSchema(tool.parameters, { $refStrategy: "none" }), strict: null, })); } diff --git a/packages/ai/src/generate.ts b/packages/ai/src/stream.ts similarity index 97% rename from packages/ai/src/generate.ts rename to packages/ai/src/stream.ts index 54eeed0f..c46e7da5 100644 --- a/packages/ai/src/generate.ts +++ b/packages/ai/src/stream.ts @@ -11,7 +11,7 @@ import type { Model, OptionsForApi, ReasoningEffort, - SimpleGenerateOptions, + SimpleStreamOptions, } from "./types.js"; const apiKeys: Map = new Map(); @@ -90,7 +90,7 @@ export async function complete( export function streamSimple( model: Model, context: Context, - options?: SimpleGenerateOptions, + options?: SimpleStreamOptions, ): AssistantMessageEventStream { const apiKey = options?.apiKey || getApiKey(model.provider); if (!apiKey) { @@ -104,7 +104,7 @@ export function streamSimple( export async function completeSimple( model: Model, context: Context, - options?: SimpleGenerateOptions, + options?: SimpleStreamOptions, ): Promise { const s = streamSimple(model, context, options); return s.result(); @@ -112,7 +112,7 @@ export async function completeSimple( function mapOptionsForApi( model: Model, - options?: SimpleGenerateOptions, + options?: SimpleStreamOptions, apiKey?: string, ): OptionsForApi { const base = { diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 6af1945e..3608f307 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -16,11 +16,11 @@ export interface ApiOptionsMap { } // Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys -type _CheckExhaustive = ApiOptionsMap extends Record - ? Record extends ApiOptionsMap +type _CheckExhaustive = ApiOptionsMap extends Record + ? Record extends ApiOptionsMap ? true : ["ApiOptionsMap is missing some KnownApi values", Exclude] - : ["ApiOptionsMap doesn't extend Record"]; + : ["ApiOptionsMap doesn't extend Record"]; const _exhaustive: _CheckExhaustive = true; // Helper type to get options for a specific API @@ -32,20 +32,20 @@ export type Provider = KnownProvider | string; export type ReasoningEffort = "minimal" | "low" | "medium" | "high"; // Base options all providers share -export interface GenerateOptions { +export interface StreamOptions { temperature?: number; maxTokens?: number; signal?: AbortSignal; apiKey?: string; } -// Unified options with reasoning (what public generate() accepts) -export interface SimpleGenerateOptions extends GenerateOptions { +// Unified options with reasoning passed to streamSimple() and completeSimple() +export interface SimpleStreamOptions extends StreamOptions { reasoning?: ReasoningEffort; } -// Generic GenerateFunction with typed options -export type GenerateFunction = ( +// Generic StreamFunction with typed options +export type StreamFunction = ( model: Model, context: Context, options: OptionsForApi, @@ -119,10 +119,12 @@ export interface ToolResultMessage { export type Message = UserMessage | AssistantMessage | ToolResultMessage; -export interface Tool { +import type { ZodSchema } from "zod"; + +export interface Tool { name: string; description: string; - parameters: Record; // JSON Schema + parameters: TParameters; } export interface Context { diff --git a/packages/ai/src/validation.ts b/packages/ai/src/validation.ts new file mode 100644 index 00000000..a3afd92a --- /dev/null +++ b/packages/ai/src/validation.ts @@ -0,0 +1,32 @@ +import { z } from "zod"; +import type { Tool, ToolCall } from "./types.js"; + +/** + * Validates tool call arguments against the tool's Zod schema + * @param tool The tool definition with Zod schema + * @param toolCall The tool call from the LLM + * @returns The validated arguments + * @throws ZodError with formatted message if validation fails + */ +export function validateToolArguments(tool: Tool, toolCall: ToolCall): any { + try { + // Validate arguments with Zod schema + return tool.parameters.parse(toolCall.arguments); + } catch (e) { + if (e instanceof z.ZodError) { + // Format validation errors nicely + const errors = e.issues + .map((err) => { + const path = err.path.length > 0 ? err.path.join(".") : "root"; + return ` - ${path}: ${err.message}`; + }) + .join("\n"); + + const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`; + + // Throw a new error with the formatted message + throw new Error(errorMessage); + } + throw e; + } +} diff --git a/packages/ai/test/abort.test.ts b/packages/ai/test/abort.test.ts index 7de2892d..ecee1186 100644 --- a/packages/ai/test/abort.test.ts +++ b/packages/ai/test/abort.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from "vitest"; -import { complete, stream } from "../src/generate.js"; import { getModel } from "../src/models.js"; +import { complete, stream } from "../src/stream.js"; import type { Api, Context, Model, OptionsForApi } from "../src/types.js"; async function testAbortSignal(llm: Model, options: OptionsForApi = {}) { diff --git a/packages/ai/test/agent.test.ts b/packages/ai/test/agent.test.ts new file mode 100644 index 00000000..fcaa3096 --- /dev/null +++ b/packages/ai/test/agent.test.ts @@ -0,0 +1,347 @@ +import { describe, expect, it } from "vitest"; +import { prompt } from "../src/agent/agent.js"; +import { calculateTool } from "../src/agent/tools/calculate.js"; +import type { AgentContext, AgentEvent, PromptConfig } from "../src/agent/types.js"; +import { getModel } from "../src/models.js"; +import type { Api, Message, Model, OptionsForApi, UserMessage } from "../src/types.js"; + +async function calculateTest(model: Model, options: OptionsForApi = {}) { + // Create the agent context with the calculator tool + const context: AgentContext = { + systemPrompt: + "You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.", + messages: [], + tools: [calculateTool], + }; + + // Create the prompt config + const config: PromptConfig = { + model, + ...options, + }; + + // Create the user prompt asking for multiple calculations + const userPrompt: UserMessage = { + role: "user", + content: `Use the calculator tool to complete the following mulit-step task. +1. Calculate 3485 * 4234 and 88823 * 3482 in parallel +2. Calculate the sum of the two results using the calculator tool +3. Output ONLY the final sum as a single integer number, nothing else.`, + }; + + // Calculate expected results (using integers) + const expectedFirst = 3485 * 4234; // = 14755490 + const expectedSecond = 88823 * 3482; // = 309281786 + const expectedSum = expectedFirst + expectedSecond; // = 324037276 + + // Track events for verification + const events: AgentEvent[] = []; + let turns = 0; + let toolCallCount = 0; + const toolResults: number[] = []; + let finalAnswer: number | undefined; + + // Execute the prompt + const stream = prompt(userPrompt, context, config); + + for await (const event of stream) { + events.push(event); + + switch (event.type) { + case "turn_start": + turns++; + console.log(`\n=== Turn ${turns} started ===`); + break; + + case "turn_end": + console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`); + console.log(event.assistantMessage); + break; + + case "tool_execution_end": + if (!event.isError && typeof event.result === "object" && event.result.output) { + toolCallCount++; + // Extract number from output like "expression = result" + const match = event.result.output.match(/=\s*([\d.]+)/); + if (match) { + const value = parseFloat(match[1]); + toolResults.push(value); + console.log(`Tool ${toolCallCount}: ${event.result.output}`); + } + } + break; + + case "message_end": + // Just track the message end event, don't extract answer here + break; + } + } + + // Get the final messages + const finalMessages = await stream.result(); + + // Verify the results + expect(finalMessages).toBeDefined(); + expect(finalMessages.length).toBeGreaterThan(0); + + const finalMessage = finalMessages[finalMessages.length - 1]; + expect(finalMessage).toBeDefined(); + expect(finalMessage.role).toBe("assistant"); + if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant"); + + // Extract the final answer from the last assistant message + const content = finalMessage.content + .filter((c) => c.type === "text") + .map((c) => (c.type === "text" ? c.text : "")) + .join(" "); + + // Look for integers in the response that might be the final answer + const numbers = content.match(/\b\d+\b/g); + if (numbers) { + // Check if any of the numbers matches our expected sum + for (const num of numbers) { + const value = parseInt(num, 10); + if (Math.abs(value - expectedSum) < 10) { + finalAnswer = value; + break; + } + } + // If no exact match, take the last large number as likely the answer + if (finalAnswer === undefined) { + const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000); + if (largeNumbers.length > 0) { + finalAnswer = largeNumbers[largeNumbers.length - 1]; + } + } + } + + // Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum + // (or possibly 2 if the model calculates the sum itself without a tool) + expect(toolCallCount).toBeGreaterThanOrEqual(2); + + // Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer + // Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions + expect(turns).toBeGreaterThanOrEqual(3); + expect(turns).toBeLessThanOrEqual(4); + + // Verify the individual calculations are in the results + const hasFirstCalc = toolResults.some((r) => r === expectedFirst); + const hasSecondCalc = toolResults.some((r) => r === expectedSecond); + expect(hasFirstCalc).toBe(true); + expect(hasSecondCalc).toBe(true); + + // Verify the final sum + if (finalAnswer !== undefined) { + expect(finalAnswer).toBe(expectedSum); + console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`); + } else { + // If we couldn't extract the final answer from text, check if it's in the tool results + const hasSum = toolResults.some((r) => r === expectedSum); + expect(hasSum).toBe(true); + } + + // Log summary + console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`); + if (turns === 3) { + console.log("Model used parallel tool calls for initial calculations"); + } else { + console.log("Model used sequential tool calls"); + } + + return { + turns, + toolCallCount, + toolResults, + finalAnswer, + events, + }; +} + +async function abortTest(model: Model, options: OptionsForApi = {}) { + // Create the agent context with the calculator tool + const context: AgentContext = { + systemPrompt: + "You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.", + messages: [], + tools: [calculateTool], + }; + + // Create the prompt config + const config: PromptConfig = { + model, + ...options, + }; + + // Create a prompt that will require multiple calculations + const userPrompt: UserMessage = { + role: "user", + content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.", + }; + + // Create abort controller + const abortController = new AbortController(); + + // Track events for verification + const events: AgentEvent[] = []; + let toolCallCount = 0; + const errorReceived = false; + let finalMessages: Message[] | undefined; + + // Execute the prompt + const stream = prompt(userPrompt, context, config, abortController.signal); + + // Abort after first tool execution + const abortPromise = (async () => { + for await (const event of stream) { + events.push(event); + + if (event.type === "tool_execution_end" && !event.isError) { + toolCallCount++; + // Abort after first successful tool execution + if (toolCallCount === 1) { + console.log("Aborting after first tool execution"); + abortController.abort(); + } + } + + if (event.type === "agent_end") { + finalMessages = event.messages; + } + } + })(); + + finalMessages = await stream.result(); + + // Verify abort behavior + console.log(`\nAbort test completed with ${toolCallCount} tool calls`); + const assistantMessage = finalMessages[finalMessages.length - 1]; + if (!assistantMessage) throw new Error("No final message received"); + expect(assistantMessage).toBeDefined(); + expect(assistantMessage.role).toBe("assistant"); + if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant"); + + // Should have executed 1 tool call before abort + expect(toolCallCount).toBeGreaterThanOrEqual(1); + expect(assistantMessage.stopReason).toBe("error"); + + return { + toolCallCount, + events, + errorReceived, + finalMessages, + }; +} + +describe("Agent Calculator Tests", () => { + describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => { + const model = getModel("google", "gemini-2.5-flash"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => { + const model = getModel("openai", "gpt-4o-mini"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => { + const model = getModel("openai", "gpt-5-mini"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => { + const model = getModel("anthropic", "claude-3-5-haiku-20241022"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => { + const model = getModel("xai", "grok-3"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => { + const model = getModel("groq", "openai/gpt-oss-20b"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => { + const model = getModel("cerebras", "gpt-oss-120b"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); + + describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => { + const model = getModel("zai", "glm-4.5-air"); + + it("should calculate multiple expressions and sum the results", async () => { + const result = await calculateTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(2); + }, 30000); + + it("should handle abort during tool execution", async () => { + const result = await abortTest(model); + expect(result.toolCallCount).toBeGreaterThanOrEqual(1); + }, 30000); + }); +}); diff --git a/packages/ai/test/empty.test.ts b/packages/ai/test/empty.test.ts index 8549fed3..30568cde 100644 --- a/packages/ai/test/empty.test.ts +++ b/packages/ai/test/empty.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from "vitest"; -import { complete } from "../src/generate.js"; import { getModel } from "../src/models.js"; +import { complete } from "../src/stream.js"; import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js"; async function testEmptyMessage(llm: Model, options: OptionsForApi = {}) { diff --git a/packages/ai/test/generate.test.ts b/packages/ai/test/generate.test.ts index e70ecef2..53a434af 100644 --- a/packages/ai/test/generate.test.ts +++ b/packages/ai/test/generate.test.ts @@ -3,8 +3,9 @@ import { readFileSync } from "fs"; import { dirname, join } from "path"; import { fileURLToPath } from "url"; import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import { complete, stream } from "../src/generate.js"; +import { z } from "zod"; import { getModel } from "../src/models.js"; +import { complete, stream } from "../src/stream.js"; import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js"; const __filename = fileURLToPath(import.meta.url); @@ -14,19 +15,13 @@ const __dirname = dirname(__filename); const calculatorTool: Tool = { name: "calculator", description: "Perform basic arithmetic operations", - parameters: { - type: "object", - properties: { - a: { type: "number", description: "First number" }, - b: { type: "number", description: "Second number" }, - operation: { - type: "string", - enum: ["add", "subtract", "multiply", "divide"], - description: "The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'.", - }, - }, - required: ["a", "b", "operation"], - }, + parameters: z.object({ + a: z.number().describe("First number"), + b: z.number().describe("Second number"), + operation: z + .enum(["add", "subtract", "multiply", "divide"]) + .describe("The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'."), + }), }; async function basicTextGeneration(model: Model, options?: OptionsForApi) { diff --git a/packages/ai/test/handoff.test.ts b/packages/ai/test/handoff.test.ts index cced9703..7eabbde0 100644 --- a/packages/ai/test/handoff.test.ts +++ b/packages/ai/test/handoff.test.ts @@ -1,19 +1,16 @@ import { describe, expect, it } from "vitest"; -import { complete } from "../src/generate.js"; +import { z } from "zod"; import { getModel } from "../src/models.js"; +import { complete } from "../src/stream.js"; import type { Api, AssistantMessage, Context, Message, Model, Tool, ToolResultMessage } from "../src/types.js"; // Tool for testing const weatherTool: Tool = { name: "get_weather", description: "Get the weather for a location", - parameters: { - type: "object", - properties: { - location: { type: "string", description: "City name" }, - }, - required: ["location"], - }, + parameters: z.object({ + location: z.string().describe("City name"), + }), }; // Pre-built contexts representing typical outputs from each provider diff --git a/packages/ai/test/tool-validation.test.ts b/packages/ai/test/tool-validation.test.ts new file mode 100644 index 00000000..09827b27 --- /dev/null +++ b/packages/ai/test/tool-validation.test.ts @@ -0,0 +1,112 @@ +import { describe, expect, it } from "vitest"; +import { z } from "zod"; +import type { AgentTool } from "../src/agent/types.js"; + +describe("Tool Validation with Zod", () => { + // Define a test tool with Zod schema + const testSchema = z.object({ + name: z.string().min(1, "Name is required"), + age: z.number().int().min(0).max(150), + email: z.string().email("Invalid email format"), + tags: z.array(z.string()).optional(), + }); + + const testTool: AgentTool = { + label: "Test Tool", + name: "test_tool", + description: "A test tool for validation", + parameters: testSchema, + execute: async (_toolCallId, args) => { + return { + output: `Processed: ${args.name}, ${args.age}, ${args.email}`, + details: undefined, + }; + }, + }; + + it("should validate correct input", () => { + const validInput = { + name: "John Doe", + age: 30, + email: "john@example.com", + tags: ["developer", "typescript"], + }; + + // This should not throw + const result = testTool.parameters.parse(validInput); + expect(result).toEqual(validInput); + }); + + it("should reject invalid email", () => { + const invalidInput = { + name: "John Doe", + age: 30, + email: "not-an-email", + }; + + expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError); + }); + + it("should reject missing required fields", () => { + const invalidInput = { + age: 30, + email: "john@example.com", + }; + + expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError); + }); + + it("should reject invalid age", () => { + const invalidInput = { + name: "John Doe", + age: -5, + email: "john@example.com", + }; + + expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError); + }); + + it("should format validation errors nicely", () => { + const invalidInput = { + name: "", + age: 200, + email: "invalid", + }; + + try { + testTool.parameters.parse(invalidInput); + // Should not reach here + expect(true).toBe(false); + } catch (e) { + if (e instanceof z.ZodError) { + const errors = e.issues + .map((err) => { + const path = err.path.length > 0 ? err.path.join(".") : "root"; + return ` - ${path}: ${err.message}`; + }) + .join("\n"); + + expect(errors).toContain("name: Name is required"); + expect(errors).toContain("age: Number must be less than or equal to 150"); + expect(errors).toContain("email: Invalid email format"); + } else { + throw e; + } + } + }); + + it("should have type-safe execute function", async () => { + const validInput = { + name: "John Doe", + age: 30, + email: "john@example.com", + }; + + // Validate and execute + const validated = testTool.parameters.parse(validInput); + const result = await testTool.execute("test-id", validated); + + expect(result.output).toBe("Processed: John Doe, 30, john@example.com"); + expect(result.details).toBeUndefined(); + }); +});