feat(ai): Add cross-provider message handoff support

- Add transformMessages utility to handle cross-provider compatibility
- Convert thinking blocks to <thinking> tagged text when switching providers
- Preserve native thinking blocks when staying with same provider/model
- Add comprehensive handoff tests verifying all provider combinations
- Fix OpenAI Completions to return partial results on abort
- Update tool call ID format for Anthropic compatibility
- Document cross-provider handoff capabilities in README
This commit is contained in:
Mario Zechner 2025-09-01 18:43:49 +02:00
parent bf1f410c2b
commit 46b5800d36
10 changed files with 828 additions and 130 deletions

View file

@ -18,6 +18,7 @@ import type {
ThinkingContent,
ToolCall,
} from "../types.js";
import { transformMessages } from "./utils.js";
export interface AnthropicLLMOptions extends LLMOptions {
thinking?: {
@ -61,7 +62,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
return this.modelInfo;
}
async complete(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = {
role: "assistant",
content: [],
@ -243,7 +244,10 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
private convertMessages(messages: Message[]): MessageParam[] {
const params: MessageParam[] = [];
for (const msg of messages) {
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo);
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {

View file

@ -21,6 +21,7 @@ import type {
Tool,
ToolCall,
} from "../types.js";
import { transformMessages } from "./utils.js";
export interface GoogleLLMOptions extends LLMOptions {
toolChoice?: "auto" | "none" | "any";
@ -51,7 +52,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
return this.modelInfo;
}
async complete(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = {
role: "assistant",
content: [],
@ -223,6 +224,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
}
}
// Finalize last block
if (currentBlock) {
if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
} else {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
}
}
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
return output;
} catch (error) {
@ -236,7 +246,10 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
private convertMessages(messages: Message[]): Content[] {
const contents: Content[] = [];
for (const msg of messages) {
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo);
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {

View file

@ -19,8 +19,8 @@ import type {
ThinkingContent,
Tool,
ToolCall,
Usage,
} from "../types.js";
import { transformMessages } from "./utils.js";
export interface OpenAICompletionsLLMOptions extends LLMOptions {
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
@ -48,7 +48,22 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
return this.modelInfo;
}
async complete(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = {
role: "assistant",
content: [],
provider: this.modelInfo.provider,
model: this.modelInfo.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
};
try {
const messages = this.convertMessages(request.messages, request.systemPrompt);
@ -94,19 +109,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
const blocks: AssistantMessage["content"] = [];
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
let usage: Usage = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
let finishReason: ChatCompletionChunk.Choice["finish_reason"] | null = null;
for await (const chunk of stream) {
if (chunk.usage) {
usage = {
output.usage = {
input: chunk.usage.prompt_tokens || 0,
output:
(chunk.usage.completion_tokens || 0) +
@ -121,11 +127,17 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
total: 0,
},
};
calculateCost(this.modelInfo, output.usage);
}
const choice = chunk.choices[0];
if (!choice) continue;
// Capture finish reason
if (choice.finish_reason) {
output.stopReason = this.mapStopReason(choice.finish_reason);
}
if (choice.delta) {
// Handle text content
if (
@ -144,10 +156,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
}
blocks.push(currentBlock);
}
// Start new text block
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
options?.onEvent?.({ type: "text_start" });
}
// Append to text block
@ -178,10 +190,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
}
blocks.push(currentBlock);
}
// Start new thinking block
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" };
output.content.push(currentBlock);
options?.onEvent?.({ type: "thinking_start" });
}
// Append to thinking block
@ -209,10 +221,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
}
blocks.push(currentBlock);
}
// Start new thinking block
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" };
output.content.push(currentBlock);
options?.onEvent?.({ type: "thinking_start" });
}
// Append to thinking block
@ -243,7 +255,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
}
blocks.push(currentBlock);
}
// Start new tool call block
@ -254,6 +265,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
arguments: {},
partialArgs: "",
};
output.content.push(currentBlock);
}
// Accumulate tool call data
@ -267,11 +279,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
}
}
}
// Capture finish reason
if (choice.finish_reason) {
finishReason = choice.finish_reason;
}
}
// Save final block if exists
@ -285,39 +292,19 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
}
blocks.push(currentBlock);
}
// Calculate cost
calculateCost(this.modelInfo, usage);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
const output = {
role: "assistant",
content: blocks,
provider: this.modelInfo.provider,
model: this.modelInfo.id,
usage,
stopReason: this.mapStopReason(finishReason),
} satisfies AssistantMessage;
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
return output;
} catch (error) {
const output = {
role: "assistant",
content: [],
provider: this.modelInfo.provider,
model: this.modelInfo.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "error",
error: error instanceof Error ? error.message : String(error),
} satisfies AssistantMessage;
options?.onEvent?.({ type: "error", error: output.error || "Unknown error" });
// Update output with error information
output.stopReason = "error";
output.error = error instanceof Error ? error.message : String(error);
options?.onEvent?.({ type: "error", error: output.error });
return output;
}
}
@ -325,6 +312,9 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] {
const params: ChatCompletionMessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo);
// Add system prompt if provided
if (systemPrompt) {
// Cerebras/xAi don't like the "developer" role
@ -337,7 +327,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
}
// Convert messages
for (const msg of messages) {
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {

View file

@ -23,6 +23,7 @@ import type {
Tool,
ToolCall,
} from "../types.js";
import { transformMessages } from "./utils.js";
export interface OpenAIResponsesLLMOptions extends LLMOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high";
@ -50,7 +51,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
return this.modelInfo;
}
async complete(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = {
role: "assistant",
content: [],
@ -132,7 +133,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
lastPart.text += event.delta;
options?.onEvent?.({
type: "thinking_delta",
content: currentItem.summary.join("\n\n"),
content: currentItem.summary.map((s) => s.text).join("\n\n"),
delta: event.delta,
});
}
@ -141,11 +142,16 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
// Add a new line between summary parts (hack...)
else if (event.type === "response.reasoning_summary_part.done") {
if (currentItem && currentItem.type === "reasoning") {
options?.onEvent?.({
type: "thinking_delta",
content: currentItem.summary.join("\n\n"),
delta: "\n\n",
});
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
lastPart.text += "\n\n";
options?.onEvent?.({
type: "thinking_delta",
content: currentItem.summary.map((s) => s.text).join("\n\n"),
delta: "\n\n",
});
}
}
}
// Handle text output deltas
@ -189,7 +195,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
if (item.type === "reasoning") {
outputItems[outputItems.length - 1] = item; // Update with final item
const thinkingContent = item.summary?.map((s: any) => s.text).join("\n\n") || "";
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || "";
options?.onEvent?.({ type: "thinking_end", content: thinkingContent });
} else if (item.type === "message") {
outputItems[outputItems.length - 1] = item; // Update with final item
@ -280,6 +286,9 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput {
const input: ResponseInput = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo);
// Add system prompt if provided
if (systemPrompt) {
const role = this.modelInfo?.reasoning ? "developer" : "system";
@ -290,7 +299,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
}
// Convert messages
for (const msg of messages) {
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {

View file

@ -0,0 +1,54 @@
import type { AssistantMessage, Message, Model } from "../types.js";
/**
* Transform messages for cross-provider compatibility.
*
* - User and toolResult messages are copied verbatim
* - Assistant messages:
* - If from the same provider/model, copied as-is
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
*
* @param messages The messages to transform
* @param model The target model that will process these messages
* @returns A copy of the messages array with transformations applied
*/
export function transformMessages(messages: Message[], model: Model): Message[] {
return messages.map((msg) => {
// User and toolResult messages pass through unchanged
if (msg.role === "user" || msg.role === "toolResult") {
return msg;
}
// Assistant messages need transformation check
if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage;
// If message is from the same provider and model, keep as-is
if (assistantMsg.provider === model.provider && assistantMsg.model === model.id) {
return msg;
}
// Transform message from different provider/model
const transformedContent = assistantMsg.content.map((block) => {
if (block.type === "thinking") {
// Convert thinking block to text block with <thinking> tags
return {
type: "text" as const,
text: `<thinking>\n${block.thinking}\n</thinking>`,
};
}
// All other blocks (text, toolCall) pass through unchanged
return block;
});
// Return transformed assistant message
return {
...assistantMsg,
content: transformedContent,
};
}
// Should not reach here, but return as-is for safety
return msg;
});
}