mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 02:04:32 +00:00
refactor(ai): Update LLM implementations to use Model objects
- LLM constructors now take Model objects instead of string IDs - Added provider field to AssistantMessage interface - Updated getModel function with type-safe model ID autocomplete - Fixed Anthropic model ID mapping for proper API aliases - Added baseUrl to Model interface for provider-specific endpoints - Updated all tests to use getModel for model instantiation - Removed deprecated models.json in favor of generated models
This commit is contained in:
parent
d61d09b88d
commit
f9d688d577
11 changed files with 334 additions and 8447 deletions
|
|
@ -11,6 +11,7 @@ import type {
|
|||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TokenUsage,
|
||||
ToolCall,
|
||||
|
|
@ -26,10 +27,10 @@ export interface AnthropicLLMOptions extends LLMOptions {
|
|||
|
||||
export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||
private client: Anthropic;
|
||||
private model: string;
|
||||
private modelInfo: Model;
|
||||
private isOAuthToken: boolean = false;
|
||||
|
||||
constructor(model: string, apiKey?: string, baseUrl?: string) {
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.ANTHROPIC_API_KEY) {
|
||||
throw new Error(
|
||||
|
|
@ -45,13 +46,17 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
};
|
||||
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
this.client = new Anthropic({ apiKey: null, authToken: apiKey, baseURL: baseUrl, defaultHeaders });
|
||||
this.client = new Anthropic({ apiKey: null, authToken: apiKey, baseURL: model.baseUrl, defaultHeaders });
|
||||
this.isOAuthToken = true;
|
||||
} else {
|
||||
this.client = new Anthropic({ apiKey, baseURL: baseUrl });
|
||||
this.client = new Anthropic({ apiKey, baseURL: model.baseUrl });
|
||||
this.isOAuthToken = false;
|
||||
}
|
||||
this.model = model;
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||
|
|
@ -59,7 +64,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
const messages = this.convertMessages(context.messages);
|
||||
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: this.model,
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
max_tokens: options?.maxTokens || 4096,
|
||||
stream: true,
|
||||
|
|
@ -97,7 +102,8 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
params.tools = this.convertTools(context.tools);
|
||||
}
|
||||
|
||||
if (options?.thinking?.enabled) {
|
||||
// Only enable thinking if the model supports it
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
|
|
@ -194,14 +200,16 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
thinking,
|
||||
thinkingSignature,
|
||||
toolCalls,
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage,
|
||||
stopReason: this.mapStopReason(msg.stop_reason),
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
role: "assistant",
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import type {
|
|||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TokenUsage,
|
||||
Tool,
|
||||
|
|
@ -27,9 +28,9 @@ export interface GoogleLLMOptions extends LLMOptions {
|
|||
|
||||
export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||
private client: GoogleGenAI;
|
||||
private model: string;
|
||||
private model: Model;
|
||||
|
||||
constructor(model: string, apiKey?: string) {
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.GEMINI_API_KEY) {
|
||||
throw new Error(
|
||||
|
|
@ -42,6 +43,10 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
this.model = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
async complete(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
try {
|
||||
const contents = this.convertMessages(context.messages);
|
||||
|
|
@ -71,8 +76,8 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
};
|
||||
}
|
||||
|
||||
// Add thinking config if enabled
|
||||
if (options?.thinking?.enabled) {
|
||||
// Add thinking config if enabled and model supports it
|
||||
if (options?.thinking?.enabled && this.model.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
|
|
@ -81,7 +86,7 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
|
||||
// Build the request parameters
|
||||
const params: GenerateContentParameters = {
|
||||
model: this.model,
|
||||
model: this.model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
|
@ -207,14 +212,16 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
thinking: thinking || undefined,
|
||||
thinkingSignature: thoughtSignature,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
model: this.model,
|
||||
provider: this.model.provider,
|
||||
model: this.model.id,
|
||||
usage,
|
||||
stopReason,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
role: "assistant",
|
||||
model: this.model,
|
||||
provider: this.model.provider,
|
||||
model: this.model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -6,6 +6,7 @@ import type {
|
|||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TokenUsage,
|
||||
Tool,
|
||||
|
|
@ -19,9 +20,9 @@ export interface OpenAICompletionsLLMOptions extends LLMOptions {
|
|||
|
||||
export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private model: string;
|
||||
private modelInfo: Model;
|
||||
|
||||
constructor(model: string, apiKey?: string, baseUrl?: string) {
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
|
|
@ -30,8 +31,12 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: baseUrl });
|
||||
this.model = model;
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||
|
|
@ -39,14 +44,14 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
const messages = this.convertMessages(request.messages, request.systemPrompt);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: this.model,
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
};
|
||||
|
||||
// Cerebras/xAI dont like the "store" field
|
||||
if (!this.client.baseURL?.includes("cerebras.ai") || this.client.baseURL?.includes("api.x.ai")) {
|
||||
if (!this.modelInfo.baseUrl?.includes("cerebras.ai") && !this.modelInfo.baseUrl?.includes("api.x.ai")) {
|
||||
(params as any).store = false;
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +71,11 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if (options?.reasoningEffort && this.isReasoningModel() && !this.model.toLowerCase().includes("grok")) {
|
||||
if (
|
||||
options?.reasoningEffort &&
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.id.toLowerCase().includes("grok")
|
||||
) {
|
||||
params.reasoning_effort = options.reasoningEffort;
|
||||
}
|
||||
|
||||
|
|
@ -203,14 +212,16 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
thinking: reasoningContent || undefined,
|
||||
thinkingSignature: reasoningField || undefined,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage,
|
||||
stopReason: this.mapStopReason(finishReason),
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
role: "assistant",
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -230,9 +241,9 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
if (systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
const useDeveloperRole =
|
||||
this.isReasoningModel() &&
|
||||
!this.client.baseURL?.includes("cerebras.ai") &&
|
||||
!this.client.baseURL?.includes("api.x.ai");
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.baseUrl?.includes("cerebras.ai") &&
|
||||
!this.modelInfo.baseUrl?.includes("api.x.ai");
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: systemPrompt });
|
||||
}
|
||||
|
|
@ -305,9 +316,4 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
return "stop";
|
||||
}
|
||||
}
|
||||
|
||||
private isReasoningModel(): boolean {
|
||||
// TODO base on models.dev
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import type {
|
|||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TokenUsage,
|
||||
Tool,
|
||||
|
|
@ -24,9 +25,9 @@ export interface OpenAIResponsesLLMOptions extends LLMOptions {
|
|||
|
||||
export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private model: string;
|
||||
private modelInfo: Model;
|
||||
|
||||
constructor(model: string, apiKey?: string, baseUrl?: string) {
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
|
|
@ -35,8 +36,12 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: baseUrl });
|
||||
this.model = model;
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
async complete(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||
|
|
@ -44,7 +49,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
const input = this.convertToInput(request.messages, request.systemPrompt);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: this.model,
|
||||
model: this.modelInfo.id,
|
||||
input,
|
||||
stream: true,
|
||||
};
|
||||
|
|
@ -62,7 +67,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
}
|
||||
|
||||
// Add reasoning options for models that support it
|
||||
if (this.supportsReasoning() && (options?.reasoningEffort || options?.reasoningSummary)) {
|
||||
if (this.modelInfo?.reasoning && (options?.reasoningEffort || options?.reasoningSummary)) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
|
|
@ -145,7 +150,8 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
else if (event.type === "error") {
|
||||
return {
|
||||
role: "assistant",
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage,
|
||||
stopReason: "error",
|
||||
error: `Code ${event.code}: ${event.message}` || "Unknown error",
|
||||
|
|
@ -159,14 +165,16 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
thinking: thinking || undefined,
|
||||
thinkingSignature: JSON.stringify(reasoningItems) || undefined,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage,
|
||||
stopReason,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
role: "assistant",
|
||||
model: this.model,
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -184,7 +192,7 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
const role = this.supportsReasoning() ? "developer" : "system";
|
||||
const role = this.modelInfo?.reasoning ? "developer" : "system";
|
||||
input.push({
|
||||
role,
|
||||
content: systemPrompt,
|
||||
|
|
@ -260,14 +268,4 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
return "stop";
|
||||
}
|
||||
}
|
||||
|
||||
private supportsReasoning(): boolean {
|
||||
// TODO base on models.dev
|
||||
return (
|
||||
this.model.includes("o1") ||
|
||||
this.model.includes("o3") ||
|
||||
this.model.includes("gpt-5") ||
|
||||
this.model.includes("gpt-4o")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue