mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 02:04:05 +00:00
Remove provider-level tool validation, add validateToolCall helper
This commit is contained in:
parent
0196308266
commit
8bec289dc6
14 changed files with 59 additions and 68 deletions
|
|
@ -221,7 +221,6 @@ export class Agent {
|
||||||
tools: this._state.tools,
|
tools: this._state.tools,
|
||||||
model,
|
model,
|
||||||
reasoning,
|
reasoning,
|
||||||
validateToolCallsAtProvider: false,
|
|
||||||
getQueuedMessages: async <T>() => {
|
getQueuedMessages: async <T>() => {
|
||||||
// Return queued messages based on queue mode
|
// Return queued messages based on queue mode
|
||||||
if (this.queueMode === "one-at-a-time") {
|
if (this.queueMode === "one-at-a-time") {
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,6 @@ function streamSimpleProxy(
|
||||||
temperature: options.temperature,
|
temperature: options.temperature,
|
||||||
maxTokens: options.maxTokens,
|
maxTokens: options.maxTokens,
|
||||||
reasoning: options.reasoning,
|
reasoning: options.reasoning,
|
||||||
validateToolCallsAtProvider: options.validateToolCallsAtProvider,
|
|
||||||
// Don't send apiKey or signal - those are added server-side
|
// Don't send apiKey or signal - those are added server-side
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
|
@ -366,7 +365,6 @@ export class AppTransport implements AgentTransport {
|
||||||
model: cfg.model,
|
model: cfg.model,
|
||||||
reasoning: cfg.reasoning,
|
reasoning: cfg.reasoning,
|
||||||
getQueuedMessages: cfg.getQueuedMessages,
|
getQueuedMessages: cfg.getQueuedMessages,
|
||||||
validateToolCallsAtProvider: cfg.validateToolCallsAtProvider ?? false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Yield events from the upstream agentLoop iterator
|
// Yield events from the upstream agentLoop iterator
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,6 @@ export class ProviderTransport implements AgentTransport {
|
||||||
reasoning: cfg.reasoning,
|
reasoning: cfg.reasoning,
|
||||||
apiKey,
|
apiKey,
|
||||||
getQueuedMessages: cfg.getQueuedMessages,
|
getQueuedMessages: cfg.getQueuedMessages,
|
||||||
validateToolCallsAtProvider: cfg.validateToolCallsAtProvider ?? false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Yield events from agentLoop
|
// Yield events from agentLoop
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ export interface AgentRunConfig {
|
||||||
model: Model<any>;
|
model: Model<any>;
|
||||||
reasoning?: "low" | "medium" | "high";
|
reasoning?: "low" | "medium" | "high";
|
||||||
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
|
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
|
||||||
validateToolCallsAtProvider?: boolean;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,13 @@
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- Removed provider-level tool argument validation. Validation now happens in `agentLoop` via `executeToolCalls`, allowing models to retry on validation errors. For manual tool execution, use `validateToolCall(tools, toolCall)` or `validateToolArguments(tool, toolCall)`.
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Added `validateToolCallsAtProvider` option to streaming and agent APIs to optionally skip provider-level tool-call validation (default on), allowing agent loops to surface schema errors as toolResult messages and retry.
|
- Added `validateToolCall(tools, toolCall)` helper that finds the tool by name and validates arguments.
|
||||||
|
|
||||||
## [0.13.0] - 2025-12-06
|
## [0.13.0] - 2025-12-06
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -194,8 +194,8 @@ const response = await complete(model, context);
|
||||||
// Check for tool calls in the response
|
// Check for tool calls in the response
|
||||||
for (const block of response.content) {
|
for (const block of response.content) {
|
||||||
if (block.type === 'toolCall') {
|
if (block.type === 'toolCall') {
|
||||||
// Arguments are automatically validated against the TypeBox schema using AJV
|
// Execute your tool with the arguments
|
||||||
// If validation fails, an error event is emitted
|
// See "Validating Tool Arguments" section for validation
|
||||||
const result = await executeWeatherApi(block.arguments);
|
const result = await executeWeatherApi(block.arguments);
|
||||||
|
|
||||||
// Add tool result with text content
|
// Add tool result with text content
|
||||||
|
|
@ -253,7 +253,7 @@ for await (const event of s) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (event.type === 'toolcall_end') {
|
if (event.type === 'toolcall_end') {
|
||||||
// Here toolCall.arguments is complete and validated
|
// Here toolCall.arguments is complete (but not yet validated)
|
||||||
const toolCall = event.toolCall;
|
const toolCall = event.toolCall;
|
||||||
console.log(`Tool completed: ${toolCall.name}`, toolCall.arguments);
|
console.log(`Tool completed: ${toolCall.name}`, toolCall.arguments);
|
||||||
}
|
}
|
||||||
|
|
@ -267,22 +267,43 @@ for await (const event of s) {
|
||||||
- Arrays may be incomplete
|
- Arrays may be incomplete
|
||||||
- Nested objects may be partially populated
|
- Nested objects may be partially populated
|
||||||
- At minimum, `arguments` will be an empty object `{}`, never `undefined`
|
- At minimum, `arguments` will be an empty object `{}`, never `undefined`
|
||||||
- Full validation only occurs at `toolcall_end` when arguments are complete
|
|
||||||
- The Google provider does not support function call streaming. Instead, you will receive a single `toolcall_delta` event with the full arguments.
|
- The Google provider does not support function call streaming. Instead, you will receive a single `toolcall_delta` event with the full arguments.
|
||||||
|
|
||||||
### Provider tool-call validation
|
### Validating Tool Arguments
|
||||||
|
|
||||||
By default, providers validate streamed tool calls against your tool schema and abort the stream on validation errors. Set `validateToolCallsAtProvider: false` on `stream`, `streamSimple`, `complete`, `completeSimple`, or `AgentLoopConfig` to skip provider-level validation and let downstream code (for example, `agentLoop` via `executeToolCalls` → `validateToolArguments`) surface schema errors as `toolResult` messages. This enables the model to retry after receiving a validation error.
|
When using `agentLoop`, tool arguments are automatically validated against your TypeBox schemas before execution. If validation fails, the error is returned to the model as a tool result, allowing it to retry.
|
||||||
|
|
||||||
|
When implementing your own tool execution loop with `stream()` or `complete()`, use `validateToolCall` to validate arguments before passing them to your tools:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
await streamSimple(model, context, {
|
import { stream, validateToolCall, Tool } from '@mariozechner/pi-ai';
|
||||||
apiKey: 'your-key',
|
|
||||||
validateToolCallsAtProvider: false
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
- `true` (default): Provider validates tool calls and emits an error if arguments do not match the schema
|
const tools: Tool[] = [weatherTool, calculatorTool];
|
||||||
- `false`: Provider emits tool calls even when arguments are invalid; callers must validate and handle errors themselves
|
const s = stream(model, { messages, tools });
|
||||||
|
|
||||||
|
for await (const event of s) {
|
||||||
|
if (event.type === 'toolcall_end') {
|
||||||
|
const toolCall = event.toolCall;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Validate arguments against the tool's schema (throws on invalid args)
|
||||||
|
const validatedArgs = validateToolCall(tools, toolCall);
|
||||||
|
const result = await executeMyTool(toolCall.name, validatedArgs);
|
||||||
|
// ... add tool result to context
|
||||||
|
} catch (error) {
|
||||||
|
// Validation failed - return error as tool result so model can retry
|
||||||
|
context.messages.push({
|
||||||
|
role: 'toolResult',
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
toolName: toolCall.name,
|
||||||
|
content: [{ type: 'text', text: error.message }],
|
||||||
|
isError: true,
|
||||||
|
timestamp: Date.now()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Complete Event Reference
|
### Complete Event Reference
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ export * from "./stream.js";
|
||||||
export * from "./types.js";
|
export * from "./types.js";
|
||||||
export * from "./utils/overflow.js";
|
export * from "./utils/overflow.js";
|
||||||
export * from "./utils/typebox-helpers.js";
|
export * from "./utils/typebox-helpers.js";
|
||||||
|
export * from "./utils/validation.js";
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ import type {
|
||||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||||
import { validateToolArguments } from "../utils/validation.js";
|
|
||||||
import { transformMessages } from "./transorm-messages.js";
|
import { transformMessages } from "./transorm-messages.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -92,7 +92,6 @@ export const streamAnthropic: StreamFunction<"anthropic-messages"> = (
|
||||||
options?: AnthropicOptions,
|
options?: AnthropicOptions,
|
||||||
): AssistantMessageEventStream => {
|
): AssistantMessageEventStream => {
|
||||||
const stream = new AssistantMessageEventStream();
|
const stream = new AssistantMessageEventStream();
|
||||||
const shouldValidateToolCalls = options?.validateToolCallsAtProvider !== false;
|
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
|
|
@ -232,15 +231,6 @@ export const streamAnthropic: StreamFunction<"anthropic-messages"> = (
|
||||||
});
|
});
|
||||||
} else if (block.type === "toolCall") {
|
} else if (block.type === "toolCall") {
|
||||||
block.arguments = parseStreamingJson(block.partialJson);
|
block.arguments = parseStreamingJson(block.partialJson);
|
||||||
|
|
||||||
// Validate tool arguments if tool definition is available
|
|
||||||
if (shouldValidateToolCalls && context.tools) {
|
|
||||||
const tool = context.tools.find((t) => t.name === block.name);
|
|
||||||
if (tool) {
|
|
||||||
block.arguments = validateToolArguments(tool, block);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
delete (block as any).partialJson;
|
delete (block as any).partialJson;
|
||||||
stream.push({
|
stream.push({
|
||||||
type: "toolcall_end",
|
type: "toolcall_end",
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import type {
|
||||||
} from "../types.js";
|
} from "../types.js";
|
||||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||||
import { validateToolArguments } from "../utils/validation.js";
|
|
||||||
import { transformMessages } from "./transorm-messages.js";
|
import { transformMessages } from "./transorm-messages.js";
|
||||||
|
|
||||||
export interface GoogleOptions extends StreamOptions {
|
export interface GoogleOptions extends StreamOptions {
|
||||||
|
|
@ -43,7 +43,6 @@ export const streamGoogle: StreamFunction<"google-generative-ai"> = (
|
||||||
options?: GoogleOptions,
|
options?: GoogleOptions,
|
||||||
): AssistantMessageEventStream => {
|
): AssistantMessageEventStream => {
|
||||||
const stream = new AssistantMessageEventStream();
|
const stream = new AssistantMessageEventStream();
|
||||||
const shouldValidateToolCalls = options?.validateToolCallsAtProvider !== false;
|
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
|
|
@ -167,14 +166,6 @@ export const streamGoogle: StreamFunction<"google-generative-ai"> = (
|
||||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate tool arguments if tool definition is available
|
|
||||||
if (shouldValidateToolCalls && context.tools) {
|
|
||||||
const tool = context.tools.find((t) => t.name === toolCall.name);
|
|
||||||
if (tool) {
|
|
||||||
toolCall.arguments = validateToolArguments(tool, toolCall);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
output.content.push(toolCall);
|
output.content.push(toolCall);
|
||||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||||
stream.push({
|
stream.push({
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import type {
|
||||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||||
import { validateToolArguments } from "../utils/validation.js";
|
|
||||||
import { transformMessages } from "./transorm-messages.js";
|
import { transformMessages } from "./transorm-messages.js";
|
||||||
|
|
||||||
export interface OpenAICompletionsOptions extends StreamOptions {
|
export interface OpenAICompletionsOptions extends StreamOptions {
|
||||||
|
|
@ -37,7 +37,6 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions"> = (
|
||||||
options?: OpenAICompletionsOptions,
|
options?: OpenAICompletionsOptions,
|
||||||
): AssistantMessageEventStream => {
|
): AssistantMessageEventStream => {
|
||||||
const stream = new AssistantMessageEventStream();
|
const stream = new AssistantMessageEventStream();
|
||||||
const shouldValidateToolCalls = options?.validateToolCallsAtProvider !== false;
|
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
|
|
@ -85,15 +84,6 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions"> = (
|
||||||
});
|
});
|
||||||
} else if (block.type === "toolCall") {
|
} else if (block.type === "toolCall") {
|
||||||
block.arguments = JSON.parse(block.partialArgs || "{}");
|
block.arguments = JSON.parse(block.partialArgs || "{}");
|
||||||
|
|
||||||
// Validate tool arguments if tool definition is available
|
|
||||||
if (shouldValidateToolCalls && context.tools) {
|
|
||||||
const tool = context.tools.find((t) => t.name === block.name);
|
|
||||||
if (tool) {
|
|
||||||
block.arguments = validateToolArguments(tool, block);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
delete block.partialArgs;
|
delete block.partialArgs;
|
||||||
stream.push({
|
stream.push({
|
||||||
type: "toolcall_end",
|
type: "toolcall_end",
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ import type {
|
||||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||||
import { validateToolArguments } from "../utils/validation.js";
|
|
||||||
import { transformMessages } from "./transorm-messages.js";
|
import { transformMessages } from "./transorm-messages.js";
|
||||||
|
|
||||||
// OpenAI Responses-specific options
|
// OpenAI Responses-specific options
|
||||||
|
|
@ -45,7 +45,6 @@ export const streamOpenAIResponses: StreamFunction<"openai-responses"> = (
|
||||||
options?: OpenAIResponsesOptions,
|
options?: OpenAIResponsesOptions,
|
||||||
): AssistantMessageEventStream => {
|
): AssistantMessageEventStream => {
|
||||||
const stream = new AssistantMessageEventStream();
|
const stream = new AssistantMessageEventStream();
|
||||||
const shouldValidateToolCalls = options?.validateToolCallsAtProvider !== false;
|
|
||||||
|
|
||||||
// Start async processing
|
// Start async processing
|
||||||
(async () => {
|
(async () => {
|
||||||
|
|
@ -240,14 +239,6 @@ export const streamOpenAIResponses: StreamFunction<"openai-responses"> = (
|
||||||
arguments: JSON.parse(item.arguments),
|
arguments: JSON.parse(item.arguments),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate tool arguments if tool definition is available
|
|
||||||
if (shouldValidateToolCalls && context.tools) {
|
|
||||||
const tool = context.tools.find((t) => t.name === toolCall.name);
|
|
||||||
if (tool) {
|
|
||||||
toolCall.arguments = validateToolArguments(tool, toolCall);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,6 @@ function mapOptionsForApi<TApi extends Api>(
|
||||||
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
||||||
signal: options?.signal,
|
signal: options?.signal,
|
||||||
apiKey: apiKey || options?.apiKey,
|
apiKey: apiKey || options?.apiKey,
|
||||||
validateToolCallsAtProvider: options?.validateToolCallsAtProvider ?? true,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (model.api) {
|
switch (model.api) {
|
||||||
|
|
|
||||||
|
|
@ -37,12 +37,6 @@ export interface StreamOptions {
|
||||||
maxTokens?: number;
|
maxTokens?: number;
|
||||||
signal?: AbortSignal;
|
signal?: AbortSignal;
|
||||||
apiKey?: string;
|
apiKey?: string;
|
||||||
/**
|
|
||||||
* Controls whether providers validate streamed tool calls against the tool schema.
|
|
||||||
* Defaults to true. Set to false to skip provider-level validation and allow
|
|
||||||
* downstream consumers (e.g., agentLoop) to handle validation failures.
|
|
||||||
*/
|
|
||||||
validateToolCallsAtProvider?: boolean;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,21 @@ if (!isBrowserExtension) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
|
||||||
|
* @param tools Array of tool definitions
|
||||||
|
* @param toolCall The tool call from the LLM
|
||||||
|
* @returns The validated arguments
|
||||||
|
* @throws Error if tool is not found or validation fails
|
||||||
|
*/
|
||||||
|
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
|
||||||
|
const tool = tools.find((t) => t.name === toolCall.name);
|
||||||
|
if (!tool) {
|
||||||
|
throw new Error(`Tool "${toolCall.name}" not found`);
|
||||||
|
}
|
||||||
|
return validateToolArguments(tool, toolCall);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validates tool call arguments against the tool's TypeBox schema
|
* Validates tool call arguments against the tool's TypeBox schema
|
||||||
* @param tool The tool definition with TypeBox schema
|
* @param tool The tool definition with TypeBox schema
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue