diff --git a/packages/ai/package.json b/packages/ai/package.json index 49895535..5fee305c 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -13,6 +13,7 @@ "clean": "rm -rf dist", "generate-models": "npx tsx scripts/generate-models.ts", "build": "npm run generate-models && tsc -p tsconfig.build.json", + "dev": "tsc -p tsconfig.build.json --watch", "check": "biome check --write .", "test": "vitest --run", "prepublishOnly": "npm run clean && npm run build" diff --git a/packages/ai/src/agent/agent.ts b/packages/ai/src/agent/agent.ts new file mode 100644 index 00000000..2ec05833 --- /dev/null +++ b/packages/ai/src/agent/agent.ts @@ -0,0 +1,231 @@ +import { EventStream } from "../event-stream"; +import { streamSimple } from "../generate.js"; +import type { + AssistantMessage, + Context, + Message, + Model, + SimpleGenerateOptions, + ToolResultMessage, + UserMessage, +} from "../types.js"; +import type { AgentContext, AgentTool, AgentToolResult } from "./types"; + +// Event types +export type AgentEvent = + | { type: "message_start"; message: Message } + | { type: "message_update"; message: AssistantMessage } + | { type: "message_complete"; message: Message } + | { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any } + | { + type: "tool_execution_complete"; + toolCallId: string; + toolName: string; + result: AgentToolResult | string; + isError: boolean; + } + | { type: "turn_complete"; messages: AgentContext["messages"] }; + +// Configuration for prompt execution +export interface PromptConfig { + model: Model; + apiKey: string; + enableThinking?: boolean; + preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise; +} + +// Main prompt function - returns a stream of events +export function prompt( + context: AgentContext, + config: PromptConfig, + prompt: UserMessage, + signal?: AbortSignal, +): EventStream { + const stream = new EventStream( + (event) => event.type === "turn_complete", + (event) => (event.type === "turn_complete" ? event.messages : []), + ); + + // Run the prompt async + (async () => { + try { + // Track new messages generated during this prompt + const newMessages: AgentContext["messages"] = []; + + // Create user message + const messages = [...context.messages, prompt]; + newMessages.push(prompt); + + stream.push({ type: "message_start", message: prompt }); + stream.push({ type: "message_complete", message: prompt }); + + // Update context with new messages + const currentContext: AgentContext = { + ...context, + messages, + }; + + // Keep looping while we have tool calls + let hasMoreToolCalls = true; + while (hasMoreToolCalls) { + // Stream assistant response + const assistantMessage = await streamAssistantResponse(currentContext, config, signal, stream); + newMessages.push(assistantMessage); + + // Check for tool calls + const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); + hasMoreToolCalls = toolCalls.length > 0; + + if (hasMoreToolCalls) { + // Execute tool calls + const toolResults = await executeToolCalls(currentContext.tools, assistantMessage, signal, stream); + newMessages.push(...toolResults); + + // Add tool results to context + currentContext.messages = [...currentContext.messages, ...toolResults]; + } + } + + stream.push({ type: "turn_complete", messages: newMessages }); + } catch (error) { + // End stream on error + stream.end([]); + throw error; + } + })(); + + return stream; +} + +// Helper functions +async function streamAssistantResponse( + context: AgentContext, + config: PromptConfig, + signal: AbortSignal | undefined, + stream: EventStream, +): Promise { + // Convert AgentContext to Context for streamSimple + // Use a copy of messages to avoid mutating the original context + const processedMessages = config.preprocessor + ? await config.preprocessor(context.messages, signal) + : [...context.messages]; + const processedContext: Context = { + systemPrompt: context.systemPrompt, + messages: [...processedMessages].map((m) => { + if (m.role === "toolResult") { + const { details, ...rest } = m; + return rest; + } else { + return m; + } + }), + tools: context.tools, // AgentTool extends Tool, so this works + }; + + const options: SimpleGenerateOptions = { + apiKey: config.apiKey, + signal, + }; + + if (config.model.reasoning && config.enableThinking) { + options.reasoning = "medium"; + } + + const response = await streamSimple(config.model, processedContext, options); + + let partialMessage: AssistantMessage | null = null; + let addedPartial = false; + + for await (const event of response) { + switch (event.type) { + case "start": + partialMessage = event.partial; + context.messages.push(partialMessage); + addedPartial = true; + stream.push({ type: "message_start", message: { ...partialMessage } }); + break; + + case "text_start": + case "text_delta": + case "thinking_start": + case "thinking_delta": + case "toolcall_start": + case "toolcall_delta": + if (partialMessage) { + partialMessage = event.partial; + context.messages[context.messages.length - 1] = partialMessage; + stream.push({ type: "message_update", message: { ...partialMessage } }); + } + break; + + case "done": + case "error": { + const finalMessage = await response.result(); + if (addedPartial) { + context.messages[context.messages.length - 1] = finalMessage; + } else { + context.messages.push(finalMessage); + } + stream.push({ type: "message_complete", message: finalMessage }); + return finalMessage; + } + } + } + + return await response.result(); +} + +async function executeToolCalls( + tools: AgentTool[] | undefined, + assistantMessage: AssistantMessage, + signal: AbortSignal | undefined, + stream: EventStream, +): Promise[]> { + const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); + const results: ToolResultMessage[] = []; + + for (const toolCall of toolCalls) { + const tool = tools?.find((t) => t.name === toolCall.name); + + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + + let resultOrError: AgentToolResult | string; + let isError = false; + + try { + if (!tool) throw new Error(`Tool ${toolCall.name} not found`); + resultOrError = await tool.execute(toolCall.arguments, toolCall.id, signal); + } catch (e) { + resultOrError = `Error: ${e instanceof Error ? e.message : String(e)}`; + isError = true; + } + + stream.push({ + type: "tool_execution_complete", + toolCallId: toolCall.id, + toolName: toolCall.name, + result: resultOrError, + isError, + }); + + const toolResultMessage: ToolResultMessage = { + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + output: typeof resultOrError === "string" ? resultOrError : resultOrError.output, + details: typeof resultOrError === "string" ? ({} as T) : resultOrError.details, + isError, + }; + + results.push(toolResultMessage); + stream.push({ type: "message_start", message: toolResultMessage }); + stream.push({ type: "message_complete", message: toolResultMessage }); + } + + return results; +} diff --git a/packages/ai/src/agent/index.ts b/packages/ai/src/agent/index.ts new file mode 100644 index 00000000..df4b022f --- /dev/null +++ b/packages/ai/src/agent/index.ts @@ -0,0 +1,3 @@ +export { type AgentEvent, type PromptConfig, prompt } from "./agent"; +export * from "./tools"; +export type { AgentContext, AgentTool } from "./types"; diff --git a/packages/ai/src/agent/tools/calculate.ts b/packages/ai/src/agent/tools/calculate.ts new file mode 100644 index 00000000..c0eff265 --- /dev/null +++ b/packages/ai/src/agent/tools/calculate.ts @@ -0,0 +1,34 @@ +import type { AgentTool } from "../../agent"; + +export interface CalculateResult { + output: string; + details: undefined; +} + +export function calculate(expression: string): CalculateResult { + try { + const result = new Function("return " + expression)(); + return { output: `${expression} = ${result}`, details: undefined }; + } catch (e: any) { + throw new Error(e.message || String(e)); + } +} + +export const calculateTool: AgentTool = { + label: "Calculator", + name: "calculate", + description: "Evaluate mathematical expressions", + parameters: { + type: "object", + properties: { + expression: { + type: "string", + description: "The mathematical expression to evaluate", + }, + }, + required: ["expression"], + }, + execute: async (args: { expression: string }) => { + return calculate(args.expression); + }, +}; diff --git a/packages/ai/src/agent/tools/get-current-time.ts b/packages/ai/src/agent/tools/get-current-time.ts new file mode 100644 index 00000000..a4774302 --- /dev/null +++ b/packages/ai/src/agent/tools/get-current-time.ts @@ -0,0 +1,44 @@ +import type { AgentTool } from "../../agent"; +import type { AgentToolResult } from "../types"; + +export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {} + +export async function getCurrentTime(timezone?: string): Promise { + const date = new Date(); + if (timezone) { + try { + return { + output: date.toLocaleString("en-US", { + timeZone: timezone, + dateStyle: "full", + timeStyle: "long", + }), + details: { utcTimestamp: date.getTime() }, + }; + } catch (e) { + throw new Error(`Invalid timezone: ${timezone}. Current UTC time: ${date.toISOString()}`); + } + } + return { + output: date.toLocaleString("en-US", { dateStyle: "full", timeStyle: "long" }), + details: { utcTimestamp: date.getTime() }, + }; +} + +export const getCurrentTimeTool: AgentTool<{ utcTimestamp: number }> = { + label: "Current Time", + name: "get_current_time", + description: "Get the current date and time", + parameters: { + type: "object", + properties: { + timezone: { + type: "string", + description: "Optional timezone (e.g., 'America/New_York', 'Europe/London')", + }, + }, + }, + execute: async (args: { timezone?: string }) => { + return getCurrentTime(args.timezone); + }, +}; diff --git a/packages/ai/src/agent/tools/index.ts b/packages/ai/src/agent/tools/index.ts new file mode 100644 index 00000000..ddd96932 --- /dev/null +++ b/packages/ai/src/agent/tools/index.ts @@ -0,0 +1,2 @@ +export { calculate, calculateTool } from "./calculate"; +export { getCurrentTime, getCurrentTimeTool } from "./get-current-time"; diff --git a/packages/ai/src/agent/types.ts b/packages/ai/src/agent/types.ts new file mode 100644 index 00000000..d8ddcac9 --- /dev/null +++ b/packages/ai/src/agent/types.ts @@ -0,0 +1,22 @@ +import type { Message, Tool } from "../types.js"; + +export interface AgentToolResult { + // Output of the tool to be given to the LLM in ToolResultMessage.content + output: string; + // Details to be displayed in a UI or loggedty + details: T; +} + +// AgentTool extends Tool but adds the execute function +export interface AgentTool extends Tool { + // A human-readable label for the tool to be displayed in UI + label: string; + execute: (params: any, toolCallId: string, signal?: AbortSignal) => Promise>; +} + +// AgentContext is like Context but uses AgentTool +export interface AgentContext { + systemPrompt: string; + messages: Message[]; + tools?: AgentTool[]; +} diff --git a/packages/ai/src/event-stream.ts b/packages/ai/src/event-stream.ts new file mode 100644 index 00000000..2f0a82a4 --- /dev/null +++ b/packages/ai/src/event-stream.ts @@ -0,0 +1,82 @@ +import type { AssistantMessage, AssistantMessageEvent } from "./types"; + +// Generic event stream class for async iteration +export class EventStream implements AsyncIterable { + private queue: T[] = []; + private waiting: ((value: IteratorResult) => void)[] = []; + private done = false; + private finalResultPromise: Promise; + private resolveFinalResult!: (result: R) => void; + + constructor( + private isComplete: (event: T) => boolean, + private extractResult: (event: T) => R, + ) { + this.finalResultPromise = new Promise((resolve) => { + this.resolveFinalResult = resolve; + }); + } + + push(event: T): void { + if (this.done) return; + + if (this.isComplete(event)) { + this.done = true; + this.resolveFinalResult(this.extractResult(event)); + } + + // Deliver to waiting consumer or queue it + const waiter = this.waiting.shift(); + if (waiter) { + waiter({ value: event, done: false }); + } else { + this.queue.push(event); + } + } + + end(result?: R): void { + this.done = true; + if (result !== undefined) { + this.resolveFinalResult(result); + } + // Notify all waiting consumers that we're done + while (this.waiting.length > 0) { + const waiter = this.waiting.shift()!; + waiter({ value: undefined as any, done: true }); + } + } + + async *[Symbol.asyncIterator](): AsyncIterator { + while (true) { + if (this.queue.length > 0) { + yield this.queue.shift()!; + } else if (this.done) { + return; + } else { + const result = await new Promise>((resolve) => this.waiting.push(resolve)); + if (result.done) return; + yield result.value; + } + } + } + + result(): Promise { + return this.finalResultPromise; + } +} + +export class AssistantMessageEventStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") { + return event.message; + } else if (event.type === "error") { + return event.partial; + } + throw new Error("Unexpected event type for final result"); + }, + ); + } +} diff --git a/packages/ai/src/generate.ts b/packages/ai/src/generate.ts index fd67f04b..54eeed0f 100644 --- a/packages/ai/src/generate.ts +++ b/packages/ai/src/generate.ts @@ -5,9 +5,8 @@ import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/ import type { Api, AssistantMessage, - AssistantMessageEvent, + AssistantMessageEventStream, Context, - GenerateStream, KnownProvider, Model, OptionsForApi, @@ -15,73 +14,6 @@ import type { SimpleGenerateOptions, } from "./types.js"; -export class QueuedGenerateStream implements GenerateStream { - private queue: AssistantMessageEvent[] = []; - private waiting: ((value: IteratorResult) => void)[] = []; - private done = false; - private finalMessagePromise: Promise; - private resolveFinalMessage!: (message: AssistantMessage) => void; - - constructor() { - this.finalMessagePromise = new Promise((resolve) => { - this.resolveFinalMessage = resolve; - }); - } - - push(event: AssistantMessageEvent): void { - if (this.done) return; - - if (event.type === "done") { - this.done = true; - this.resolveFinalMessage(event.message); - } - if (event.type === "error") { - this.done = true; - this.resolveFinalMessage(event.partial); - } - - // Deliver to waiting consumer or queue it - const waiter = this.waiting.shift(); - if (waiter) { - waiter({ value: event, done: false }); - } else { - this.queue.push(event); - } - } - - end(): void { - this.done = true; - // Notify all waiting consumers that we're done - while (this.waiting.length > 0) { - const waiter = this.waiting.shift()!; - waiter({ value: undefined as any, done: true }); - } - } - - async *[Symbol.asyncIterator](): AsyncIterator { - while (true) { - // If we have queued events, yield them - if (this.queue.length > 0) { - yield this.queue.shift()!; - } else if (this.done) { - // No more events and we're done - return; - } else { - // Wait for next event - const result = await new Promise>((resolve) => - this.waiting.push(resolve), - ); - if (result.done) return; - yield result.value; - } - } - } - - finalMessage(): Promise { - return this.finalMessagePromise; - } -} - const apiKeys: Map = new Map(); export function setApiKey(provider: KnownProvider, key: string): void; @@ -117,7 +49,7 @@ export function stream( model: Model, context: Context, options?: OptionsForApi, -): GenerateStream { +): AssistantMessageEventStream { const apiKey = options?.apiKey || getApiKey(model.provider); if (!apiKey) { throw new Error(`No API key for provider: ${model.provider}`); @@ -152,14 +84,14 @@ export async function complete( options?: OptionsForApi, ): Promise { const s = stream(model, context, options); - return s.finalMessage(); + return s.result(); } export function streamSimple( model: Model, context: Context, options?: SimpleGenerateOptions, -): GenerateStream { +): AssistantMessageEventStream { const apiKey = options?.apiKey || getApiKey(model.provider); if (!apiKey) { throw new Error(`No API key for provider: ${model.provider}`); @@ -175,7 +107,7 @@ export async function completeSimple( options?: SimpleGenerateOptions, ): Promise { const s = streamSimple(model, context, options); - return s.finalMessage(); + return s.result(); } function mapOptionsForApi( diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index d163aad6..e0e24534 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,3 +1,4 @@ +export * from "./agent/index.js"; export * from "./generate.js"; export * from "./models.js"; export * from "./providers/anthropic.js"; diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index b6c20726..683f0ed1 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -4,7 +4,7 @@ import type { MessageCreateParamsStreaming, MessageParam, } from "@anthropic-ai/sdk/resources/messages.js"; -import { QueuedGenerateStream } from "../generate.js"; +import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, @@ -12,7 +12,6 @@ import type { Context, GenerateFunction, GenerateOptions, - GenerateStream, Message, Model, StopReason, @@ -20,8 +19,9 @@ import type { ThinkingContent, Tool, ToolCall, + ToolResultMessage, } from "../types.js"; -import { transformMessages } from "./utils.js"; +import { transformMessages } from "./transorm-messages.js"; export interface AnthropicOptions extends GenerateOptions { thinkingEnabled?: boolean; @@ -33,8 +33,8 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = ( model: Model<"anthropic-messages">, context: Context, options?: AnthropicOptions, -): GenerateStream => { - const stream = new QueuedGenerateStream(); +): AssistantMessageEventStream => { + const stream = new AssistantMessageEventStream(); (async () => { const output: AssistantMessage = { @@ -59,93 +59,114 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = ( const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal }); stream.push({ type: "start", partial: output }); - let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null; + type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number }; + const blocks = output.content as Block[]; for await (const event of anthropicStream) { if (event.type === "content_block_start") { if (event.content_block.type === "text") { - currentBlock = { + const block: Block = { type: "text", text: "", + index: event.index, }; - output.content.push(currentBlock); - stream.push({ type: "text_start", partial: output }); + output.content.push(block); + stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output }); } else if (event.content_block.type === "thinking") { - currentBlock = { + const block: Block = { type: "thinking", thinking: "", thinkingSignature: "", + index: event.index, }; - output.content.push(currentBlock); - stream.push({ type: "thinking_start", partial: output }); + output.content.push(block); + stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output }); } else if (event.content_block.type === "tool_use") { - // We wait for the full tool use to be streamed - currentBlock = { + const block: Block = { type: "toolCall", id: event.content_block.id, name: event.content_block.name, arguments: event.content_block.input as Record, partialJson: "", + index: event.index, }; + output.content.push(block); + stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output }); } } else if (event.type === "content_block_delta") { if (event.delta.type === "text_delta") { - if (currentBlock && currentBlock.type === "text") { - currentBlock.text += event.delta.text; + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (block && block.type === "text") { + block.text += event.delta.text; stream.push({ type: "text_delta", + contentIndex: index, delta: event.delta.text, partial: output, }); } } else if (event.delta.type === "thinking_delta") { - if (currentBlock && currentBlock.type === "thinking") { - currentBlock.thinking += event.delta.thinking; + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (block && block.type === "thinking") { + block.thinking += event.delta.thinking; stream.push({ type: "thinking_delta", + contentIndex: index, delta: event.delta.thinking, partial: output, }); } } else if (event.delta.type === "input_json_delta") { - if (currentBlock && currentBlock.type === "toolCall") { - currentBlock.partialJson += event.delta.partial_json; + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (block && block.type === "toolCall") { + block.partialJson += event.delta.partial_json; + stream.push({ + type: "toolcall_delta", + contentIndex: index, + delta: event.delta.partial_json, + partial: output, + }); } } else if (event.delta.type === "signature_delta") { - if (currentBlock && currentBlock.type === "thinking") { - currentBlock.thinkingSignature = currentBlock.thinkingSignature || ""; - currentBlock.thinkingSignature += event.delta.signature; + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (block && block.type === "thinking") { + block.thinkingSignature = block.thinkingSignature || ""; + block.thinkingSignature += event.delta.signature; } } } else if (event.type === "content_block_stop") { - if (currentBlock) { - if (currentBlock.type === "text") { + const index = blocks.findIndex((b) => b.index === event.index); + const block = blocks[index]; + if (block) { + delete (block as any).index; + if (block.type === "text") { stream.push({ type: "text_end", - content: currentBlock.text, + contentIndex: index, + content: block.text, partial: output, }); - } else if (currentBlock.type === "thinking") { + } else if (block.type === "thinking") { stream.push({ type: "thinking_end", - content: currentBlock.thinking, + contentIndex: index, + content: block.thinking, partial: output, }); - } else if (currentBlock.type === "toolCall") { - const finalToolCall: ToolCall = { - type: "toolCall", - id: currentBlock.id, - name: currentBlock.name, - arguments: JSON.parse(currentBlock.partialJson), - }; - output.content.push(finalToolCall); + } else if (block.type === "toolCall") { + block.arguments = JSON.parse(block.partialJson); + delete (block as any).partialJson; stream.push({ - type: "toolCall", - toolCall: finalToolCall, + type: "toolcall_end", + contentIndex: index, + toolCall: block, partial: output, }); } - currentBlock = null; } } else if (event.type === "message_delta") { if (event.delta.stop_reason) { @@ -166,6 +187,7 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = ( stream.push({ type: "done", reason: output.stopReason, message: output }); stream.end(); } catch (error) { + for (const block of output.content) delete (block as any).index; output.stopReason = "error"; output.error = error instanceof Error ? error.message : JSON.stringify(error); stream.push({ type: "error", error: output.error, partial: output }); @@ -294,7 +316,9 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages"> // Transform messages for cross-provider compatibility const transformedMessages = transformMessages(messages, model); - for (const msg of transformedMessages) { + for (let i = 0; i < transformedMessages.length; i++) { + const msg = transformedMessages[i]; + if (msg.role === "user") { if (typeof msg.content === "string") { if (msg.content.trim().length > 0) { @@ -366,16 +390,37 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages"> content: blocks, }); } else if (msg.role === "toolResult") { + // Collect all consecutive toolResult messages + const toolResults: ContentBlockParam[] = []; + + // Add the current tool result + toolResults.push({ + type: "tool_result", + tool_use_id: sanitizeToolCallId(msg.toolCallId), + content: msg.output, + is_error: msg.isError, + }); + + // Look ahead for consecutive toolResult messages + let j = i + 1; + while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") { + const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult + toolResults.push({ + type: "tool_result", + tool_use_id: sanitizeToolCallId(nextMsg.toolCallId), + content: nextMsg.output, + is_error: nextMsg.isError, + }); + j++; + } + + // Skip the messages we've already processed + i = j - 1; + + // Add a single user message with all tool results params.push({ role: "user", - content: [ - { - type: "tool_result", - tool_use_id: sanitizeToolCallId(msg.toolCallId), - content: msg.content, - is_error: msg.isError, - }, - ], + content: toolResults, }); } } diff --git a/packages/ai/src/providers/google.ts b/packages/ai/src/providers/google.ts index 88439eeb..8a90f9eb 100644 --- a/packages/ai/src/providers/google.ts +++ b/packages/ai/src/providers/google.ts @@ -7,7 +7,7 @@ import { GoogleGenAI, type Part, } from "@google/genai"; -import { QueuedGenerateStream } from "../generate.js"; +import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, @@ -15,7 +15,6 @@ import type { Context, GenerateFunction, GenerateOptions, - GenerateStream, Model, StopReason, TextContent, @@ -23,7 +22,7 @@ import type { Tool, ToolCall, } from "../types.js"; -import { transformMessages } from "./utils.js"; +import { transformMessages } from "./transorm-messages.js"; export interface GoogleOptions extends GenerateOptions { toolChoice?: "auto" | "none" | "any"; @@ -40,8 +39,8 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( model: Model<"google-generative-ai">, context: Context, options?: GoogleOptions, -): GenerateStream => { - const stream = new QueuedGenerateStream(); +): AssistantMessageEventStream => { + const stream = new AssistantMessageEventStream(); (async () => { const output: AssistantMessage = { @@ -67,6 +66,8 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( stream.push({ type: "start", partial: output }); let currentBlock: TextContent | ThinkingContent | null = null; + const blocks = output.content; + const blockIndex = () => blocks.length - 1; for await (const chunk of googleStream) { const candidate = chunk.candidates?.[0]; if (candidate?.content?.parts) { @@ -82,12 +83,14 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( if (currentBlock.type === "text") { stream.push({ type: "text_end", + contentIndex: blocks.length - 1, content: currentBlock.text, partial: output, }); } else { stream.push({ type: "thinking_end", + contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); @@ -95,10 +98,10 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( } if (isThinking) { currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined }; - stream.push({ type: "thinking_start", partial: output }); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); } else { currentBlock = { type: "text", text: "" }; - stream.push({ type: "text_start", partial: output }); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); } output.content.push(currentBlock); } @@ -107,12 +110,18 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( currentBlock.thinkingSignature = part.thoughtSignature; stream.push({ type: "thinking_delta", + contentIndex: blockIndex(), delta: part.text, partial: output, }); } else { currentBlock.text += part.text; - stream.push({ type: "text_delta", delta: part.text, partial: output }); + stream.push({ + type: "text_delta", + contentIndex: blockIndex(), + delta: part.text, + partial: output, + }); } } @@ -121,12 +130,14 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( if (currentBlock.type === "text") { stream.push({ type: "text_end", + contentIndex: blockIndex(), content: currentBlock.text, partial: output, }); } else { stream.push({ type: "thinking_end", + contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); @@ -149,7 +160,14 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( arguments: part.functionCall.args as Record, }; output.content.push(toolCall); - stream.push({ type: "toolCall", toolCall, partial: output }); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta: JSON.stringify(toolCall.arguments), + partial: output, + }); + stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output }); } } } @@ -182,9 +200,19 @@ export const streamGoogle: GenerateFunction<"google-generative-ai"> = ( if (currentBlock) { if (currentBlock.type === "text") { - stream.push({ type: "text_end", content: currentBlock.text, partial: output }); + stream.push({ + type: "text_end", + contentIndex: blockIndex(), + content: currentBlock.text, + partial: output, + }); } else { - stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output }); + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: currentBlock.thinking, + partial: output, + }); } } @@ -333,7 +361,7 @@ function convertMessages(model: Model<"google-generative-ai">, context: Context) id: msg.toolCallId, name: msg.toolName, response: { - result: msg.content, + result: msg.output, isError: msg.isError, }, }, diff --git a/packages/ai/src/providers/openai-completions.ts b/packages/ai/src/providers/openai-completions.ts index a28d4924..dcd8eb01 100644 --- a/packages/ai/src/providers/openai-completions.ts +++ b/packages/ai/src/providers/openai-completions.ts @@ -7,14 +7,13 @@ import type { ChatCompletionContentPartText, ChatCompletionMessageParam, } from "openai/resources/chat/completions.js"; -import { QueuedGenerateStream } from "../generate.js"; +import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { AssistantMessage, Context, GenerateFunction, GenerateOptions, - GenerateStream, Model, StopReason, TextContent, @@ -22,7 +21,7 @@ import type { Tool, ToolCall, } from "../types.js"; -import { transformMessages } from "./utils.js"; +import { transformMessages } from "./transorm-messages.js"; export interface OpenAICompletionsOptions extends GenerateOptions { toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } }; @@ -33,8 +32,8 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions, -): GenerateStream => { - const stream = new QueuedGenerateStream(); +): AssistantMessageEventStream => { + const stream = new AssistantMessageEventStream(); (async () => { const output: AssistantMessage = { @@ -60,6 +59,37 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( stream.push({ type: "start", partial: output }); let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null; + const blocks = output.content; + const blockIndex = () => blocks.length - 1; + const finishCurrentBlock = (block?: typeof currentBlock) => { + if (block) { + if (block.type === "text") { + stream.push({ + type: "text_end", + contentIndex: blockIndex(), + content: block.text, + partial: output, + }); + } else if (block.type === "thinking") { + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: block.thinking, + partial: output, + }); + } else if (block.type === "toolCall") { + block.arguments = JSON.parse(block.partialArgs || "{}"); + delete block.partialArgs; + stream.push({ + type: "toolcall_end", + contentIndex: blockIndex(), + toolCall: block, + partial: output, + }); + } + } + }; + for await (const chunk of openaiStream) { if (chunk.usage) { output.usage = { @@ -94,119 +124,53 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( choice.delta.content.length > 0 ) { if (!currentBlock || currentBlock.type !== "text") { - if (currentBlock) { - if (currentBlock.type === "thinking") { - stream.push({ - type: "thinking_end", - content: currentBlock.thinking, - partial: output, - }); - } else if (currentBlock.type === "toolCall") { - currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); - delete currentBlock.partialArgs; - stream.push({ - type: "toolCall", - toolCall: currentBlock as ToolCall, - partial: output, - }); - } - } + finishCurrentBlock(currentBlock); currentBlock = { type: "text", text: "" }; output.content.push(currentBlock); - stream.push({ type: "text_start", partial: output }); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); } if (currentBlock.type === "text") { currentBlock.text += choice.delta.content; stream.push({ type: "text_delta", + contentIndex: blockIndex(), delta: choice.delta.content, partial: output, }); } } - // Some endpoints return reasoning in reasoning_content (llama.cpp) - if ( - (choice.delta as any).reasoning_content !== null && - (choice.delta as any).reasoning_content !== undefined && - (choice.delta as any).reasoning_content.length > 0 - ) { - if (!currentBlock || currentBlock.type !== "thinking") { - if (currentBlock) { - if (currentBlock.type === "text") { - stream.push({ - type: "text_end", - content: currentBlock.text, - partial: output, - }); - } else if (currentBlock.type === "toolCall") { - currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); - delete currentBlock.partialArgs; - stream.push({ - type: "toolCall", - toolCall: currentBlock as ToolCall, - partial: output, - }); - } + // Some endpoints return reasoning in reasoning_content (llama.cpp), + // or reasoning (other openai compatible endpoints) + const reasoningFields = ["reasoning_content", "reasoning"]; + for (const field of reasoningFields) { + if ( + (choice.delta as any)[field] !== null && + (choice.delta as any)[field] !== undefined && + (choice.delta as any)[field].length > 0 + ) { + if (!currentBlock || currentBlock.type !== "thinking") { + finishCurrentBlock(currentBlock); + currentBlock = { + type: "thinking", + thinking: "", + thinkingSignature: field, + }; + output.content.push(currentBlock); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); } - currentBlock = { - type: "thinking", - thinking: "", - thinkingSignature: "reasoning_content", - }; - output.content.push(currentBlock); - stream.push({ type: "thinking_start", partial: output }); - } - if (currentBlock.type === "thinking") { - const delta = (choice.delta as any).reasoning_content; - currentBlock.thinking += delta; - stream.push({ - type: "thinking_delta", - delta, - partial: output, - }); - } - } - - // Some endpoints return reasoning in reasining (ollama, xAI, ...) - if ( - (choice.delta as any).reasoning !== null && - (choice.delta as any).reasoning !== undefined && - (choice.delta as any).reasoning.length > 0 - ) { - if (!currentBlock || currentBlock.type !== "thinking") { - if (currentBlock) { - if (currentBlock.type === "text") { - stream.push({ - type: "text_end", - content: currentBlock.text, - partial: output, - }); - } else if (currentBlock.type === "toolCall") { - currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); - delete currentBlock.partialArgs; - stream.push({ - type: "toolCall", - toolCall: currentBlock as ToolCall, - partial: output, - }); - } + if (currentBlock.type === "thinking") { + const delta = (choice.delta as any)[field]; + currentBlock.thinking += delta; + stream.push({ + type: "thinking_delta", + contentIndex: blockIndex(), + delta, + partial: output, + }); } - currentBlock = { - type: "thinking", - thinking: "", - thinkingSignature: "reasoning", - }; - output.content.push(currentBlock); - stream.push({ type: "thinking_start", partial: output }); - } - - if (currentBlock.type === "thinking") { - const delta = (choice.delta as any).reasoning; - currentBlock.thinking += delta; - stream.push({ type: "thinking_delta", delta, partial: output }); } } @@ -217,30 +181,7 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( currentBlock.type !== "toolCall" || (toolCall.id && currentBlock.id !== toolCall.id) ) { - if (currentBlock) { - if (currentBlock.type === "text") { - stream.push({ - type: "text_end", - content: currentBlock.text, - partial: output, - }); - } else if (currentBlock.type === "thinking") { - stream.push({ - type: "thinking_end", - content: currentBlock.thinking, - partial: output, - }); - } else if (currentBlock.type === "toolCall") { - currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); - delete currentBlock.partialArgs; - stream.push({ - type: "toolCall", - toolCall: currentBlock as ToolCall, - partial: output, - }); - } - } - + finishCurrentBlock(currentBlock); currentBlock = { type: "toolCall", id: toolCall.id || "", @@ -249,43 +190,30 @@ export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = ( partialArgs: "", }; output.content.push(currentBlock); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); } if (currentBlock.type === "toolCall") { if (toolCall.id) currentBlock.id = toolCall.id; if (toolCall.function?.name) currentBlock.name = toolCall.function.name; + let delta = ""; if (toolCall.function?.arguments) { + delta = toolCall.function.arguments; currentBlock.partialArgs += toolCall.function.arguments; } + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta, + partial: output, + }); } } } } } - if (currentBlock) { - if (currentBlock.type === "text") { - stream.push({ - type: "text_end", - content: currentBlock.text, - partial: output, - }); - } else if (currentBlock.type === "thinking") { - stream.push({ - type: "thinking_end", - content: currentBlock.thinking, - partial: output, - }); - } else if (currentBlock.type === "toolCall") { - currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); - delete currentBlock.partialArgs; - stream.push({ - type: "toolCall", - toolCall: currentBlock as ToolCall, - partial: output, - }); - } - } + finishCurrentBlock(currentBlock); if (options?.signal?.aborted) { throw new Error("Request was aborted"); @@ -438,7 +366,7 @@ function convertMessages(model: Model<"openai-completions">, context: Context): } else if (msg.role === "toolResult") { params.push({ role: "tool", - content: msg.content, + content: msg.output, tool_call_id: msg.toolCallId, }); } diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index a4d1ed4d..484caa2f 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -10,7 +10,7 @@ import type { ResponseOutputMessage, ResponseReasoningItem, } from "openai/resources/responses/responses.js"; -import { QueuedGenerateStream } from "../generate.js"; +import { AssistantMessageEventStream } from "../event-stream.js"; import { calculateCost } from "../models.js"; import type { Api, @@ -18,7 +18,6 @@ import type { Context, GenerateFunction, GenerateOptions, - GenerateStream, Model, StopReason, TextContent, @@ -26,7 +25,7 @@ import type { Tool, ToolCall, } from "../types.js"; -import { transformMessages } from "./utils.js"; +import { transformMessages } from "./transorm-messages.js"; // OpenAI Responses-specific options export interface OpenAIResponsesOptions extends GenerateOptions { @@ -41,8 +40,8 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions, -): GenerateStream => { - const stream = new QueuedGenerateStream(); +): AssistantMessageEventStream => { + const stream = new AssistantMessageEventStream(); // Start async processing (async () => { @@ -70,7 +69,9 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( stream.push({ type: "start", partial: output }); let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null; - let currentBlock: ThinkingContent | TextContent | ToolCall | null = null; + let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null; + const blocks = output.content; + const blockIndex = () => blocks.length - 1; for await (const event of openaiStream) { // Handle output item start @@ -80,12 +81,23 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( currentItem = item; currentBlock = { type: "thinking", thinking: "" }; output.content.push(currentBlock); - stream.push({ type: "thinking_start", partial: output }); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); } else if (item.type === "message") { currentItem = item; currentBlock = { type: "text", text: "" }; output.content.push(currentBlock); - stream.push({ type: "text_start", partial: output }); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); + } else if (item.type === "function_call") { + currentItem = item; + currentBlock = { + type: "toolCall", + id: item.call_id + "|" + item.id, + name: item.name, + arguments: {}, + partialJson: item.arguments || "", + }; + output.content.push(currentBlock); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); } } // Handle reasoning summary deltas @@ -108,6 +120,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( lastPart.text += event.delta; stream.push({ type: "thinking_delta", + contentIndex: blockIndex(), delta: event.delta, partial: output, }); @@ -129,6 +142,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( lastPart.text += "\n\n"; stream.push({ type: "thinking_delta", + contentIndex: blockIndex(), delta: "\n\n", partial: output, }); @@ -149,6 +163,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( lastPart.text += event.delta; stream.push({ type: "text_delta", + contentIndex: blockIndex(), delta: event.delta, partial: output, }); @@ -162,12 +177,36 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( lastPart.refusal += event.delta; stream.push({ type: "text_delta", + contentIndex: blockIndex(), delta: event.delta, partial: output, }); } } } + // Handle function call argument deltas + else if (event.type === "response.function_call_arguments.delta") { + if ( + currentItem && + currentItem.type === "function_call" && + currentBlock && + currentBlock.type === "toolCall" + ) { + currentBlock.partialJson += event.delta; + try { + const args = JSON.parse(currentBlock.partialJson); + currentBlock.arguments = args; + } catch { + // Ignore JSON parse errors - the JSON might be incomplete + } + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta: event.delta, + partial: output, + }); + } + } // Handle output item completion else if (event.type === "response.output_item.done") { const item = event.item; @@ -177,6 +216,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( currentBlock.thinkingSignature = JSON.stringify(item); stream.push({ type: "thinking_end", + contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); @@ -186,6 +226,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( currentBlock.textSignature = item.id; stream.push({ type: "text_end", + contentIndex: blockIndex(), content: currentBlock.text, partial: output, }); @@ -197,8 +238,7 @@ export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = ( name: item.name, arguments: JSON.parse(item.arguments), }; - output.content.push(toolCall); - stream.push({ type: "toolCall", toolCall, partial: output }); + stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output }); } } // Handle completion @@ -398,7 +438,7 @@ function convertMessages(model: Model<"openai-responses">, context: Context): Re messages.push({ type: "function_call_output", call_id: msg.toolCallId.split("|")[0], - output: msg.content, + output: msg.output, }); } } diff --git a/packages/ai/src/providers/utils.ts b/packages/ai/src/providers/transorm-messages.ts similarity index 100% rename from packages/ai/src/providers/utils.ts rename to packages/ai/src/providers/transorm-messages.ts diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index dbbb2bfc..6af1945e 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -1,8 +1,11 @@ +import type { AssistantMessageEventStream } from "./event-stream"; import type { AnthropicOptions } from "./providers/anthropic"; import type { GoogleOptions } from "./providers/google"; import type { OpenAICompletionsOptions } from "./providers/openai-completions"; import type { OpenAIResponsesOptions } from "./providers/openai-responses"; +export type { AssistantMessageEventStream } from "./event-stream"; + export type Api = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai"; export interface ApiOptionsMap { @@ -28,12 +31,6 @@ export type Provider = KnownProvider | string; export type ReasoningEffort = "minimal" | "low" | "medium" | "high"; -// The stream interface - what generate() returns -export interface GenerateStream extends AsyncIterable { - // Get the final message (waits for streaming to complete) - finalMessage(): Promise; -} - // Base options all providers share export interface GenerateOptions { temperature?: number; @@ -52,7 +49,7 @@ export type GenerateFunction = ( model: Model, context: Context, options: OptionsForApi, -) => GenerateStream; +) => AssistantMessageEventStream; export interface TextContent { type: "text"; @@ -111,11 +108,12 @@ export interface AssistantMessage { error?: string; } -export interface ToolResultMessage { +export interface ToolResultMessage { role: "toolResult"; toolCallId: string; toolName: string; - content: string; + output: string; + details?: TDetails; isError: boolean; } @@ -135,13 +133,15 @@ export interface Context { export type AssistantMessageEvent = | { type: "start"; partial: AssistantMessage } - | { type: "text_start"; partial: AssistantMessage } - | { type: "text_delta"; delta: string; partial: AssistantMessage } - | { type: "text_end"; content: string; partial: AssistantMessage } - | { type: "thinking_start"; partial: AssistantMessage } - | { type: "thinking_delta"; delta: string; partial: AssistantMessage } - | { type: "thinking_end"; content: string; partial: AssistantMessage } - | { type: "toolCall"; toolCall: ToolCall; partial: AssistantMessage } + | { type: "text_start"; contentIndex: number; partial: AssistantMessage } + | { type: "text_delta"; contentIndex: number; delta: string; partial: AssistantMessage } + | { type: "text_end"; contentIndex: number; content: string; partial: AssistantMessage } + | { type: "thinking_start"; contentIndex: number; partial: AssistantMessage } + | { type: "thinking_delta"; contentIndex: number; delta: string; partial: AssistantMessage } + | { type: "thinking_end"; contentIndex: number; content: string; partial: AssistantMessage } + | { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage } + | { type: "toolcall_delta"; contentIndex: number; delta: string; partial: AssistantMessage } + | { type: "toolcall_end"; contentIndex: number; toolCall: ToolCall; partial: AssistantMessage } | { type: "done"; reason: StopReason; message: AssistantMessage } | { type: "error"; error: string; partial: AssistantMessage }; diff --git a/packages/ai/test/abort.test.ts b/packages/ai/test/abort.test.ts index 028ea13c..7de2892d 100644 --- a/packages/ai/test/abort.test.ts +++ b/packages/ai/test/abort.test.ts @@ -22,7 +22,7 @@ async function testAbortSignal(llm: Model, options: Opti abortFired = true; break; } - const msg = await response.finalMessage(); + const msg = await response.result(); // If we get here without throwing, the abort didn't work expect(msg.stopReason).toBe("error"); diff --git a/packages/ai/test/cross-provider-toolcall.test.ts b/packages/ai/test/cross-provider-toolcall.test.ts deleted file mode 100644 index 2707ce56..00000000 --- a/packages/ai/test/cross-provider-toolcall.test.ts +++ /dev/null @@ -1,113 +0,0 @@ -import { type Context, complete, getModel } from "../src/index.js"; - -async function testCrossProviderToolCall() { - console.log("Testing cross-provider tool call handoff...\n"); - - // Define a simple tool - const tools = [ - { - name: "get_weather", - description: "Get current weather for a location", - parameters: { - type: "object", - properties: { - location: { type: "string", description: "City name" }, - }, - required: ["location"], - }, - }, - ]; - - // Create context with tools - const context: Context = { - systemPrompt: "You are a helpful assistant. Use the get_weather tool when asked about weather.", - messages: [{ role: "user", content: "What is the weather in Paris?" }], - tools, - }; - - try { - // Step 1: Get tool call from GPT-5 - console.log("Step 1: Getting tool call from GPT-5..."); - const gpt5 = getModel("openai", "gpt-5-mini"); - const gpt5Response = await complete(gpt5, context); - context.messages.push(gpt5Response); - - // Check for tool calls - const toolCalls = gpt5Response.content.filter((b) => b.type === "toolCall"); - console.log(`GPT-5 made ${toolCalls.length} tool call(s)`); - - if (toolCalls.length > 0) { - const toolCall = toolCalls[0]; - console.log(`Tool call ID: ${toolCall.id}`); - console.log(`Tool call contains pipe: ${toolCall.id.includes("|")}`); - console.log(`Tool: ${toolCall.name}(${JSON.stringify(toolCall.arguments)})\n`); - - // Add tool result - context.messages.push({ - role: "toolResult", - toolCallId: toolCall.id, - toolName: toolCall.name, - content: JSON.stringify({ - location: "Paris", - temperature: "22°C", - conditions: "Partly cloudy", - }), - isError: false, - }); - - // Step 2: Send to Claude Haiku for follow-up - console.log("Step 2: Sending to Claude Haiku for follow-up..."); - const haiku = getModel("anthropic", "claude-3-5-haiku-20241022"); - - try { - const haikuResponse = await complete(haiku, context); - console.log("✅ Claude Haiku successfully processed the conversation!"); - console.log("Response content types:", haikuResponse.content.map((b) => b.type).join(", ")); - console.log("Number of content blocks:", haikuResponse.content.length); - console.log("Stop reason:", haikuResponse.stopReason); - if (haikuResponse.error) { - console.log("Error message:", haikuResponse.error); - } - - // Print all response content - for (const block of haikuResponse.content) { - if (block.type === "text") { - console.log("\nClaude text response:", block.text); - } else if (block.type === "thinking") { - console.log("\nClaude thinking:", block.thinking); - } else if (block.type === "toolCall") { - console.log("\nClaude tool call:", block.name, block.arguments); - } - } - - if (haikuResponse.content.length === 0) { - console.log("⚠️ Claude returned an empty response!"); - } - } catch (error) { - console.error("❌ Claude Haiku failed to process the conversation:"); - console.error("Error:", error); - - // Check if it's related to the tool call ID - if (error instanceof Error && error.message.includes("tool")) { - console.error("\n⚠️ This appears to be a tool call ID issue!"); - console.error("The pipe character (|) in OpenAI Response API tool IDs might be causing problems."); - } - } - } else { - console.log("No tool calls were made by GPT-5"); - } - } catch (error) { - console.error("Test failed:", error); - } -} - -// Set API keys from environment or pass them explicitly -const openaiKey = process.env.OPENAI_API_KEY; -const anthropicKey = process.env.ANTHROPIC_API_KEY; - -if (!openaiKey || !anthropicKey) { - console.error("Please set OPENAI_API_KEY and ANTHROPIC_API_KEY environment variables"); - process.exit(1); -} - -testCrossProviderToolCall().catch(console.error); diff --git a/packages/ai/test/generate.test.ts b/packages/ai/test/generate.test.ts index faf23e8f..50f7c124 100644 --- a/packages/ai/test/generate.test.ts +++ b/packages/ai/test/generate.test.ts @@ -5,7 +5,7 @@ import { fileURLToPath } from "url"; import { afterAll, beforeAll, describe, expect, it } from "vitest"; import { complete, stream } from "../src/generate.js"; import { getModel } from "../src/models.js"; -import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js"; +import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js"; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); @@ -70,13 +70,62 @@ async function handleToolCall(model: Model, options?: Op tools: [calculatorTool], }; - const response = await complete(model, context, options); + const s = await stream(model, context, options); + let hasToolStart = false; + let hasToolDelta = false; + let hasToolEnd = false; + let accumulatedToolArgs = ""; + let index = 0; + for await (const event of s) { + if (event.type === "toolcall_start") { + hasToolStart = true; + const toolCall = event.partial.content[event.contentIndex]; + index = event.contentIndex; + expect(toolCall.type).toBe("toolCall"); + if (toolCall.type === "toolCall") { + expect(toolCall.name).toBe("calculator"); + expect(toolCall.id).toBeTruthy(); + } + } + if (event.type === "toolcall_delta") { + hasToolDelta = true; + const toolCall = event.partial.content[event.contentIndex]; + expect(event.contentIndex).toBe(index); + expect(toolCall.type).toBe("toolCall"); + if (toolCall.type === "toolCall") { + expect(toolCall.name).toBe("calculator"); + accumulatedToolArgs += event.delta; + } + } + if (event.type === "toolcall_end") { + hasToolEnd = true; + const toolCall = event.partial.content[event.contentIndex]; + expect(event.contentIndex).toBe(index); + expect(toolCall.type).toBe("toolCall"); + if (toolCall.type === "toolCall") { + expect(toolCall.name).toBe("calculator"); + JSON.parse(accumulatedToolArgs); + expect(toolCall.arguments).not.toBeUndefined(); + expect((toolCall.arguments as any).a).toBe(15); + expect((toolCall.arguments as any).b).toBe(27); + expect((toolCall.arguments as any).operation).oneOf(["add", "subtract", "multiply", "divide"]); + } + } + } + + expect(hasToolStart).toBe(true); + expect(hasToolDelta).toBe(true); + expect(hasToolEnd).toBe(true); + + const response = await s.result(); expect(response.stopReason).toBe("toolUse"); expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy(); const toolCall = response.content.find((b) => b.type === "toolCall"); if (toolCall && toolCall.type === "toolCall") { expect(toolCall.name).toBe("calculator"); expect(toolCall.id).toBeTruthy(); + } else { + throw new Error("No tool call found in response"); } } @@ -101,7 +150,7 @@ async function handleStreaming(model: Model, options?: O } } - const response = await s.finalMessage(); + const response = await s.result(); expect(textStarted).toBe(true); expect(textChunks.length).toBeGreaterThan(0); @@ -135,7 +184,7 @@ async function handleThinking(model: Model, options?: Op } } - const response = await s.finalMessage(); + const response = await s.result(); expect(response.stopReason, `Error: ${response.error}`).toBe("stop"); expect(thinkingStarted).toBe(true); @@ -214,6 +263,7 @@ async function multiTurn(model: Model, options?: Options context.messages.push(response); // Process content blocks + const results: ToolResultMessage[] = []; for (const block of response.content) { if (block.type === "text") { allTextContent += block.text; @@ -241,15 +291,16 @@ async function multiTurn(model: Model, options?: Options } // Add tool result to context - context.messages.push({ + results.push({ role: "toolResult", toolCallId: block.id, toolName: block.name, - content: `${result}`, + output: `${result}`, isError: false, }); } } + context.messages.push(...results); // If we got a stop response with text content, we're likely done expect(response.stopReason).not.toBe("error"); @@ -331,12 +382,12 @@ describe("Generate E2E Tests", () => { await handleStreaming(llm); }); - it("should handle ", { retry: 2 }, async () => { - await handleThinking(llm, { reasoningEffort: "medium" }); + it("should handle thinking", { retry: 2 }, async () => { + await handleThinking(llm, { reasoningEffort: "high" }); }); it("should handle multi-turn with thinking and tools", async () => { - await multiTurn(llm, { reasoningEffort: "medium" }); + await multiTurn(llm, { reasoningEffort: "high" }); }); it("should handle image input", async () => { diff --git a/packages/ai/test/handoff.test.ts b/packages/ai/test/handoff.test.ts index 9992a6a9..cced9703 100644 --- a/packages/ai/test/handoff.test.ts +++ b/packages/ai/test/handoff.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it } from "vitest"; import { complete } from "../src/generate.js"; import { getModel } from "../src/models.js"; -import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js"; +import type { Api, AssistantMessage, Context, Message, Model, Tool, ToolResultMessage } from "../src/types.js"; // Tool for testing const weatherTool: Tool = { @@ -22,6 +22,7 @@ const providerContexts = { anthropic: { message: { role: "assistant", + api: "anthropic-messages", content: [ { type: "thinking", @@ -49,14 +50,14 @@ const providerContexts = { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "toolUse", - } as AssistantMessage, + } satisfies AssistantMessage, toolResult: { role: "toolResult" as const, toolCallId: "toolu_01abc123", toolName: "get_weather", - content: "Weather in Tokyo: 18°C, partly cloudy", + output: "Weather in Tokyo: 18°C, partly cloudy", isError: false, - }, + } satisfies ToolResultMessage, facts: { calculation: 391, city: "Tokyo", @@ -69,6 +70,7 @@ const providerContexts = { google: { message: { role: "assistant", + api: "google-generative-ai", content: [ { type: "thinking", @@ -97,14 +99,14 @@ const providerContexts = { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "toolUse", - } as AssistantMessage, + } satisfies AssistantMessage, toolResult: { role: "toolResult" as const, toolCallId: "call_gemini_123", toolName: "get_weather", - content: "Weather in Berlin: 22°C, sunny", + output: "Weather in Berlin: 22°C, sunny", isError: false, - }, + } satisfies ToolResultMessage, facts: { calculation: 456, city: "Berlin", @@ -117,6 +119,7 @@ const providerContexts = { openaiCompletions: { message: { role: "assistant", + api: "openai-completions", content: [ { type: "thinking", @@ -144,14 +147,14 @@ const providerContexts = { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "toolUse", - } as AssistantMessage, + } satisfies AssistantMessage, toolResult: { role: "toolResult" as const, toolCallId: "call_abc123", toolName: "get_weather", - content: "Weather in London: 15°C, rainy", + output: "Weather in London: 15°C, rainy", isError: false, - }, + } satisfies ToolResultMessage, facts: { calculation: 525, city: "London", @@ -164,6 +167,7 @@ const providerContexts = { openaiResponses: { message: { role: "assistant", + api: "openai-responses", content: [ { type: "thinking", @@ -193,14 +197,14 @@ const providerContexts = { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "toolUse", - } as AssistantMessage, + } satisfies AssistantMessage, toolResult: { role: "toolResult" as const, toolCallId: "call_789_item_012", // Match the updated ID format toolName: "get_weather", - content: "Weather in Sydney: 25°C, clear", + output: "Weather in Sydney: 25°C, clear", isError: false, - }, + } satisfies ToolResultMessage, facts: { calculation: 486, city: "Sydney", @@ -213,6 +217,7 @@ const providerContexts = { aborted: { message: { role: "assistant", + api: "anthropic-messages", content: [ { type: "thinking", @@ -235,7 +240,7 @@ const providerContexts = { }, stopReason: "error", error: "Request was aborted", - } as AssistantMessage, + } satisfies AssistantMessage, toolResult: null, facts: { calculation: 600, diff --git a/packages/ai/tsconfig.build.json b/packages/ai/tsconfig.build.json index 5ce43029..6089faa7 100644 --- a/packages/ai/tsconfig.build.json +++ b/packages/ai/tsconfig.build.json @@ -4,6 +4,6 @@ "outDir": "./dist", "rootDir": "./src" }, - "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] + "include": ["src/**/*.ts"], + "exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"] } \ No newline at end of file