mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 15:01:26 +00:00
feat(ai): Add cross-provider message handoff support
- Add transformMessages utility to handle cross-provider compatibility - Convert thinking blocks to <thinking> tagged text when switching providers - Preserve native thinking blocks when staying with same provider/model - Add comprehensive handoff tests verifying all provider combinations - Fix OpenAI Completions to return partial results on abort - Update tool call ID format for Anthropic compatibility - Document cross-provider handoff capabilities in README
This commit is contained in:
parent
bf1f410c2b
commit
46b5800d36
10 changed files with 828 additions and 130 deletions
|
|
@ -18,6 +18,7 @@ import type {
|
|||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface AnthropicLLMOptions extends LLMOptions {
|
||||
thinking?: {
|
||||
|
|
@ -61,7 +62,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
|
|
@ -243,7 +244,10 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
private convertMessages(messages: Message[]): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import type {
|
|||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface GoogleLLMOptions extends LLMOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
|
|
@ -51,7 +52,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
|
|
@ -223,6 +224,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
}
|
||||
}
|
||||
|
||||
// Finalize last block
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
}
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
} catch (error) {
|
||||
|
|
@ -236,7 +246,10 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
private convertMessages(messages: Message[]): Content[] {
|
||||
const contents: Content[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ import type {
|
|||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAICompletionsLLMOptions extends LLMOptions {
|
||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||
|
|
@ -48,7 +48,22 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
const messages = this.convertMessages(request.messages, request.systemPrompt);
|
||||
|
||||
|
|
@ -94,19 +109,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
|
||||
const blocks: AssistantMessage["content"] = [];
|
||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||
let usage: Usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
let finishReason: ChatCompletionChunk.Choice["finish_reason"] | null = null;
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.usage) {
|
||||
usage = {
|
||||
output.usage = {
|
||||
input: chunk.usage.prompt_tokens || 0,
|
||||
output:
|
||||
(chunk.usage.completion_tokens || 0) +
|
||||
|
|
@ -121,11 +127,17 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
// Capture finish reason
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = this.mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
// Handle text content
|
||||
if (
|
||||
|
|
@ -144,10 +156,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
// Start new text block
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
}
|
||||
// Append to text block
|
||||
|
|
@ -178,10 +190,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" };
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
|
@ -209,10 +221,10 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" };
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
|
@ -243,7 +255,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
|
||||
// Start new tool call block
|
||||
|
|
@ -254,6 +265,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Accumulate tool call data
|
||||
|
|
@ -267,11 +279,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Capture finish reason
|
||||
if (choice.finish_reason) {
|
||||
finishReason = choice.finish_reason;
|
||||
}
|
||||
}
|
||||
|
||||
// Save final block if exists
|
||||
|
|
@ -285,39 +292,19 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
}
|
||||
blocks.push(currentBlock);
|
||||
}
|
||||
|
||||
// Calculate cost
|
||||
calculateCost(this.modelInfo, usage);
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
const output = {
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage,
|
||||
stopReason: this.mapStopReason(finishReason),
|
||||
} satisfies AssistantMessage;
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
} catch (error) {
|
||||
const output = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
} satisfies AssistantMessage;
|
||||
options?.onEvent?.({ type: "error", error: output.error || "Unknown error" });
|
||||
// Update output with error information
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : String(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
|
@ -325,6 +312,9 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
|
|
@ -337,7 +327,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
}
|
||||
|
||||
// Convert messages
|
||||
for (const msg of messages) {
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import type {
|
|||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAIResponsesLLMOptions extends LLMOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||
|
|
@ -50,7 +51,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
|
|
@ -132,7 +133,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
lastPart.text += event.delta;
|
||||
options?.onEvent?.({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.join("\n\n"),
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: event.delta,
|
||||
});
|
||||
}
|
||||
|
|
@ -141,11 +142,16 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
// Add a new line between summary parts (hack...)
|
||||
else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
options?.onEvent?.({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.join("\n\n"),
|
||||
delta: "\n\n",
|
||||
});
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
lastPart.text += "\n\n";
|
||||
options?.onEvent?.({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: "\n\n",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle text output deltas
|
||||
|
|
@ -189,7 +195,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
|
||||
if (item.type === "reasoning") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
const thinkingContent = item.summary?.map((s: any) => s.text).join("\n\n") || "";
|
||||
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
options?.onEvent?.({ type: "thinking_end", content: thinkingContent });
|
||||
} else if (item.type === "message") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
|
|
@ -280,6 +286,9 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput {
|
||||
const input: ResponseInput = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
const role = this.modelInfo?.reasoning ? "developer" : "system";
|
||||
|
|
@ -290,7 +299,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
}
|
||||
|
||||
// Convert messages
|
||||
for (const msg of messages) {
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
|
|
|
|||
54
packages/ai/src/providers/utils.ts
Normal file
54
packages/ai/src/providers/utils.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import type { AssistantMessage, Message, Model } from "../types.js";
|
||||
|
||||
/**
|
||||
* Transform messages for cross-provider compatibility.
|
||||
*
|
||||
* - User and toolResult messages are copied verbatim
|
||||
* - Assistant messages:
|
||||
* - If from the same provider/model, copied as-is
|
||||
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
|
||||
*
|
||||
* @param messages The messages to transform
|
||||
* @param model The target model that will process these messages
|
||||
* @returns A copy of the messages array with transformations applied
|
||||
*/
|
||||
export function transformMessages(messages: Message[], model: Model): Message[] {
|
||||
return messages.map((msg) => {
|
||||
// User and toolResult messages pass through unchanged
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Assistant messages need transformation check
|
||||
if (msg.role === "assistant") {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
|
||||
// If message is from the same provider and model, keep as-is
|
||||
if (assistantMsg.provider === model.provider && assistantMsg.model === model.id) {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Transform message from different provider/model
|
||||
const transformedContent = assistantMsg.content.map((block) => {
|
||||
if (block.type === "thinking") {
|
||||
// Convert thinking block to text block with <thinking> tags
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: `<thinking>\n${block.thinking}\n</thinking>`,
|
||||
};
|
||||
}
|
||||
// All other blocks (text, toolCall) pass through unchanged
|
||||
return block;
|
||||
});
|
||||
|
||||
// Return transformed assistant message
|
||||
return {
|
||||
...assistantMsg,
|
||||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
|
||||
// Should not reach here, but return as-is for safety
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
|
@ -6,7 +6,7 @@ export interface LLMOptions {
|
|||
}
|
||||
|
||||
export interface LLM<T extends LLMOptions> {
|
||||
complete(request: Context, options?: T): Promise<AssistantMessage>;
|
||||
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
||||
getModel(): Model;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue