From bf1f410c2bb6ef295635f3e815e10d8c231bf8cd Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Mon, 1 Sep 2025 01:57:08 +0200 Subject: [PATCH] refactor(ai): Update API to support partial results on abort - Anthropic, Google, and OpenAI Responses providers now return partial results when aborted - Restructured streaming to accumulate content blocks incrementally - Prevents submission of thinking/toolCall blocks from aborted completions in multi-turn conversations - Makes UI development easier by providing partial content even when requests are interrupted --- packages/ai/src/providers/anthropic.ts | 182 +++++++--------- packages/ai/src/providers/google.ts | 94 +++----- packages/ai/src/providers/openai-responses.ts | 203 +++++++++--------- packages/ai/test/abort.test.ts | 45 ++-- 4 files changed, 244 insertions(+), 280 deletions(-) diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index 7705ea30..9fc3dad5 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -14,8 +14,9 @@ import type { Message, Model, StopReason, + TextContent, + ThinkingContent, ToolCall, - Usage, } from "../types.js"; export interface AnthropicLLMOptions extends LLMOptions { @@ -61,6 +62,21 @@ export class AnthropicLLM implements LLM { } async complete(context: Context, options?: AnthropicLLMOptions): Promise { + const output: AssistantMessage = { + role: "assistant", + content: [], + provider: this.modelInfo.provider, + model: this.modelInfo.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + }; + try { const messages = this.convertMessages(context.messages); @@ -131,132 +147,94 @@ export class AnthropicLLM implements LLM { options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider }); - let blockType: "text" | "thinking" | "toolUse" | "other" = "other"; - let blockContent = ""; - let toolCall: (ToolCall & { partialJson: string }) | null = null; + let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null; for await (const event of stream) { if (event.type === "content_block_start") { if (event.content_block.type === "text") { - blockType = "text"; - blockContent = ""; + currentBlock = { + type: "text", + text: "", + }; + output.content.push(currentBlock); options?.onEvent?.({ type: "text_start" }); } else if (event.content_block.type === "thinking") { - blockType = "thinking"; - blockContent = ""; + currentBlock = { + type: "thinking", + thinking: "", + thinkingSignature: "", + }; + output.content.push(currentBlock); options?.onEvent?.({ type: "thinking_start" }); } else if (event.content_block.type === "tool_use") { // We wait for the full tool use to be streamed to send the event - toolCall = { + currentBlock = { type: "toolCall", id: event.content_block.id, name: event.content_block.name, arguments: event.content_block.input as Record, partialJson: "", }; - blockType = "toolUse"; - blockContent = ""; - } else { - blockType = "other"; - blockContent = ""; } - } - if (event.type === "content_block_delta") { + } else if (event.type === "content_block_delta") { if (event.delta.type === "text_delta") { - options?.onEvent?.({ type: "text_delta", content: blockContent, delta: event.delta.text }); - blockContent += event.delta.text; + if (currentBlock && currentBlock.type === "text") { + currentBlock.text += event.delta.text; + options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: event.delta.text }); + } + } else if (event.delta.type === "thinking_delta") { + if (currentBlock && currentBlock.type === "thinking") { + currentBlock.thinking += event.delta.thinking; + options?.onEvent?.({ + type: "thinking_delta", + content: currentBlock.thinking, + delta: event.delta.thinking, + }); + } + } else if (event.delta.type === "input_json_delta") { + if (currentBlock && currentBlock.type === "toolCall") { + currentBlock.partialJson += event.delta.partial_json; + } + } else if (event.delta.type === "signature_delta") { + if (currentBlock && currentBlock.type === "thinking") { + currentBlock.thinkingSignature = currentBlock.thinkingSignature || ""; + currentBlock.thinkingSignature += event.delta.signature; + } } - if (event.delta.type === "thinking_delta") { - options?.onEvent?.({ type: "thinking_delta", content: blockContent, delta: event.delta.thinking }); - blockContent += event.delta.thinking; + } else if (event.type === "content_block_stop") { + if (currentBlock) { + if (currentBlock.type === "text") { + options?.onEvent?.({ type: "text_end", content: currentBlock.text }); + } else if (currentBlock.type === "thinking") { + options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); + } else if (currentBlock.type === "toolCall") { + const finalToolCall: ToolCall = { + type: "toolCall", + id: currentBlock.id, + name: currentBlock.name, + arguments: JSON.parse(currentBlock.partialJson), + }; + output.content.push(finalToolCall); + options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall }); + } + currentBlock = null; } - if (event.delta.type === "input_json_delta") { - toolCall!.partialJson += event.delta.partial_json; + } else if (event.type === "message_delta") { + if (event.delta.stop_reason) { + output.stopReason = this.mapStopReason(event.delta.stop_reason); } - } - if (event.type === "content_block_stop") { - if (blockType === "text") { - options?.onEvent?.({ type: "text_end", content: blockContent }); - } else if (blockType === "thinking") { - options?.onEvent?.({ type: "thinking_end", content: blockContent }); - } else if (blockType === "toolUse") { - const finalToolCall: ToolCall = { - type: "toolCall", - id: toolCall!.id, - name: toolCall!.name, - arguments: toolCall!.partialJson ? JSON.parse(toolCall!.partialJson) : toolCall!.arguments, - }; - toolCall = null; - options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall }); - } - blockType = "other"; - } - } - const msg = await stream.finalMessage(); - const blocks: AssistantMessage["content"] = []; - for (const block of msg.content) { - if (block.type === "text" && block.text) { - blocks.push({ - type: "text", - text: block.text, - }); - } else if (block.type === "thinking" && block.thinking) { - blocks.push({ - type: "thinking", - thinking: block.thinking, - thinkingSignature: block.signature, - }); - } else if (block.type === "tool_use") { - blocks.push({ - type: "toolCall", - id: block.id, - name: block.name, - arguments: block.input as Record, - }); + output.usage.input += event.usage.input_tokens || 0; + output.usage.output += event.usage.output_tokens || 0; + output.usage.cacheRead += event.usage.cache_read_input_tokens || 0; + output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0; + calculateCost(this.modelInfo, output.usage); } } - const usage: Usage = { - input: msg.usage.input_tokens, - output: msg.usage.output_tokens, - cacheRead: msg.usage.cache_read_input_tokens || 0, - cacheWrite: msg.usage.cache_creation_input_tokens || 0, - cost: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - total: 0, - }, - }; - calculateCost(this.modelInfo, usage); - - const output = { - role: "assistant", - content: blocks, - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage, - stopReason: this.mapStopReason(msg.stop_reason), - } satisfies AssistantMessage; options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); - return output; } catch (error) { - const output = { - role: "assistant", - content: [], - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "error", - error: error instanceof Error ? error.message : JSON.stringify(error), - } satisfies AssistantMessage; + output.stopReason = "error"; + output.error = error instanceof Error ? error.message : JSON.stringify(error); options?.onEvent?.({ type: "error", error: output.error }); return output; } diff --git a/packages/ai/src/providers/google.ts b/packages/ai/src/providers/google.ts index 1876f22c..b3cb1dc9 100644 --- a/packages/ai/src/providers/google.ts +++ b/packages/ai/src/providers/google.ts @@ -20,7 +20,6 @@ import type { ThinkingContent, Tool, ToolCall, - Usage, } from "../types.js"; export interface GoogleLLMOptions extends LLMOptions { @@ -33,7 +32,7 @@ export interface GoogleLLMOptions extends LLMOptions { export class GoogleLLM implements LLM { private client: GoogleGenAI; - private model: Model; + private modelInfo: Model; constructor(model: Model, apiKey?: string) { if (!apiKey) { @@ -45,14 +44,28 @@ export class GoogleLLM implements LLM { apiKey = process.env.GEMINI_API_KEY; } this.client = new GoogleGenAI({ apiKey }); - this.model = model; + this.modelInfo = model; } getModel(): Model { - return this.model; + return this.modelInfo; } async complete(context: Context, options?: GoogleLLMOptions): Promise { + const output: AssistantMessage = { + role: "assistant", + content: [], + provider: this.modelInfo.provider, + model: this.modelInfo.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + }; try { const contents = this.convertMessages(context.messages); @@ -82,7 +95,7 @@ export class GoogleLLM implements LLM { } // Add thinking config if enabled and model supports it - if (options?.thinking?.enabled && this.model.reasoning) { + if (options?.thinking?.enabled && this.modelInfo.reasoning) { config.thinkingConfig = { includeThoughts: true, ...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }), @@ -99,27 +112,15 @@ export class GoogleLLM implements LLM { // Build the request parameters const params: GenerateContentParameters = { - model: this.model.id, + model: this.modelInfo.id, contents, config, }; const stream = await this.client.models.generateContentStream(params); - options?.onEvent?.({ type: "start", model: this.model.id, provider: this.model.provider }); - - const blocks: AssistantMessage["content"] = []; + options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider }); let currentBlock: TextContent | ThinkingContent | null = null; - let usage: Usage = { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }; - let stopReason: StopReason = "stop"; - - // Process the stream for await (const chunk of stream) { // Extract parts from the chunk const candidate = chunk.candidates?.[0]; @@ -134,14 +135,12 @@ export class GoogleLLM implements LLM { (isThinking && currentBlock.type !== "thinking") || (!isThinking && currentBlock.type !== "text") ) { - // Save and finalize current block if (currentBlock) { if (currentBlock.type === "text") { options?.onEvent?.({ type: "text_end", content: currentBlock.text }); } else { options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); } - blocks.push(currentBlock); } // Start new block @@ -152,6 +151,7 @@ export class GoogleLLM implements LLM { currentBlock = { type: "text", text: "" }; options?.onEvent?.({ type: "text_start" }); } + output.content.push(currentBlock); } // Append content to current block @@ -171,14 +171,12 @@ export class GoogleLLM implements LLM { // Handle function calls if (part.functionCall) { - // Save current block if exists if (currentBlock) { if (currentBlock.type === "text") { options?.onEvent?.({ type: "text_end", content: currentBlock.text }); } else { options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); } - blocks.push(currentBlock); currentBlock = null; } @@ -190,7 +188,7 @@ export class GoogleLLM implements LLM { name: part.functionCall.name || "", arguments: part.functionCall.args as Record, }; - blocks.push(toolCall); + output.content.push(toolCall); options?.onEvent?.({ type: "toolCall", toolCall }); } } @@ -198,16 +196,16 @@ export class GoogleLLM implements LLM { // Map finish reason if (candidate?.finishReason) { - stopReason = this.mapStopReason(candidate.finishReason); + output.stopReason = this.mapStopReason(candidate.finishReason); // Check if we have tool calls in blocks - if (blocks.some((b) => b.type === "toolCall")) { - stopReason = "toolUse"; + if (output.content.some((b) => b.type === "toolCall")) { + output.stopReason = "toolUse"; } } // Capture usage metadata if available if (chunk.usageMetadata) { - usage = { + output.usage = { input: chunk.usageMetadata.promptTokenCount || 0, output: (chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0), @@ -221,47 +219,15 @@ export class GoogleLLM implements LLM { total: 0, }, }; + calculateCost(this.modelInfo, output.usage); } } - // Save final block if exists - if (currentBlock) { - if (currentBlock.type === "text") { - options?.onEvent?.({ type: "text_end", content: currentBlock.text }); - } else { - options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); - } - blocks.push(currentBlock); - } - - calculateCost(this.model, usage); - - const output = { - role: "assistant", - content: blocks, - provider: this.model.provider, - model: this.model.id, - usage, - stopReason, - } satisfies AssistantMessage; - options?.onEvent?.({ type: "done", reason: stopReason, message: output }); + options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); return output; } catch (error) { - const output = { - role: "assistant", - content: [], - provider: this.model.provider, - model: this.model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "error", - error: error instanceof Error ? error.message : JSON.stringify(error), - } satisfies AssistantMessage; + output.stopReason = "error"; + output.error = error instanceof Error ? error.message : JSON.stringify(error); options?.onEvent?.({ type: "error", error: output.error }); return output; } diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index fc0e1dcf..a10cf818 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -22,7 +22,6 @@ import type { TextContent, Tool, ToolCall, - Usage, } from "../types.js"; export interface OpenAIResponsesLLMOptions extends LLMOptions { @@ -52,6 +51,20 @@ export class OpenAIResponsesLLM implements LLM { } async complete(request: Context, options?: OpenAIResponsesLLMOptions): Promise { + const output: AssistantMessage = { + role: "assistant", + content: [], + provider: this.modelInfo.provider, + model: this.modelInfo.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + }; try { const input = this.convertToInput(request.messages, request.systemPrompt); @@ -88,17 +101,8 @@ export class OpenAIResponsesLLM implements LLM { options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider }); - const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = []; // any for function_call items - let currentTextAccum = ""; // For delta accumulation - let currentThinkingAccum = ""; // For delta accumulation - let usage: Usage = { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }; - let stopReason: StopReason = "stop"; + const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = []; + let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null; for await (const event of stream) { // Handle output item start @@ -106,47 +110,91 @@ export class OpenAIResponsesLLM implements LLM { const item = event.item; if (item.type === "reasoning") { options?.onEvent?.({ type: "thinking_start" }); - currentThinkingAccum = ""; + outputItems.push(item); + currentItem = item; } else if (item.type === "message") { options?.onEvent?.({ type: "text_start" }); - currentTextAccum = ""; + outputItems.push(item); + currentItem = item; } } // Handle reasoning summary deltas - else if (event.type === "response.reasoning_summary_text.delta") { - const delta = event.delta; - currentThinkingAccum += delta; - options?.onEvent?.({ type: "thinking_delta", content: currentThinkingAccum, delta }); + else if (event.type === "response.reasoning_summary_part.added") { + if (currentItem && currentItem.type === "reasoning") { + currentItem.summary = currentItem.summary || []; + currentItem.summary.push(event.part); + } + } else if (event.type === "response.reasoning_summary_text.delta") { + if (currentItem && currentItem.type === "reasoning") { + currentItem.summary = currentItem.summary || []; + const lastPart = currentItem.summary[currentItem.summary.length - 1]; + if (lastPart) { + lastPart.text += event.delta; + options?.onEvent?.({ + type: "thinking_delta", + content: currentItem.summary.join("\n\n"), + delta: event.delta, + }); + } + } } // Add a new line between summary parts (hack...) else if (event.type === "response.reasoning_summary_part.done") { - currentThinkingAccum += "\n\n"; - options?.onEvent?.({ type: "thinking_delta", content: currentThinkingAccum, delta: "\n\n" }); + if (currentItem && currentItem.type === "reasoning") { + options?.onEvent?.({ + type: "thinking_delta", + content: currentItem.summary.join("\n\n"), + delta: "\n\n", + }); + } } // Handle text output deltas - else if (event.type === "response.output_text.delta") { - const delta = event.delta; - currentTextAccum += delta; - options?.onEvent?.({ type: "text_delta", content: currentTextAccum, delta }); - } - // Handle refusal output deltas - else if (event.type === "response.refusal.delta") { - const delta = event.delta; - currentTextAccum += delta; - options?.onEvent?.({ type: "text_delta", content: currentTextAccum, delta }); + else if (event.type === "response.content_part.added") { + if (currentItem && currentItem.type === "message") { + currentItem.content = currentItem.content || []; + currentItem.content.push(event.part); + } + } else if (event.type === "response.output_text.delta") { + if (currentItem && currentItem.type === "message") { + const lastPart = currentItem.content[currentItem.content.length - 1]; + if (lastPart && lastPart.type === "output_text") { + lastPart.text += event.delta; + options?.onEvent?.({ + type: "text_delta", + content: currentItem.content + .map((c) => (c.type === "output_text" ? c.text : c.refusal)) + .join(""), + delta: event.delta, + }); + } + } + } else if (event.type === "response.refusal.delta") { + if (currentItem && currentItem.type === "message") { + const lastPart = currentItem.content[currentItem.content.length - 1]; + if (lastPart && lastPart.type === "refusal") { + lastPart.refusal += event.delta; + options?.onEvent?.({ + type: "text_delta", + content: currentItem.content + .map((c) => (c.type === "output_text" ? c.text : c.refusal)) + .join(""), + delta: event.delta, + }); + } + } } // Handle output item completion else if (event.type === "response.output_item.done") { const item = event.item; if (item.type === "reasoning") { + outputItems[outputItems.length - 1] = item; // Update with final item const thinkingContent = item.summary?.map((s: any) => s.text).join("\n\n") || ""; options?.onEvent?.({ type: "thinking_end", content: thinkingContent }); - outputItems.push(item); } else if (item.type === "message") { + outputItems[outputItems.length - 1] = item; // Update with final item const textContent = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""); options?.onEvent?.({ type: "text_end", content: textContent }); - outputItems.push(item); } else if (item.type === "function_call") { const toolCall: ToolCall = { type: "toolCall", @@ -162,7 +210,7 @@ export class OpenAIResponsesLLM implements LLM { else if (event.type === "response.completed") { const response = event.response; if (response?.usage) { - usage = { + output.usage = { input: response.usage.input_tokens || 0, output: response.usage.output_tokens || 0, cacheRead: response.usage.input_tokens_details?.cached_tokens || 0, @@ -170,60 +218,43 @@ export class OpenAIResponsesLLM implements LLM { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }; } - + calculateCost(this.modelInfo, output.usage); // Map status to stop reason - stopReason = this.mapStopReason(response?.status); + output.stopReason = this.mapStopReason(response?.status); + if (outputItems.some((b) => b.type === "function_call") && output.stopReason === "stop") { + output.stopReason = "toolUse"; + } } // Handle errors else if (event.type === "error") { - const errorOutput = { - role: "assistant", - content: [], - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage, - stopReason: "error", - error: `Code ${event.code}: ${event.message}` || "Unknown error", - } satisfies AssistantMessage; - options?.onEvent?.({ type: "error", error: errorOutput.error || "Unknown error" }); - return errorOutput; + output.stopReason = "error"; + output.error = `Code ${event.code}: ${event.message}` || "Unknown error"; + options?.onEvent?.({ type: "error", error: output.error }); + return output; } else if (event.type === "response.failed") { - const errorOutput = { - role: "assistant", - content: [], - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage, - stopReason: "error", - error: "Unknown error", - } satisfies AssistantMessage; - options?.onEvent?.({ type: "error", error: errorOutput.error || "Unknown error" }); - return errorOutput; + output.stopReason = "error"; + output.error = "Unknown error"; + options?.onEvent?.({ type: "error", error: output.error }); + return output; } } - if (options?.signal?.aborted) { - throw new Error("Request was aborted"); - } - // Convert output items to blocks - const blocks: AssistantMessage["content"] = []; - for (const item of outputItems) { if (item.type === "reasoning") { - blocks.push({ + output.content.push({ type: "thinking", thinking: item.summary?.map((s: any) => s.text).join("\n\n") || "", thinkingSignature: JSON.stringify(item), // Full item for resubmission }); } else if (item.type === "message") { - blocks.push({ + output.content.push({ type: "text", text: item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""), textSignature: item.id, // ID for resubmission }); } else if (item.type === "function_call") { - blocks.push({ + output.content.push({ type: "toolCall", id: item.call_id + "|" + item.id, name: item.name, @@ -232,40 +263,16 @@ export class OpenAIResponsesLLM implements LLM { } } - // Check if we have tool calls for stop reason - if (blocks.some((b) => b.type === "toolCall") && stopReason === "stop") { - stopReason = "toolUse"; + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); } - calculateCost(this.modelInfo, usage); - - const output = { - role: "assistant", - content: blocks, - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage, - stopReason, - } satisfies AssistantMessage; options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); return output; } catch (error) { - const output = { - role: "assistant", - content: [], - provider: this.modelInfo.provider, - model: this.modelInfo.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "error", - error: error instanceof Error ? error.message : String(error), - } satisfies AssistantMessage; - options?.onEvent?.({ type: "error", error: output.error || "Unknown error" }); + output.stopReason = "error"; + output.error = error instanceof Error ? error.message : JSON.stringify(error); + options?.onEvent?.({ type: "error", error: output.error }); return output; } } @@ -318,7 +325,8 @@ export class OpenAIResponsesLLM implements LLM { const output: ResponseInput = []; for (const block of msg.content) { - if (block.type === "thinking") { + // Do not submit thinking blocks if the completion had an error (i.e. abort) + if (block.type === "thinking" && msg.stopReason !== "error") { // Push the full reasoning item(s) from signature if (block.thinkingSignature) { const reasoningItem = JSON.parse(block.thinkingSignature); @@ -333,7 +341,8 @@ export class OpenAIResponsesLLM implements LLM { status: "completed", id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15), } satisfies ResponseOutputMessage); - } else if (block.type === "toolCall") { + // Do not submit thinking blocks if the completion had an error (i.e. abort) + } else if (block.type === "toolCall" && msg.stopReason !== "error") { const toolCall = block as ToolCall; output.push({ type: "function_call", diff --git a/packages/ai/test/abort.test.ts b/packages/ai/test/abort.test.ts index ab4e32f8..2e194090 100644 --- a/packages/ai/test/abort.test.ts +++ b/packages/ai/test/abort.test.ts @@ -6,28 +6,38 @@ import { AnthropicLLM } from "../src/providers/anthropic.js"; import type { LLM, LLMOptions, Context } from "../src/types.js"; import { getModel } from "../src/models.js"; -async function testAbortSignal(llm: LLM) { +async function testAbortSignal(llm: LLM, options: T) { const controller = new AbortController(); // Abort after 100ms - setTimeout(() => controller.abort(), 1000); + setTimeout(() => controller.abort(), 5000); const context: Context = { messages: [{ role: "user", - content: "Write a very long story about a dragon that lives in a mountain. Include lots of details about the dragon's appearance, its daily life, the treasures it guards, and its interactions with nearby villages. Make it at least 1000 words long." + content: "What is 15 + 27? Think step by step. Then list 100 first names." }] }; const response = await llm.complete(context, { + ...options, signal: controller.signal - } as T); + }); // If we get here without throwing, the abort didn't work expect(response.stopReason).toBe("error"); + expect(response.content.length).toBeGreaterThan(0); + + context.messages.push(response); + context.messages.push({ role: "user", content: "Please continue." }); + + // Ensure we can still make requests after abort + const followUp = await llm.complete(context, options); + expect(followUp.stopReason).toBe("stop"); + expect(followUp.content.length).toBeGreaterThan(0); } -async function testImmediateAbort(llm: LLM) { +async function testImmediateAbort(llm: LLM, options: T) { const controller = new AbortController(); // Abort immediately @@ -38,8 +48,9 @@ async function testImmediateAbort(llm: LLM) { }; const response = await llm.complete(context, { + ...options, signal: controller.signal - } as T); + }); expect(response.stopReason).toBe("error"); } @@ -52,11 +63,11 @@ describe("AI Providers Abort Tests", () => { }); it("should abort mid-stream", async () => { - await testAbortSignal(llm); + await testAbortSignal(llm, { thinking: { enabled: true } }); }); it("should handle immediate abort", async () => { - await testImmediateAbort(llm); + await testImmediateAbort(llm, { thinking: { enabled: true } }); }); }); @@ -64,15 +75,15 @@ describe("AI Providers Abort Tests", () => { let llm: OpenAICompletionsLLM; beforeAll(() => { - llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!); + llm = new OpenAICompletionsLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!); }); it("should abort mid-stream", async () => { - await testAbortSignal(llm); + await testAbortSignal(llm, { reasoningEffort: "medium"}); }); it("should handle immediate abort", async () => { - await testImmediateAbort(llm); + await testImmediateAbort(llm, { reasoningEffort: "medium" }); }); }); @@ -88,27 +99,27 @@ describe("AI Providers Abort Tests", () => { }); it("should abort mid-stream", async () => { - await testAbortSignal(llm); + await testAbortSignal(llm, {}); }); it("should handle immediate abort", async () => { - await testImmediateAbort(llm); + await testImmediateAbort(llm, {}); }); }); - describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Abort", () => { + describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => { let llm: AnthropicLLM; beforeAll(() => { - llm = new AnthropicLLM(getModel("anthropic", "claude-3-5-haiku-latest")!, process.env.ANTHROPIC_API_KEY!); + llm = new AnthropicLLM(getModel("anthropic", "claude-opus-4-1")!, process.env.ANTHROPIC_OAUTH_TOKEN!); }); it("should abort mid-stream", async () => { - await testAbortSignal(llm); + await testAbortSignal(llm, {thinking: { enabled: true, budgetTokens: 2048 }}); }); it("should handle immediate abort", async () => { - await testImmediateAbort(llm); + await testImmediateAbort(llm, {thinking: { enabled: true, budgetTokens: 2048 }}); }); }); }); \ No newline at end of file