mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 07:02:04 +00:00
fix(ai): Use API type instead of model for message compatibility checks
- Add getApi() method to all providers to identify the API type - Add api field to AssistantMessage to track which API generated it - Update transformMessages to check API compatibility instead of model - Fixes issue where OpenAI Responses API failed when switching models - Preserves thinking blocks and signatures when staying within same API
This commit is contained in:
parent
3007b7a5ac
commit
2cfd8ff3c3
6 changed files with 46 additions and 11 deletions
|
|
@ -77,10 +77,15 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
return this.modelInfo;
|
return this.modelInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getApi(): string {
|
||||||
|
return "anthropic-messages";
|
||||||
|
}
|
||||||
|
|
||||||
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
|
api: this.getApi(),
|
||||||
provider: this.modelInfo.provider,
|
provider: this.modelInfo.provider,
|
||||||
model: this.modelInfo.id,
|
model: this.modelInfo.id,
|
||||||
usage: {
|
usage: {
|
||||||
|
|
@ -260,7 +265,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
const params: MessageParam[] = [];
|
const params: MessageParam[] = [];
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
// Transform messages for cross-provider compatibility
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||||
|
|
||||||
for (const msg of transformedMessages) {
|
for (const msg of transformedMessages) {
|
||||||
if (msg.role === "user") {
|
if (msg.role === "user") {
|
||||||
|
|
@ -290,9 +295,12 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
const filteredBlocks = !this.modelInfo?.input.includes("image")
|
||||||
|
? blocks.filter((b) => b.type !== "image")
|
||||||
|
: blocks;
|
||||||
params.push({
|
params.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: blocks,
|
content: filteredBlocks,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (msg.role === "assistant") {
|
} else if (msg.role === "assistant") {
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import {
|
||||||
type GenerateContentParameters,
|
type GenerateContentParameters,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
type Part,
|
type Part,
|
||||||
|
setDefaultBaseUrls,
|
||||||
} from "@google/genai";
|
} from "@google/genai";
|
||||||
import { calculateCost } from "../models.js";
|
import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
|
|
@ -52,10 +53,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
return this.modelInfo;
|
return this.modelInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getApi(): string {
|
||||||
|
return "google-generative-ai";
|
||||||
|
}
|
||||||
|
|
||||||
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
|
api: this.getApi(),
|
||||||
provider: this.modelInfo.provider,
|
provider: this.modelInfo.provider,
|
||||||
model: this.modelInfo.id,
|
model: this.modelInfo.id,
|
||||||
usage: {
|
usage: {
|
||||||
|
|
@ -247,7 +253,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
const contents: Content[] = [];
|
const contents: Content[] = [];
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
// Transform messages for cross-provider compatibility
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||||
|
|
||||||
for (const msg of transformedMessages) {
|
for (const msg of transformedMessages) {
|
||||||
if (msg.role === "user") {
|
if (msg.role === "user") {
|
||||||
|
|
@ -272,9 +278,12 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
const filteredParts = !this.modelInfo?.input.includes("image")
|
||||||
|
? parts.filter((p) => p.text !== undefined)
|
||||||
|
: parts;
|
||||||
contents.push({
|
contents.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
parts,
|
parts: filteredParts,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (msg.role === "assistant") {
|
} else if (msg.role === "assistant") {
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,15 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
return this.modelInfo;
|
return this.modelInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getApi(): string {
|
||||||
|
return "openai-completions";
|
||||||
|
}
|
||||||
|
|
||||||
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
|
api: this.getApi(),
|
||||||
provider: this.modelInfo.provider,
|
provider: this.modelInfo.provider,
|
||||||
model: this.modelInfo.id,
|
model: this.modelInfo.id,
|
||||||
usage: {
|
usage: {
|
||||||
|
|
@ -313,7 +318,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
const params: ChatCompletionMessageParam[] = [];
|
const params: ChatCompletionMessageParam[] = [];
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
// Transform messages for cross-provider compatibility
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||||
|
|
||||||
// Add system prompt if provided
|
// Add system prompt if provided
|
||||||
if (systemPrompt) {
|
if (systemPrompt) {
|
||||||
|
|
@ -353,9 +358,12 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
} satisfies ChatCompletionContentPartImage;
|
} satisfies ChatCompletionContentPartImage;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||||
|
? content.filter((c) => c.type !== "image_url")
|
||||||
|
: content;
|
||||||
params.push({
|
params.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content,
|
content: filteredContent,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (msg.role === "assistant") {
|
} else if (msg.role === "assistant") {
|
||||||
|
|
|
||||||
|
|
@ -51,10 +51,15 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
return this.modelInfo;
|
return this.modelInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getApi(): string {
|
||||||
|
return "openai-responses";
|
||||||
|
}
|
||||||
|
|
||||||
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
|
api: this.getApi(),
|
||||||
provider: this.modelInfo.provider,
|
provider: this.modelInfo.provider,
|
||||||
model: this.modelInfo.id,
|
model: this.modelInfo.id,
|
||||||
usage: {
|
usage: {
|
||||||
|
|
@ -287,7 +292,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
const input: ResponseInput = [];
|
const input: ResponseInput = [];
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
// Transform messages for cross-provider compatibility
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo);
|
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||||
|
|
||||||
// Add system prompt if provided
|
// Add system prompt if provided
|
||||||
if (systemPrompt) {
|
if (systemPrompt) {
|
||||||
|
|
@ -324,9 +329,12 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
} satisfies ResponseInputImage;
|
} satisfies ResponseInputImage;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||||
|
? content.filter((c) => c.type !== "input_image")
|
||||||
|
: content;
|
||||||
input.push({
|
input.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content,
|
content: filteredContent,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (msg.role === "assistant") {
|
} else if (msg.role === "assistant") {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import type { AssistantMessage, Message, Model } from "../types.js";
|
||||||
* @param model The target model that will process these messages
|
* @param model The target model that will process these messages
|
||||||
* @returns A copy of the messages array with transformations applied
|
* @returns A copy of the messages array with transformations applied
|
||||||
*/
|
*/
|
||||||
export function transformMessages(messages: Message[], model: Model): Message[] {
|
export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
|
||||||
return messages.map((msg) => {
|
return messages.map((msg) => {
|
||||||
// User and toolResult messages pass through unchanged
|
// User and toolResult messages pass through unchanged
|
||||||
if (msg.role === "user" || msg.role === "toolResult") {
|
if (msg.role === "user" || msg.role === "toolResult") {
|
||||||
|
|
@ -23,8 +23,8 @@ export function transformMessages(messages: Message[], model: Model): Message[]
|
||||||
if (msg.role === "assistant") {
|
if (msg.role === "assistant") {
|
||||||
const assistantMsg = msg as AssistantMessage;
|
const assistantMsg = msg as AssistantMessage;
|
||||||
|
|
||||||
// If message is from the same provider and model, keep as-is
|
// If message is from the same provider and API, keep as is
|
||||||
if (assistantMsg.provider === model.provider && assistantMsg.model === model.id) {
|
if (assistantMsg.provider === model.provider && assistantMsg.api === api) {
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ export interface LLMOptions {
|
||||||
export interface LLM<T extends LLMOptions> {
|
export interface LLM<T extends LLMOptions> {
|
||||||
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
||||||
getModel(): Model;
|
getModel(): Model;
|
||||||
|
getApi(): string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TextContent {
|
export interface TextContent {
|
||||||
|
|
@ -59,6 +60,7 @@ export interface UserMessage {
|
||||||
export interface AssistantMessage {
|
export interface AssistantMessage {
|
||||||
role: "assistant";
|
role: "assistant";
|
||||||
content: (TextContent | ThinkingContent | ToolCall)[];
|
content: (TextContent | ThinkingContent | ToolCall)[];
|
||||||
|
api: string;
|
||||||
provider: string;
|
provider: string;
|
||||||
model: string;
|
model: string;
|
||||||
usage: Usage;
|
usage: Usage;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue