mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 08:00:59 +00:00
feat(ai): Implement Zod-based tool validation and improve Agent API
- Replace JSON Schema with Zod schemas for tool parameter definitions - Add runtime validation for all tool calls at provider level - Create shared validation module with detailed error formatting - Update Agent API with comprehensive event system - Add agent tests with calculator tool for multi-turn execution - Add abort test to verify proper handling of aborted requests - Update documentation with detailed event flow examples - Rename generate.ts to stream.ts for clarity
This commit is contained in:
parent
594b0dac6c
commit
35fe8f21e9
24 changed files with 1069 additions and 221 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, stream } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||
|
||||
async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
|
|
|
|||
347
packages/ai/test/agent.test.ts
Normal file
347
packages/ai/test/agent.test.ts
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { prompt } from "../src/agent/agent.js";
|
||||
import { calculateTool } from "../src/agent/tools/calculate.js";
|
||||
import type { AgentContext, AgentEvent, PromptConfig } from "../src/agent/types.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Message, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||
|
||||
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Create the agent context with the calculator tool
|
||||
const context: AgentContext = {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
|
||||
messages: [],
|
||||
tools: [calculateTool],
|
||||
};
|
||||
|
||||
// Create the prompt config
|
||||
const config: PromptConfig = {
|
||||
model,
|
||||
...options,
|
||||
};
|
||||
|
||||
// Create the user prompt asking for multiple calculations
|
||||
const userPrompt: UserMessage = {
|
||||
role: "user",
|
||||
content: `Use the calculator tool to complete the following mulit-step task.
|
||||
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
|
||||
2. Calculate the sum of the two results using the calculator tool
|
||||
3. Output ONLY the final sum as a single integer number, nothing else.`,
|
||||
};
|
||||
|
||||
// Calculate expected results (using integers)
|
||||
const expectedFirst = 3485 * 4234; // = 14755490
|
||||
const expectedSecond = 88823 * 3482; // = 309281786
|
||||
const expectedSum = expectedFirst + expectedSecond; // = 324037276
|
||||
|
||||
// Track events for verification
|
||||
const events: AgentEvent[] = [];
|
||||
let turns = 0;
|
||||
let toolCallCount = 0;
|
||||
const toolResults: number[] = [];
|
||||
let finalAnswer: number | undefined;
|
||||
|
||||
// Execute the prompt
|
||||
const stream = prompt(userPrompt, context, config);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
|
||||
switch (event.type) {
|
||||
case "turn_start":
|
||||
turns++;
|
||||
console.log(`\n=== Turn ${turns} started ===`);
|
||||
break;
|
||||
|
||||
case "turn_end":
|
||||
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
|
||||
console.log(event.assistantMessage);
|
||||
break;
|
||||
|
||||
case "tool_execution_end":
|
||||
if (!event.isError && typeof event.result === "object" && event.result.output) {
|
||||
toolCallCount++;
|
||||
// Extract number from output like "expression = result"
|
||||
const match = event.result.output.match(/=\s*([\d.]+)/);
|
||||
if (match) {
|
||||
const value = parseFloat(match[1]);
|
||||
toolResults.push(value);
|
||||
console.log(`Tool ${toolCallCount}: ${event.result.output}`);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
// Just track the message end event, don't extract answer here
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the final messages
|
||||
const finalMessages = await stream.result();
|
||||
|
||||
// Verify the results
|
||||
expect(finalMessages).toBeDefined();
|
||||
expect(finalMessages.length).toBeGreaterThan(0);
|
||||
|
||||
const finalMessage = finalMessages[finalMessages.length - 1];
|
||||
expect(finalMessage).toBeDefined();
|
||||
expect(finalMessage.role).toBe("assistant");
|
||||
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
|
||||
|
||||
// Extract the final answer from the last assistant message
|
||||
const content = finalMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c.type === "text" ? c.text : ""))
|
||||
.join(" ");
|
||||
|
||||
// Look for integers in the response that might be the final answer
|
||||
const numbers = content.match(/\b\d+\b/g);
|
||||
if (numbers) {
|
||||
// Check if any of the numbers matches our expected sum
|
||||
for (const num of numbers) {
|
||||
const value = parseInt(num, 10);
|
||||
if (Math.abs(value - expectedSum) < 10) {
|
||||
finalAnswer = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If no exact match, take the last large number as likely the answer
|
||||
if (finalAnswer === undefined) {
|
||||
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
|
||||
if (largeNumbers.length > 0) {
|
||||
finalAnswer = largeNumbers[largeNumbers.length - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
|
||||
// (or possibly 2 if the model calculates the sum itself without a tool)
|
||||
expect(toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
|
||||
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
|
||||
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
|
||||
expect(turns).toBeGreaterThanOrEqual(3);
|
||||
expect(turns).toBeLessThanOrEqual(4);
|
||||
|
||||
// Verify the individual calculations are in the results
|
||||
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
|
||||
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
|
||||
expect(hasFirstCalc).toBe(true);
|
||||
expect(hasSecondCalc).toBe(true);
|
||||
|
||||
// Verify the final sum
|
||||
if (finalAnswer !== undefined) {
|
||||
expect(finalAnswer).toBe(expectedSum);
|
||||
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
|
||||
} else {
|
||||
// If we couldn't extract the final answer from text, check if it's in the tool results
|
||||
const hasSum = toolResults.some((r) => r === expectedSum);
|
||||
expect(hasSum).toBe(true);
|
||||
}
|
||||
|
||||
// Log summary
|
||||
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
|
||||
if (turns === 3) {
|
||||
console.log("Model used parallel tool calls for initial calculations");
|
||||
} else {
|
||||
console.log("Model used sequential tool calls");
|
||||
}
|
||||
|
||||
return {
|
||||
turns,
|
||||
toolCallCount,
|
||||
toolResults,
|
||||
finalAnswer,
|
||||
events,
|
||||
};
|
||||
}
|
||||
|
||||
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Create the agent context with the calculator tool
|
||||
const context: AgentContext = {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
|
||||
messages: [],
|
||||
tools: [calculateTool],
|
||||
};
|
||||
|
||||
// Create the prompt config
|
||||
const config: PromptConfig = {
|
||||
model,
|
||||
...options,
|
||||
};
|
||||
|
||||
// Create a prompt that will require multiple calculations
|
||||
const userPrompt: UserMessage = {
|
||||
role: "user",
|
||||
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
|
||||
};
|
||||
|
||||
// Create abort controller
|
||||
const abortController = new AbortController();
|
||||
|
||||
// Track events for verification
|
||||
const events: AgentEvent[] = [];
|
||||
let toolCallCount = 0;
|
||||
const errorReceived = false;
|
||||
let finalMessages: Message[] | undefined;
|
||||
|
||||
// Execute the prompt
|
||||
const stream = prompt(userPrompt, context, config, abortController.signal);
|
||||
|
||||
// Abort after first tool execution
|
||||
const abortPromise = (async () => {
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
|
||||
if (event.type === "tool_execution_end" && !event.isError) {
|
||||
toolCallCount++;
|
||||
// Abort after first successful tool execution
|
||||
if (toolCallCount === 1) {
|
||||
console.log("Aborting after first tool execution");
|
||||
abortController.abort();
|
||||
}
|
||||
}
|
||||
|
||||
if (event.type === "agent_end") {
|
||||
finalMessages = event.messages;
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
finalMessages = await stream.result();
|
||||
|
||||
// Verify abort behavior
|
||||
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
|
||||
const assistantMessage = finalMessages[finalMessages.length - 1];
|
||||
if (!assistantMessage) throw new Error("No final message received");
|
||||
expect(assistantMessage).toBeDefined();
|
||||
expect(assistantMessage.role).toBe("assistant");
|
||||
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
|
||||
|
||||
// Should have executed 1 tool call before abort
|
||||
expect(toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
expect(assistantMessage.stopReason).toBe("error");
|
||||
|
||||
return {
|
||||
toolCallCount,
|
||||
events,
|
||||
errorReceived,
|
||||
finalMessages,
|
||||
};
|
||||
}
|
||||
|
||||
describe("Agent Calculator Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
|
||||
const model = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
|
||||
const model = getModel("zai", "glm-4.5-air");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
}, 30000);
|
||||
|
||||
it("should handle abort during tool execution", async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
}, 30000);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||
|
||||
async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ import { readFileSync } from "fs";
|
|||
import { dirname, join } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { z } from "zod";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, stream } from "../src/stream.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
|
|
@ -14,19 +15,13 @@ const __dirname = dirname(__filename);
|
|||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'.",
|
||||
},
|
||||
},
|
||||
required: ["a", "b", "operation"],
|
||||
},
|
||||
parameters: z.object({
|
||||
a: z.number().describe("First number"),
|
||||
b: z.number().describe("Second number"),
|
||||
operation: z
|
||||
.enum(["add", "subtract", "multiply", "divide"])
|
||||
.describe("The operation to perform. One of 'add', 'subtract', 'multiply', 'divide'."),
|
||||
}),
|
||||
};
|
||||
|
||||
async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
|
|
|
|||
|
|
@ -1,19 +1,16 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { z } from "zod";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, AssistantMessage, Context, Message, Model, Tool, ToolResultMessage } from "../src/types.js";
|
||||
|
||||
// Tool for testing
|
||||
const weatherTool: Tool = {
|
||||
name: "get_weather",
|
||||
description: "Get the weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" },
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
parameters: z.object({
|
||||
location: z.string().describe("City name"),
|
||||
}),
|
||||
};
|
||||
|
||||
// Pre-built contexts representing typical outputs from each provider
|
||||
|
|
|
|||
112
packages/ai/test/tool-validation.test.ts
Normal file
112
packages/ai/test/tool-validation.test.ts
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { z } from "zod";
|
||||
import type { AgentTool } from "../src/agent/types.js";
|
||||
|
||||
describe("Tool Validation with Zod", () => {
|
||||
// Define a test tool with Zod schema
|
||||
const testSchema = z.object({
|
||||
name: z.string().min(1, "Name is required"),
|
||||
age: z.number().int().min(0).max(150),
|
||||
email: z.string().email("Invalid email format"),
|
||||
tags: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
const testTool: AgentTool<typeof testSchema, void> = {
|
||||
label: "Test Tool",
|
||||
name: "test_tool",
|
||||
description: "A test tool for validation",
|
||||
parameters: testSchema,
|
||||
execute: async (_toolCallId, args) => {
|
||||
return {
|
||||
output: `Processed: ${args.name}, ${args.age}, ${args.email}`,
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
it("should validate correct input", () => {
|
||||
const validInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
tags: ["developer", "typescript"],
|
||||
};
|
||||
|
||||
// This should not throw
|
||||
const result = testTool.parameters.parse(validInput);
|
||||
expect(result).toEqual(validInput);
|
||||
});
|
||||
|
||||
it("should reject invalid email", () => {
|
||||
const invalidInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "not-an-email",
|
||||
};
|
||||
|
||||
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
|
||||
});
|
||||
|
||||
it("should reject missing required fields", () => {
|
||||
const invalidInput = {
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
|
||||
});
|
||||
|
||||
it("should reject invalid age", () => {
|
||||
const invalidInput = {
|
||||
name: "John Doe",
|
||||
age: -5,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
expect(() => testTool.parameters.parse(invalidInput)).toThrowError(z.ZodError);
|
||||
});
|
||||
|
||||
it("should format validation errors nicely", () => {
|
||||
const invalidInput = {
|
||||
name: "",
|
||||
age: 200,
|
||||
email: "invalid",
|
||||
};
|
||||
|
||||
try {
|
||||
testTool.parameters.parse(invalidInput);
|
||||
// Should not reach here
|
||||
expect(true).toBe(false);
|
||||
} catch (e) {
|
||||
if (e instanceof z.ZodError) {
|
||||
const errors = e.issues
|
||||
.map((err) => {
|
||||
const path = err.path.length > 0 ? err.path.join(".") : "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
expect(errors).toContain("name: Name is required");
|
||||
expect(errors).toContain("age: Number must be less than or equal to 150");
|
||||
expect(errors).toContain("email: Invalid email format");
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it("should have type-safe execute function", async () => {
|
||||
const validInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
// Validate and execute
|
||||
const validated = testTool.parameters.parse(validInput);
|
||||
const result = await testTool.execute("test-id", validated);
|
||||
|
||||
expect(result.output).toBe("Processed: John Doe, 30, john@example.com");
|
||||
expect(result.details).toBeUndefined();
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue