fix(ai): Use API type instead of model for message compatibility checks

- Add getApi() method to all providers to identify the API type
- Add api field to AssistantMessage to track which API generated it
- Update transformMessages to check API compatibility instead of model
- Fixes issue where OpenAI Responses API failed when switching models
- Preserves thinking blocks and signatures when staying within same API
This commit is contained in:
Mario Zechner 2025-09-02 00:20:06 +02:00
parent 3007b7a5ac
commit 2cfd8ff3c3
6 changed files with 46 additions and 11 deletions

View file

@ -77,10 +77,15 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
return this.modelInfo; return this.modelInfo;
} }
getApi(): string {
return "anthropic-messages";
}
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> { async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(),
provider: this.modelInfo.provider, provider: this.modelInfo.provider,
model: this.modelInfo.id, model: this.modelInfo.id,
usage: { usage: {
@ -260,7 +265,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
const params: MessageParam[] = []; const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility // Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo); const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
for (const msg of transformedMessages) { for (const msg of transformedMessages) {
if (msg.role === "user") { if (msg.role === "user") {
@ -290,9 +295,12 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
}; };
} }
}); });
const filteredBlocks = !this.modelInfo?.input.includes("image")
? blocks.filter((b) => b.type !== "image")
: blocks;
params.push({ params.push({
role: "user", role: "user",
content: blocks, content: filteredBlocks,
}); });
} }
} else if (msg.role === "assistant") { } else if (msg.role === "assistant") {

View file

@ -6,6 +6,7 @@ import {
type GenerateContentParameters, type GenerateContentParameters,
GoogleGenAI, GoogleGenAI,
type Part, type Part,
setDefaultBaseUrls,
} from "@google/genai"; } from "@google/genai";
import { calculateCost } from "../models.js"; import { calculateCost } from "../models.js";
import type { import type {
@ -52,10 +53,15 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
return this.modelInfo; return this.modelInfo;
} }
getApi(): string {
return "google-generative-ai";
}
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> { async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(),
provider: this.modelInfo.provider, provider: this.modelInfo.provider,
model: this.modelInfo.id, model: this.modelInfo.id,
usage: { usage: {
@ -247,7 +253,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
const contents: Content[] = []; const contents: Content[] = [];
// Transform messages for cross-provider compatibility // Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo); const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
for (const msg of transformedMessages) { for (const msg of transformedMessages) {
if (msg.role === "user") { if (msg.role === "user") {
@ -272,9 +278,12 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
}; };
} }
}); });
const filteredParts = !this.modelInfo?.input.includes("image")
? parts.filter((p) => p.text !== undefined)
: parts;
contents.push({ contents.push({
role: "user", role: "user",
parts, parts: filteredParts,
}); });
} }
} else if (msg.role === "assistant") { } else if (msg.role === "assistant") {

View file

@ -48,10 +48,15 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
return this.modelInfo; return this.modelInfo;
} }
getApi(): string {
return "openai-completions";
}
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> { async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(),
provider: this.modelInfo.provider, provider: this.modelInfo.provider,
model: this.modelInfo.id, model: this.modelInfo.id,
usage: { usage: {
@ -313,7 +318,7 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
const params: ChatCompletionMessageParam[] = []; const params: ChatCompletionMessageParam[] = [];
// Transform messages for cross-provider compatibility // Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo); const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
// Add system prompt if provided // Add system prompt if provided
if (systemPrompt) { if (systemPrompt) {
@ -353,9 +358,12 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
} satisfies ChatCompletionContentPartImage; } satisfies ChatCompletionContentPartImage;
} }
}); });
const filteredContent = !this.modelInfo?.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
params.push({ params.push({
role: "user", role: "user",
content, content: filteredContent,
}); });
} }
} else if (msg.role === "assistant") { } else if (msg.role === "assistant") {

View file

@ -51,10 +51,15 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
return this.modelInfo; return this.modelInfo;
} }
getApi(): string {
return "openai-responses";
}
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> { async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(),
provider: this.modelInfo.provider, provider: this.modelInfo.provider,
model: this.modelInfo.id, model: this.modelInfo.id,
usage: { usage: {
@ -287,7 +292,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
const input: ResponseInput = []; const input: ResponseInput = [];
// Transform messages for cross-provider compatibility // Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, this.modelInfo); const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
// Add system prompt if provided // Add system prompt if provided
if (systemPrompt) { if (systemPrompt) {
@ -324,9 +329,12 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
} satisfies ResponseInputImage; } satisfies ResponseInputImage;
} }
}); });
const filteredContent = !this.modelInfo?.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
input.push({ input.push({
role: "user", role: "user",
content, content: filteredContent,
}); });
} }
} else if (msg.role === "assistant") { } else if (msg.role === "assistant") {

View file

@ -12,7 +12,7 @@ import type { AssistantMessage, Message, Model } from "../types.js";
* @param model The target model that will process these messages * @param model The target model that will process these messages
* @returns A copy of the messages array with transformations applied * @returns A copy of the messages array with transformations applied
*/ */
export function transformMessages(messages: Message[], model: Model): Message[] { export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
return messages.map((msg) => { return messages.map((msg) => {
// User and toolResult messages pass through unchanged // User and toolResult messages pass through unchanged
if (msg.role === "user" || msg.role === "toolResult") { if (msg.role === "user" || msg.role === "toolResult") {
@ -23,8 +23,8 @@ export function transformMessages(messages: Message[], model: Model): Message[]
if (msg.role === "assistant") { if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage; const assistantMsg = msg as AssistantMessage;
// If message is from the same provider and model, keep as-is // If message is from the same provider and API, keep as is
if (assistantMsg.provider === model.provider && assistantMsg.model === model.id) { if (assistantMsg.provider === model.provider && assistantMsg.api === api) {
return msg; return msg;
} }

View file

@ -8,6 +8,7 @@ export interface LLMOptions {
export interface LLM<T extends LLMOptions> { export interface LLM<T extends LLMOptions> {
generate(request: Context, options?: T): Promise<AssistantMessage>; generate(request: Context, options?: T): Promise<AssistantMessage>;
getModel(): Model; getModel(): Model;
getApi(): string;
} }
export interface TextContent { export interface TextContent {
@ -59,6 +60,7 @@ export interface UserMessage {
export interface AssistantMessage { export interface AssistantMessage {
role: "assistant"; role: "assistant";
content: (TextContent | ThinkingContent | ToolCall)[]; content: (TextContent | ThinkingContent | ToolCall)[];
api: string;
provider: string; provider: string;
model: string; model: string;
usage: Usage; usage: Usage;