diff --git a/packages/ai/README.md b/packages/ai/README.md index 747371e3..e2ac6122 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -28,7 +28,7 @@ import { createLLM } from '@mariozechner/pi-ai'; const llm = createLLM('openai', 'gpt-4o-mini'); -const response = await llm.complete({ +const response = await llm.generate({ messages: [{ role: 'user', content: 'Hello!' }] }); @@ -48,7 +48,7 @@ import { readFileSync } from 'fs'; const imageBuffer = readFileSync('image.png'); const base64Image = imageBuffer.toString('base64'); -const response = await llm.complete({ +const response = await llm.generate({ messages: [{ role: 'user', content: [ @@ -77,7 +77,7 @@ const tools = [{ const messages = []; messages.push({ role: 'user', content: 'What is the weather in Paris?' }); -const response = await llm.complete({ messages, tools }); +const response = await llm.generate({ messages, tools }); messages.push(response); // Check for tool calls in the content blocks @@ -99,7 +99,7 @@ for (const call of toolCalls) { if (toolCalls.length > 0) { // Continue conversation with tool results - const followUp = await llm.complete({ messages, tools }); + const followUp = await llm.generate({ messages, tools }); messages.push(followUp); // Print text blocks from the response @@ -114,7 +114,7 @@ if (toolCalls.length > 0) { ## Streaming ```typescript -const response = await llm.complete({ +const response = await llm.generate({ messages: [{ role: 'user', content: 'Write a story' }] }, { onEvent: (event) => { @@ -157,13 +157,17 @@ const response = await llm.complete({ ## Abort Signal +The abort signal allows you to cancel in-progress requests. When aborted, providers return partial results accumulated up to the cancellation point, including accurate token counts and cost estimates. + +### Basic Usage + ```typescript const controller = new AbortController(); // Abort after 2 seconds setTimeout(() => controller.abort(), 2000); -const response = await llm.complete({ +const response = await llm.generate({ messages: [{ role: 'user', content: 'Write a long story' }] }, { signal: controller.signal, @@ -177,18 +181,132 @@ const response = await llm.complete({ // Check if the request was aborted if (response.stopReason === 'error' && response.error) { console.log('Request was aborted:', response.error); + console.log('Partial content received:', response.content); + console.log('Tokens used:', response.usage); } else { console.log('Request completed successfully'); } ``` +### Partial Results and Token Tracking + +When a request is aborted, the API returns an `AssistantMessage` with: +- `stopReason: 'error'` - Indicates the request was aborted +- `error: string` - Error message describing the abort +- `content: array` - **Partial content** accumulated before the abort +- `usage: object` - **Token counts and costs** (may be incomplete depending on when abort occurred) + +```typescript +// Example: User interrupts a long-running request +const controller = new AbortController(); +document.getElementById('stop-button').onclick = () => controller.abort(); + +const response = await llm.generate(context, { + signal: controller.signal, + onEvent: (e) => { + if (e.type === 'text_delta') updateUI(e.delta); + } +}); + +// Even if aborted, you get: +// - Partial text that was streamed +// - Token count (may be partial/estimated) +// - Cost calculations (may be incomplete) +console.log(`Generated ${response.content.length} content blocks`); +console.log(`Estimated ${response.usage.output} output tokens`); +console.log(`Estimated cost: $${response.usage.cost.total}`); +``` + +### Continuing After Abort + +Aborted messages can be added to the conversation context and continued in subsequent requests: + +```typescript +const context = { + messages: [ + { role: 'user', content: 'Explain quantum computing in detail' } + ] +}; + +// First request gets aborted after 2 seconds +const controller1 = new AbortController(); +setTimeout(() => controller1.abort(), 2000); + +const partial = await llm.generate(context, { signal: controller1.signal }); + +// Add the partial response to context +context.messages.push(partial); +context.messages.push({ role: 'user', content: 'Please continue' }); + +// Continue the conversation +const continuation = await llm.generate(context); +``` + +When an aborted message (with `stopReason: 'error'`) is resubmitted in the context: +- **OpenAI Responses**: Filters out thinking blocks and tool calls from aborted messages, as API call will fail if incomplete thinking and tool calls are submitted +- **Anthropic, Google, OpenAI Completions**: Send all blocks as-is (text, thinking, tool calls) + +## Cross-Provider Handoffs + +The library supports seamless handoffs between different LLM providers within the same conversation. This allows you to switch models mid-conversation while preserving context, including thinking blocks, tool calls, and tool results. + +### How It Works + +When messages from one provider are sent to a different provider, the library automatically transforms them for compatibility: + +- **User and tool result messages** are passed through unchanged +- **Assistant messages from the same provider/model** are preserved as-is +- **Assistant messages from different providers** have their thinking blocks converted to text with `` tags +- **Tool calls and regular text** are preserved unchanged + +### Example: Multi-Provider Conversation + +```typescript +import { createLLM } from '@mariozechner/pi-ai'; + +// Start with Claude +const claude = createLLM('anthropic', 'claude-sonnet-4-0'); +const messages = []; + +messages.push({ role: 'user', content: 'What is 25 * 18?' }); +const claudeResponse = await claude.generate({ messages }, { + thinking: { enabled: true } +}); +messages.push(claudeResponse); + +// Switch to GPT-5 - it will see Claude's thinking as tagged text +const gpt5 = createLLM('openai', 'gpt-5-mini'); +messages.push({ role: 'user', content: 'Is that calculation correct?' }); +const gptResponse = await gpt5.generate({ messages }); +messages.push(gptResponse); + +// Switch to Gemini +const gemini = createLLM('google', 'gemini-2.5-flash'); +messages.push({ role: 'user', content: 'What was the original question?' }); +const geminiResponse = await gemini.generate({ messages }); +``` + +### Provider Compatibility + +All providers can handle messages from other providers, including: +- Text content +- Tool calls and tool results +- Thinking/reasoning blocks (transformed to tagged text for cross-provider compatibility) +- Aborted messages with partial content + +This enables flexible workflows where you can: +- Start with a fast model for initial responses +- Switch to a more capable model for complex reasoning +- Use specialized models for specific tasks +- Maintain conversation continuity across provider outages + ## Provider-Specific Options ### OpenAI Reasoning (o1, o3) ```typescript const llm = createLLM('openai', 'o1-mini'); -await llm.complete(context, { +await llm.generate(context, { reasoningEffort: 'medium' // 'minimal' | 'low' | 'medium' | 'high' }); ``` @@ -197,7 +315,7 @@ await llm.complete(context, { ```typescript const llm = createLLM('anthropic', 'claude-3-5-sonnet-20241022'); -await llm.complete(context, { +await llm.generate(context, { thinking: { enabled: true, budgetTokens: 2048 // Optional thinking token limit @@ -209,7 +327,7 @@ await llm.complete(context, { ```typescript const llm = createLLM('google', 'gemini-2.5-pro'); -await llm.complete(context, { +await llm.generate(context, { thinking: { enabled: true } }); ``` diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index 9fc3dad5..e2f47c57 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -18,6 +18,7 @@ import type { ThinkingContent, ToolCall, } from "../types.js"; +import { transformMessages } from "./utils.js"; export interface AnthropicLLMOptions extends LLMOptions { thinking?: { @@ -61,7 +62,7 @@ export class AnthropicLLM implements LLM { return this.modelInfo; } - async complete(context: Context, options?: AnthropicLLMOptions): Promise { + async generate(context: Context, options?: AnthropicLLMOptions): Promise { const output: AssistantMessage = { role: "assistant", content: [], @@ -243,7 +244,10 @@ export class AnthropicLLM implements LLM { private convertMessages(messages: Message[]): MessageParam[] { const params: MessageParam[] = []; - for (const msg of messages) { + // Transform messages for cross-provider compatibility + const transformedMessages = transformMessages(messages, this.modelInfo); + + for (const msg of transformedMessages) { if (msg.role === "user") { // Handle both string and array content if (typeof msg.content === "string") { diff --git a/packages/ai/src/providers/google.ts b/packages/ai/src/providers/google.ts index b3cb1dc9..d5a5761c 100644 --- a/packages/ai/src/providers/google.ts +++ b/packages/ai/src/providers/google.ts @@ -21,6 +21,7 @@ import type { Tool, ToolCall, } from "../types.js"; +import { transformMessages } from "./utils.js"; export interface GoogleLLMOptions extends LLMOptions { toolChoice?: "auto" | "none" | "any"; @@ -51,7 +52,7 @@ export class GoogleLLM implements LLM { return this.modelInfo; } - async complete(context: Context, options?: GoogleLLMOptions): Promise { + async generate(context: Context, options?: GoogleLLMOptions): Promise { const output: AssistantMessage = { role: "assistant", content: [], @@ -223,6 +224,15 @@ export class GoogleLLM implements LLM { } } + // Finalize last block + if (currentBlock) { + if (currentBlock.type === "text") { + options?.onEvent?.({ type: "text_end", content: currentBlock.text }); + } else { + options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); + } + } + options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); return output; } catch (error) { @@ -236,7 +246,10 @@ export class GoogleLLM implements LLM { private convertMessages(messages: Message[]): Content[] { const contents: Content[] = []; - for (const msg of messages) { + // Transform messages for cross-provider compatibility + const transformedMessages = transformMessages(messages, this.modelInfo); + + for (const msg of transformedMessages) { if (msg.role === "user") { // Handle both string and array content if (typeof msg.content === "string") { diff --git a/packages/ai/src/providers/openai-completions.ts b/packages/ai/src/providers/openai-completions.ts index 68ced3c5..0696875a 100644 --- a/packages/ai/src/providers/openai-completions.ts +++ b/packages/ai/src/providers/openai-completions.ts @@ -19,8 +19,8 @@ import type { ThinkingContent, Tool, ToolCall, - Usage, } from "../types.js"; +import { transformMessages } from "./utils.js"; export interface OpenAICompletionsLLMOptions extends LLMOptions { toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } }; @@ -48,7 +48,22 @@ export class OpenAICompletionsLLM implements LLM { return this.modelInfo; } - async complete(request: Context, options?: OpenAICompletionsLLMOptions): Promise { + async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise { + const output: AssistantMessage = { + role: "assistant", + content: [], + provider: this.modelInfo.provider, + model: this.modelInfo.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + }; + try { const messages = this.convertMessages(request.messages, request.systemPrompt); @@ -94,19 +109,10 @@ export class OpenAICompletionsLLM implements LLM { options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider }); - const blocks: AssistantMessage["content"] = []; let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null; - let usage: Usage = { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }; - let finishReason: ChatCompletionChunk.Choice["finish_reason"] | null = null; for await (const chunk of stream) { if (chunk.usage) { - usage = { + output.usage = { input: chunk.usage.prompt_tokens || 0, output: (chunk.usage.completion_tokens || 0) + @@ -121,11 +127,17 @@ export class OpenAICompletionsLLM implements LLM { total: 0, }, }; + calculateCost(this.modelInfo, output.usage); } const choice = chunk.choices[0]; if (!choice) continue; + // Capture finish reason + if (choice.finish_reason) { + output.stopReason = this.mapStopReason(choice.finish_reason); + } + if (choice.delta) { // Handle text content if ( @@ -144,10 +156,10 @@ export class OpenAICompletionsLLM implements LLM { delete currentBlock.partialArgs; options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); } - blocks.push(currentBlock); } // Start new text block currentBlock = { type: "text", text: "" }; + output.content.push(currentBlock); options?.onEvent?.({ type: "text_start" }); } // Append to text block @@ -178,10 +190,10 @@ export class OpenAICompletionsLLM implements LLM { delete currentBlock.partialArgs; options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); } - blocks.push(currentBlock); } // Start new thinking block currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" }; + output.content.push(currentBlock); options?.onEvent?.({ type: "thinking_start" }); } // Append to thinking block @@ -209,10 +221,10 @@ export class OpenAICompletionsLLM implements LLM { delete currentBlock.partialArgs; options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); } - blocks.push(currentBlock); } // Start new thinking block currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" }; + output.content.push(currentBlock); options?.onEvent?.({ type: "thinking_start" }); } // Append to thinking block @@ -243,7 +255,6 @@ export class OpenAICompletionsLLM implements LLM { delete currentBlock.partialArgs; options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); } - blocks.push(currentBlock); } // Start new tool call block @@ -254,6 +265,7 @@ export class OpenAICompletionsLLM implements LLM { arguments: {}, partialArgs: "", }; + output.content.push(currentBlock); } // Accumulate tool call data @@ -267,11 +279,6 @@ export class OpenAICompletionsLLM implements LLM { } } } - - // Capture finish reason - if (choice.finish_reason) { - finishReason = choice.finish_reason; - } } // Save final block if exists @@ -285,39 +292,19 @@ export class OpenAICompletionsLLM implements LLM { delete currentBlock.partialArgs; options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); } - blocks.push(currentBlock); } - // Calculate cost - calculateCost(this.modelInfo, usage); + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } - const output = { - role: "assistant", - content: blocks, - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage, - stopReason: this.mapStopReason(finishReason), - } satisfies AssistantMessage; options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); return output; } catch (error) { - const output = { - role: "assistant", - content: [], - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "error", - error: error instanceof Error ? error.message : String(error), - } satisfies AssistantMessage; - options?.onEvent?.({ type: "error", error: output.error || "Unknown error" }); + // Update output with error information + output.stopReason = "error"; + output.error = error instanceof Error ? error.message : String(error); + options?.onEvent?.({ type: "error", error: output.error }); return output; } } @@ -325,6 +312,9 @@ export class OpenAICompletionsLLM implements LLM { private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] { const params: ChatCompletionMessageParam[] = []; + // Transform messages for cross-provider compatibility + const transformedMessages = transformMessages(messages, this.modelInfo); + // Add system prompt if provided if (systemPrompt) { // Cerebras/xAi don't like the "developer" role @@ -337,7 +327,7 @@ export class OpenAICompletionsLLM implements LLM { } // Convert messages - for (const msg of messages) { + for (const msg of transformedMessages) { if (msg.role === "user") { // Handle both string and array content if (typeof msg.content === "string") { diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index a10cf818..b3afe6ef 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -23,6 +23,7 @@ import type { Tool, ToolCall, } from "../types.js"; +import { transformMessages } from "./utils.js"; export interface OpenAIResponsesLLMOptions extends LLMOptions { reasoningEffort?: "minimal" | "low" | "medium" | "high"; @@ -50,7 +51,7 @@ export class OpenAIResponsesLLM implements LLM { return this.modelInfo; } - async complete(request: Context, options?: OpenAIResponsesLLMOptions): Promise { + async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise { const output: AssistantMessage = { role: "assistant", content: [], @@ -132,7 +133,7 @@ export class OpenAIResponsesLLM implements LLM { lastPart.text += event.delta; options?.onEvent?.({ type: "thinking_delta", - content: currentItem.summary.join("\n\n"), + content: currentItem.summary.map((s) => s.text).join("\n\n"), delta: event.delta, }); } @@ -141,11 +142,16 @@ export class OpenAIResponsesLLM implements LLM { // Add a new line between summary parts (hack...) else if (event.type === "response.reasoning_summary_part.done") { if (currentItem && currentItem.type === "reasoning") { - options?.onEvent?.({ - type: "thinking_delta", - content: currentItem.summary.join("\n\n"), - delta: "\n\n", - }); + currentItem.summary = currentItem.summary || []; + const lastPart = currentItem.summary[currentItem.summary.length - 1]; + if (lastPart) { + lastPart.text += "\n\n"; + options?.onEvent?.({ + type: "thinking_delta", + content: currentItem.summary.map((s) => s.text).join("\n\n"), + delta: "\n\n", + }); + } } } // Handle text output deltas @@ -189,7 +195,7 @@ export class OpenAIResponsesLLM implements LLM { if (item.type === "reasoning") { outputItems[outputItems.length - 1] = item; // Update with final item - const thinkingContent = item.summary?.map((s: any) => s.text).join("\n\n") || ""; + const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || ""; options?.onEvent?.({ type: "thinking_end", content: thinkingContent }); } else if (item.type === "message") { outputItems[outputItems.length - 1] = item; // Update with final item @@ -280,6 +286,9 @@ export class OpenAIResponsesLLM implements LLM { private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput { const input: ResponseInput = []; + // Transform messages for cross-provider compatibility + const transformedMessages = transformMessages(messages, this.modelInfo); + // Add system prompt if provided if (systemPrompt) { const role = this.modelInfo?.reasoning ? "developer" : "system"; @@ -290,7 +299,7 @@ export class OpenAIResponsesLLM implements LLM { } // Convert messages - for (const msg of messages) { + for (const msg of transformedMessages) { if (msg.role === "user") { // Handle both string and array content if (typeof msg.content === "string") { diff --git a/packages/ai/src/providers/utils.ts b/packages/ai/src/providers/utils.ts new file mode 100644 index 00000000..db751fff --- /dev/null +++ b/packages/ai/src/providers/utils.ts @@ -0,0 +1,54 @@ +import type { AssistantMessage, Message, Model } from "../types.js"; + +/** + * Transform messages for cross-provider compatibility. + * + * - User and toolResult messages are copied verbatim + * - Assistant messages: + * - If from the same provider/model, copied as-is + * - If from different provider/model, thinking blocks are converted to text blocks with tags + * + * @param messages The messages to transform + * @param model The target model that will process these messages + * @returns A copy of the messages array with transformations applied + */ +export function transformMessages(messages: Message[], model: Model): Message[] { + return messages.map((msg) => { + // User and toolResult messages pass through unchanged + if (msg.role === "user" || msg.role === "toolResult") { + return msg; + } + + // Assistant messages need transformation check + if (msg.role === "assistant") { + const assistantMsg = msg as AssistantMessage; + + // If message is from the same provider and model, keep as-is + if (assistantMsg.provider === model.provider && assistantMsg.model === model.id) { + return msg; + } + + // Transform message from different provider/model + const transformedContent = assistantMsg.content.map((block) => { + if (block.type === "thinking") { + // Convert thinking block to text block with tags + return { + type: "text" as const, + text: `\n${block.thinking}\n`, + }; + } + // All other blocks (text, toolCall) pass through unchanged + return block; + }); + + // Return transformed assistant message + return { + ...assistantMsg, + content: transformedContent, + }; + } + + // Should not reach here, but return as-is for safety + return msg; + }); +} diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 01227c1c..5ad51b00 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -6,7 +6,7 @@ export interface LLMOptions { } export interface LLM { - complete(request: Context, options?: T): Promise; + generate(request: Context, options?: T): Promise; getModel(): Model; } diff --git a/packages/ai/test/abort.test.ts b/packages/ai/test/abort.test.ts index 2e194090..d55ef412 100644 --- a/packages/ai/test/abort.test.ts +++ b/packages/ai/test/abort.test.ts @@ -6,22 +6,25 @@ import { AnthropicLLM } from "../src/providers/anthropic.js"; import type { LLM, LLMOptions, Context } from "../src/types.js"; import { getModel } from "../src/models.js"; -async function testAbortSignal(llm: LLM, options: T) { - const controller = new AbortController(); - - // Abort after 100ms - setTimeout(() => controller.abort(), 5000); - +async function testAbortSignal(llm: LLM, options: T = {} as T) { const context: Context = { messages: [{ role: "user", - content: "What is 15 + 27? Think step by step. Then list 100 first names." + content: "What is 15 + 27? Think step by step. Then list 50 first names." }] }; - const response = await llm.complete(context, { + let abortFired = false; + const controller = new AbortController(); + const response = await llm.generate(context, { ...options, - signal: controller.signal + signal: controller.signal, + onEvent: (event) => { + // console.log(JSON.stringify(event, null, 2)); + if (abortFired) return; + setTimeout(() => controller.abort(), 2000); + abortFired = true; + } }); // If we get here without throwing, the abort didn't work @@ -29,15 +32,15 @@ async function testAbortSignal(llm: LLM, options: T) { expect(response.content.length).toBeGreaterThan(0); context.messages.push(response); - context.messages.push({ role: "user", content: "Please continue." }); + context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." }); // Ensure we can still make requests after abort - const followUp = await llm.complete(context, options); + const followUp = await llm.generate(context, options); expect(followUp.stopReason).toBe("stop"); expect(followUp.content.length).toBeGreaterThan(0); } -async function testImmediateAbort(llm: LLM, options: T) { +async function testImmediateAbort(llm: LLM, options: T = {} as T) { const controller = new AbortController(); // Abort immediately @@ -47,7 +50,7 @@ async function testImmediateAbort(llm: LLM, options: T) messages: [{ role: "user", content: "Hello" }] }; - const response = await llm.complete(context, { + const response = await llm.generate(context, { ...options, signal: controller.signal }); @@ -75,15 +78,15 @@ describe("AI Providers Abort Tests", () => { let llm: OpenAICompletionsLLM; beforeAll(() => { - llm = new OpenAICompletionsLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!); + llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!); }); it("should abort mid-stream", async () => { - await testAbortSignal(llm, { reasoningEffort: "medium"}); + await testAbortSignal(llm); }); it("should handle immediate abort", async () => { - await testImmediateAbort(llm, { reasoningEffort: "medium" }); + await testImmediateAbort(llm); }); }); diff --git a/packages/ai/test/handoff.test.ts b/packages/ai/test/handoff.test.ts new file mode 100644 index 00000000..6300b76f --- /dev/null +++ b/packages/ai/test/handoff.test.ts @@ -0,0 +1,503 @@ +import { describe, it, expect, beforeAll } from "vitest"; +import { GoogleLLM } from "../src/providers/google.js"; +import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js"; +import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js"; +import { AnthropicLLM } from "../src/providers/anthropic.js"; +import type { LLM, Context, AssistantMessage, Tool, Message } from "../src/types.js"; +import { getModel } from "../src/models.js"; + +// Tool for testing +const weatherTool: Tool = { + name: "get_weather", + description: "Get the weather for a location", + parameters: { + type: "object", + properties: { + location: { type: "string", description: "City name" } + }, + required: ["location"] + } +}; + +// Pre-built contexts representing typical outputs from each provider +const providerContexts = { + // Anthropic-style message with thinking block + anthropic: { + message: { + role: "assistant", + content: [ + { + type: "thinking", + thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391", + thinkingSignature: "signature_abc123" + }, + { + type: "text", + text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you." + }, + { + type: "toolCall", + id: "toolu_01abc123", + name: "get_weather", + arguments: { location: "Tokyo" } + } + ], + provider: "anthropic", + model: "claude-3-5-haiku-latest", + usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + stopReason: "toolUse" + } as AssistantMessage, + toolResult: { + role: "toolResult" as const, + toolCallId: "toolu_01abc123", + toolName: "get_weather", + content: "Weather in Tokyo: 18°C, partly cloudy", + isError: false + }, + facts: { + calculation: 391, + city: "Tokyo", + temperature: 18, + capital: "Vienna" + } + }, + + // Google-style message with thinking + google: { + message: { + role: "assistant", + content: [ + { + type: "thinking", + thinking: "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456", + thinkingSignature: undefined + }, + { + type: "text", + text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you." + }, + { + type: "toolCall", + id: "call_gemini_123", + name: "get_weather", + arguments: { location: "Berlin" } + } + ], + provider: "google", + model: "gemini-2.5-flash", + usage: { input: 120, output: 60, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + stopReason: "toolUse" + } as AssistantMessage, + toolResult: { + role: "toolResult" as const, + toolCallId: "call_gemini_123", + toolName: "get_weather", + content: "Weather in Berlin: 22°C, sunny", + isError: false + }, + facts: { + calculation: 456, + city: "Berlin", + temperature: 22, + capital: "Paris" + } + }, + + // OpenAI Completions style (with reasoning_content) + openaiCompletions: { + message: { + role: "assistant", + content: [ + { + type: "thinking", + thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525", + thinkingSignature: "reasoning_content" + }, + { + type: "text", + text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now." + }, + { + type: "toolCall", + id: "call_abc123", + name: "get_weather", + arguments: { location: "London" } + } + ], + provider: "openai", + model: "gpt-4o-mini", + usage: { input: 110, output: 55, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + stopReason: "toolUse" + } as AssistantMessage, + toolResult: { + role: "toolResult" as const, + toolCallId: "call_abc123", + toolName: "get_weather", + content: "Weather in London: 15°C, rainy", + isError: false + }, + facts: { + calculation: 525, + city: "London", + temperature: 15, + capital: "Madrid" + } + }, + + // OpenAI Responses style (with complex tool call IDs) + openaiResponses: { + message: { + role: "assistant", + content: [ + { + type: "thinking", + thinking: "Calculating 18 * 27: 18 * 27 = 486", + thinkingSignature: '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}' + }, + { + type: "text", + text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.", + textSignature: "msg_response_456" + }, + { + type: "toolCall", + id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only + name: "get_weather", + arguments: { location: "Sydney" } + } + ], + provider: "openai", + model: "gpt-5-mini", + usage: { input: 115, output: 58, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + stopReason: "toolUse" + } as AssistantMessage, + toolResult: { + role: "toolResult" as const, + toolCallId: "call_789_item_012", // Match the updated ID format + toolName: "get_weather", + content: "Weather in Sydney: 25°C, clear", + isError: false + }, + facts: { + calculation: 486, + city: "Sydney", + temperature: 25, + capital: "Rome" + } + }, + + // Aborted message (stopReason: 'error') + aborted: { + message: { + role: "assistant", + content: [ + { + type: "thinking", + thinking: "Let me start calculating 20 * 30...", + thinkingSignature: "partial_sig" + }, + { + type: "text", + text: "I was about to calculate 20 × 30 which is" + } + ], + provider: "test", + model: "test-model", + usage: { input: 50, output: 25, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + stopReason: "error", + error: "Request was aborted" + } as AssistantMessage, + toolResult: null, + facts: { + calculation: 600, + city: "none", + temperature: 0, + capital: "none" + } + } +}; + +/** + * Test that a provider can handle contexts from different sources + */ +async function testProviderHandoff( + targetProvider: LLM, + sourceLabel: string, + sourceContext: typeof providerContexts[keyof typeof providerContexts] +): Promise { + // Build conversation context + const messages: Message[] = [ + { + role: "user", + content: "Please do some calculations, tell me about capitals, and check the weather." + }, + sourceContext.message + ]; + + // Add tool result if present + if (sourceContext.toolResult) { + messages.push(sourceContext.toolResult); + } + + // Ask follow-up question + messages.push({ + role: "user", + content: `Based on our conversation, please answer: + 1) What was the multiplication result? + 2) Which city's weather did we check? + 3) What was the temperature? + 4) What capital city was mentioned? + Please include the specific numbers and names.` + }); + + const context: Context = { + messages, + tools: [weatherTool] + }; + + try { + const response = await targetProvider.generate(context, {}); + + // Check for error + if (response.stopReason === "error") { + console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Failed with error: ${response.error}`); + return false; + } + + // Extract text from response + const responseText = response.content + .filter(b => b.type === "text") + .map(b => b.text) + .join(" ") + .toLowerCase(); + + // For aborted messages, we don't expect to find the facts + if (sourceContext.message.stopReason === "error") { + const hasToolCalls = response.content.some(b => b.type === "toolCall"); + const hasThinking = response.content.some(b => b.type === "thinking"); + const hasText = response.content.some(b => b.type === "text"); + + expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true); + expect(hasThinking || hasText || hasToolCalls).toBe(true); + console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`); + return true; + } + + // Check if response contains our facts + const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString()); + const hasCity = sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase()); + const hasTemperature = sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString()); + const hasCapital = sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase()); + + const success = hasCalculation && hasCity && hasTemperature && hasCapital; + + console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handoff test:`); + if (!success) { + console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? '✓' : '✗'}`); + console.log(` City (${sourceContext.facts.city}): ${hasCity ? '✓' : '✗'}`); + console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? '✓' : '✗'}`); + console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? '✓' : '✗'}`); + } else { + console.log(` ✓ All facts found`); + } + + return success; + } catch (error) { + console.error(`[${sourceLabel} → ${targetProvider.getModel().provider}] Exception:`, error); + return false; + } +} + +describe("Cross-Provider Handoff Tests", () => { + describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => { + let provider: AnthropicLLM; + + beforeAll(() => { + const model = getModel("anthropic", "claude-3-5-haiku-20241022"); + if (model) { + provider = new AnthropicLLM(model, process.env.ANTHROPIC_API_KEY!); + } + }); + + it("should handle contexts from all providers", async () => { + if (!provider) { + console.log("Anthropic provider not available, skipping"); + return; + } + + console.log("\nTesting Anthropic with pre-built contexts:\n"); + + const contextTests = [ + { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, + { label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, + { label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, + { label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" }, + { label: "Aborted", context: providerContexts.aborted, sourceModel: null } + ]; + + let successCount = 0; + let skippedCount = 0; + + for (const { label, context, sourceModel } of contextTests) { + // Skip testing same model against itself + if (sourceModel && sourceModel === provider.getModel().id) { + console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`); + skippedCount++; + continue; + } + const success = await testProviderHandoff(provider, label, context); + if (success) successCount++; + } + + const totalTests = contextTests.length - skippedCount; + console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); + + // All non-skipped handoffs should succeed + expect(successCount).toBe(totalTests); + }); + }); + + describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => { + let provider: GoogleLLM; + + beforeAll(() => { + const model = getModel("google", "gemini-2.5-flash"); + if (model) { + provider = new GoogleLLM(model, process.env.GEMINI_API_KEY!); + } + }); + + it("should handle contexts from all providers", async () => { + if (!provider) { + console.log("Google provider not available, skipping"); + return; + } + + console.log("\nTesting Google with pre-built contexts:\n"); + + const contextTests = [ + { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, + { label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, + { label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, + { label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" }, + { label: "Aborted", context: providerContexts.aborted, sourceModel: null } + ]; + + let successCount = 0; + let skippedCount = 0; + + for (const { label, context, sourceModel } of contextTests) { + // Skip testing same model against itself + if (sourceModel && sourceModel === provider.getModel().id) { + console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`); + skippedCount++; + continue; + } + const success = await testProviderHandoff(provider, label, context); + if (success) successCount++; + } + + const totalTests = contextTests.length - skippedCount; + console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); + + // All non-skipped handoffs should succeed + expect(successCount).toBe(totalTests); + }); + }); + + describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => { + let provider: OpenAICompletionsLLM; + + beforeAll(() => { + const model = getModel("openai", "gpt-4o-mini"); + if (model) { + provider = new OpenAICompletionsLLM(model, process.env.OPENAI_API_KEY!); + } + }); + + it("should handle contexts from all providers", async () => { + if (!provider) { + console.log("OpenAI Completions provider not available, skipping"); + return; + } + + console.log("\nTesting OpenAI Completions with pre-built contexts:\n"); + + const contextTests = [ + { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, + { label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, + { label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, + { label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" }, + { label: "Aborted", context: providerContexts.aborted, sourceModel: null } + ]; + + let successCount = 0; + let skippedCount = 0; + + for (const { label, context, sourceModel } of contextTests) { + // Skip testing same model against itself + if (sourceModel && sourceModel === provider.getModel().id) { + console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`); + skippedCount++; + continue; + } + const success = await testProviderHandoff(provider, label, context); + if (success) successCount++; + } + + const totalTests = contextTests.length - skippedCount; + console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); + + // All non-skipped handoffs should succeed + expect(successCount).toBe(totalTests); + }); + }); + + describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => { + let provider: OpenAIResponsesLLM; + + beforeAll(() => { + const model = getModel("openai", "gpt-5-mini"); + if (model) { + provider = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!); + } + }); + + it("should handle contexts from all providers", async () => { + if (!provider) { + console.log("OpenAI Responses provider not available, skipping"); + return; + } + + console.log("\nTesting OpenAI Responses with pre-built contexts:\n"); + + const contextTests = [ + { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, + { label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, + { label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, + { label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" }, + { label: "Aborted", context: providerContexts.aborted, sourceModel: null } + ]; + + let successCount = 0; + let skippedCount = 0; + + for (const { label, context, sourceModel } of contextTests) { + // Skip testing same model against itself + if (sourceModel && sourceModel === provider.getModel().id) { + console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`); + skippedCount++; + continue; + } + const success = await testProviderHandoff(provider, label, context); + if (success) successCount++; + } + + const totalTests = contextTests.length - skippedCount; + console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); + + // All non-skipped handoffs should succeed + expect(successCount).toBe(totalTests); + }); + }); +}); \ No newline at end of file diff --git a/packages/ai/test/providers.test.ts b/packages/ai/test/providers.test.ts index 25dd4fb1..36e20b91 100644 --- a/packages/ai/test/providers.test.ts +++ b/packages/ai/test/providers.test.ts @@ -40,11 +40,11 @@ async function basicTextGeneration(llm: LLM) { ] }; - const response = await llm.complete(context); + const response = await llm.generate(context); expect(response.role).toBe("assistant"); expect(response.content).toBeTruthy(); - expect(response.usage.input).toBeGreaterThan(0); + expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0); expect(response.usage.output).toBeGreaterThan(0); expect(response.error).toBeFalsy(); expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful"); @@ -52,7 +52,7 @@ async function basicTextGeneration(llm: LLM) { context.messages.push(response); context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" }); - const secondResponse = await llm.complete(context); + const secondResponse = await llm.generate(context); expect(secondResponse.role).toBe("assistant"); expect(secondResponse.content).toBeTruthy(); @@ -72,7 +72,7 @@ async function handleToolCall(llm: LLM) { tools: [calculatorTool] }; - const response = await llm.complete(context); + const response = await llm.generate(context); expect(response.stopReason).toBe("toolUse"); expect(response.content.some(b => b.type == "toolCall")).toBeTruthy(); const toolCall = response.content.find(b => b.type == "toolCall")!; @@ -89,7 +89,7 @@ async function handleStreaming(llm: LLM) { messages: [{ role: "user", content: "Count from 1 to 3" }] }; - const response = await llm.complete(context, { + const response = await llm.generate(context, { onEvent: (event) => { if (event.type === "text_start") { textStarted = true; @@ -113,14 +113,15 @@ async function handleThinking(llm: LLM, options: T) { let thinkingCompleted = false; const context: Context = { - messages: [{ role: "user", content: "What is 15 + 27? Think step by step." }] + messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }] }; - const response = await llm.complete(context, { + const response = await llm.generate(context, { onEvent: (event) => { if (event.type === "thinking_start") { thinkingStarted = true; } else if (event.type === "thinking_delta") { + expect(event.content.endsWith(event.delta)).toBe(true); thinkingChunks += event.delta; } else if (event.type === "thinking_end") { thinkingCompleted = true; @@ -130,6 +131,7 @@ async function handleThinking(llm: LLM, options: T) { }); + expect(response.stopReason, `Error: ${(response as any).error}`).toBe("stop"); expect(thinkingStarted).toBe(true); expect(thinkingChunks.length).toBeGreaterThan(0); expect(thinkingCompleted).toBe(true); @@ -160,14 +162,14 @@ async function handleImage(llm: LLM) { { role: "user", content: [ - { type: "text", text: "What do you see in this image? Please describe the shape and color." }, + { type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." }, imageContent, ], }, ], }; - const response = await llm.complete(context); + const response = await llm.generate(context); // Check the response mentions red and circle expect(response.content.length > 0).toBeTruthy(); @@ -195,7 +197,7 @@ async function multiTurn(llm: LLM, thinkingOptions: T) const maxTurns = 5; // Prevent infinite loops for (let turn = 0; turn < maxTurns; turn++) { - const response = await llm.complete(context, thinkingOptions); + const response = await llm.generate(context, thinkingOptions); // Add the assistant response to context context.messages.push(response); @@ -325,12 +327,12 @@ describe("AI Providers E2E Tests", () => { await handleStreaming(llm); }); - it("should handle thinking mode", async () => { - await handleThinking(llm, {reasoningEffort: "medium"}); + it("should handle thinking mode", {retry: 2}, async () => { + await handleThinking(llm, {reasoningEffort: "high"}); }); it("should handle multi-turn with thinking and tools", async () => { - await multiTurn(llm, {reasoningEffort: "medium"}); + await multiTurn(llm, {reasoningEffort: "high"}); }); it("should handle image input", async () => { @@ -370,34 +372,6 @@ describe("AI Providers E2E Tests", () => { }); }); - describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => { - let llm: AnthropicLLM; - - beforeAll(() => { - llm = createLLM("anthropic", "claude-3-5-haiku-latest"); - }); - - it("should complete basic text generation", async () => { - await basicTextGeneration(llm); - }); - - it("should handle tool calling", async () => { - await handleToolCall(llm); - }); - - it("should handle streaming", async () => { - await handleStreaming(llm); - }); - - it("should handle multi-turn with thinking and tools", async () => { - await multiTurn(llm, {thinking: {enabled: true}}); - }); - - it("should handle image input", async () => { - await handleImage(llm); - }); - }); - describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => { let llm: OpenAICompletionsLLM; @@ -505,7 +479,7 @@ describe("AI Providers E2E Tests", () => { await handleThinking(llm, {reasoningEffort: "medium"}); }); - it("should handle multi-turn with thinking and tools", async () => { + it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => { await multiTurn(llm, {reasoningEffort: "medium"}); }); @@ -611,4 +585,34 @@ describe("AI Providers E2E Tests", () => { await multiTurn(llm, {reasoningEffort: "medium"}); }); }); + + /* + describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => { + let llm: AnthropicLLM; + + beforeAll(() => { + llm = createLLM("anthropic", "claude-3-5-haiku-latest"); + }); + + it("should complete basic text generation", async () => { + await basicTextGeneration(llm); + }); + + it("should handle tool calling", async () => { + await handleToolCall(llm); + }); + + it("should handle streaming", async () => { + await handleStreaming(llm); + }); + + it("should handle multi-turn with thinking and tools", async () => { + await multiTurn(llm, {thinking: {enabled: true}}); + }); + + it("should handle image input", async () => { + await handleImage(llm); + }); + }); + */ }); \ No newline at end of file