refactor(ai): Implement unified model system with type-safe createLLM

- Add Model interface to types.ts with normalized structure
- Create type-safe generic createLLM function with provider-specific model constraints
- Generate models from OpenRouter API and models.dev data
- Strip provider prefixes for direct providers (google, openai, anthropic, xai)
- Keep full model IDs for OpenRouter-proxied models
- Clean separation: types.ts (Model interface), models.ts (factory logic), models.generated.ts (data)
- Remove old model scripts and unused dependencies
- Rename GeminiLLM to GoogleLLM for consistency
- Add tests for new providers (xAI, Groq, Cerebras, OpenRouter)
- Support 181 tool-capable models across 7 providers with full type safety
This commit is contained in:
Mario Zechner 2025-08-29 23:19:47 +02:00
parent 3f36051bc6
commit c7618db3f7
8 changed files with 409 additions and 418 deletions

View file

@ -3,263 +3,260 @@
import { readFileSync, writeFileSync } from "fs";
import { join } from "path";
// Load the models.json file
const data = JSON.parse(readFileSync(join(process.cwd(), "src/models.json"), "utf-8"));
// Categorize providers by their API type
const openaiModels: Record<string, any> = {};
const openaiCompatibleProviders: Record<string, any> = {};
const anthropicModels: Record<string, any> = {};
const geminiModels: Record<string, any> = {};
for (const [providerId, provider] of Object.entries(data)) {
const p = provider as any;
if (providerId === "openai") {
// All OpenAI models use the Responses API
openaiModels[providerId] = p;
} else if (providerId === "anthropic" || providerId === "google-vertex-anthropic") {
// Anthropic direct and via Vertex
anthropicModels[providerId] = p;
} else if (providerId === "google" || providerId === "google-vertex") {
// Google Gemini models
geminiModels[providerId] = p;
} else if (p.npm === "@ai-sdk/openai-compatible" ||
p.npm === "@ai-sdk/groq" ||
p.npm === "@ai-sdk/cerebras" ||
p.npm === "@ai-sdk/fireworks" ||
p.npm === "@ai-sdk/openrouter" ||
p.npm === "@ai-sdk/openai" && providerId !== "openai" ||
p.api?.includes("/v1") ||
["together", "ollama", "llama", "github-models", "groq", "cerebras", "openrouter", "fireworks"].includes(providerId)) {
// OpenAI-compatible providers - they all speak the OpenAI completions API
// Set default base URLs for known providers
if (!p.api) {
switch (providerId) {
case "groq": p.api = "https://api.groq.com/openai/v1"; break;
case "cerebras": p.api = "https://api.cerebras.com/v1"; break;
case "together": p.api = "https://api.together.xyz/v1"; break;
case "fireworks": p.api = "https://api.fireworks.ai/v1"; break;
}
}
openaiCompatibleProviders[providerId] = p;
}
interface ModelsDevModel {
id: string;
name: string;
tool_call?: boolean;
reasoning?: boolean;
limit?: {
context?: number;
output?: number;
};
cost?: {
input?: number;
output?: number;
cache_read?: number;
cache_write?: number;
};
modalities?: {
input?: string[];
};
}
// Generate the TypeScript file
let output = `// This file is auto-generated by scripts/generate-models.ts
interface NormalizedModel {
id: string;
name: string;
provider: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
}
async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
try {
console.log("🌐 Fetching models from OpenRouter API...");
const response = await fetch("https://openrouter.ai/api/v1/models");
const data = await response.json();
const models: NormalizedModel[] = [];
for (const model of data.data) {
// Only include models that support tools
if (!model.supported_parameters?.includes("tools")) continue;
// Parse provider from model ID
const [providerPrefix] = model.id.split("/");
let provider = "";
let modelKey = model.id;
// Map provider prefixes to our provider names
if (model.id.startsWith("google/")) {
provider = "google";
modelKey = model.id.replace("google/", "");
} else if (model.id.startsWith("openai/")) {
provider = "openai";
modelKey = model.id.replace("openai/", "");
} else if (model.id.startsWith("anthropic/")) {
provider = "anthropic";
modelKey = model.id.replace("anthropic/", "");
} else if (model.id.startsWith("x-ai/")) {
provider = "xai";
modelKey = model.id.replace("x-ai/", "");
} else {
// All other models go through OpenRouter
provider = "openrouter";
modelKey = model.id; // Keep full ID for OpenRouter
}
// Skip if not one of our supported providers
if (!["google", "openai", "anthropic", "xai", "openrouter"].includes(provider)) {
continue;
}
// Parse input modalities
const input: ("text" | "image")[] = ["text"];
if (model.architecture?.modality?.includes("image")) {
input.push("image");
}
// Convert pricing from $/token to $/million tokens
const inputCost = parseFloat(model.pricing?.prompt || "0") * 1_000_000;
const outputCost = parseFloat(model.pricing?.completion || "0") * 1_000_000;
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
models.push({
id: modelKey,
name: model.name,
provider,
reasoning: model.supported_parameters?.includes("reasoning") || false,
input,
cost: {
input: inputCost,
output: outputCost,
cacheRead: cacheReadCost,
cacheWrite: cacheWriteCost,
},
contextWindow: model.context_length || 4096,
maxTokens: model.top_provider?.max_completion_tokens || 4096,
});
}
console.log(`✅ Fetched ${models.length} tool-capable models from OpenRouter`);
return models;
} catch (error) {
console.error("❌ Failed to fetch OpenRouter models:", error);
return [];
}
}
function loadModelsDevData(): NormalizedModel[] {
try {
console.log("📁 Loading models from models.json...");
const data = JSON.parse(readFileSync(join(process.cwd(), "src/models.json"), "utf-8"));
const models: NormalizedModel[] = [];
// Process Groq models
if (data.groq?.models) {
for (const [modelId, model] of Object.entries(data.groq.models)) {
const m = model as ModelsDevModel;
if (m.tool_call !== true) continue;
models.push({
id: modelId,
name: m.name || modelId,
provider: "groq",
reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: {
input: m.cost?.input || 0,
output: m.cost?.output || 0,
cacheRead: m.cost?.cache_read || 0,
cacheWrite: m.cost?.cache_write || 0,
},
contextWindow: m.limit?.context || 4096,
maxTokens: m.limit?.output || 4096,
});
}
}
// Process Cerebras models
if (data.cerebras?.models) {
for (const [modelId, model] of Object.entries(data.cerebras.models)) {
const m = model as ModelsDevModel;
if (m.tool_call !== true) continue;
models.push({
id: modelId,
name: m.name || modelId,
provider: "cerebras",
reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: {
input: m.cost?.input || 0,
output: m.cost?.output || 0,
cacheRead: m.cost?.cache_read || 0,
cacheWrite: m.cost?.cache_write || 0,
},
contextWindow: m.limit?.context || 4096,
maxTokens: m.limit?.output || 4096,
});
}
}
console.log(`✅ Loaded ${models.length} tool-capable models from models.dev`);
return models;
} catch (error) {
console.error("❌ Failed to load models.dev data:", error);
return [];
}
}
async function generateModels() {
// Fetch all models
const openRouterModels = await fetchOpenRouterModels();
const modelsDevModels = loadModelsDevData();
// Combine models (models.dev takes priority for Groq/Cerebras)
const allModels = [...modelsDevModels, ...openRouterModels];
// Group by provider
const providers: Record<string, NormalizedModel[]> = {};
for (const model of allModels) {
if (!providers[model.provider]) {
providers[model.provider] = [];
}
providers[model.provider].push(model);
}
// Generate TypeScript file
let output = `// This file is auto-generated by scripts/generate-models.ts
// Do not edit manually - run 'npm run generate-models' to update
import type { ModalityInput, ModalityOutput } from "./models.js";
export interface ModelData {
id: string;
name: string;
reasoning: boolean;
tool_call: boolean;
attachment: boolean;
temperature: boolean;
knowledge?: string;
release_date: string;
last_updated: string;
modalities: {
input: ModalityInput[];
output: ModalityOutput[];
};
open_weights: boolean;
limit: {
context: number;
output: number;
};
cost?: {
input: number;
output: number;
cache_read?: number;
cache_write?: number;
};
}
export interface ProviderData {
id: string;
name: string;
baseUrl?: string;
env?: string[];
models: Record<string, ModelData>;
}
import type { Model } from "./types.js";
export const PROVIDERS = {
`;
// Generate OpenAI models
output += `// OpenAI models - all use OpenAIResponsesLLM\n`;
output += `export const OPENAI_MODELS = {\n`;
for (const [providerId, provider] of Object.entries(openaiModels)) {
const p = provider as any;
for (const [modelId, model] of Object.entries(p.models || {})) {
const m = model as any;
output += ` "${modelId}": ${JSON.stringify(m, null, 8).split('\n').join('\n ')},\n`;
}
}
output += `} as const;\n\n`;
// Generate provider sections
for (const [providerId, models] of Object.entries(providers)) {
output += `\t${providerId}: {\n`;
output += `\t\tmodels: {\n`;
// Generate OpenAI-compatible providers
output += `// OpenAI-compatible providers - use OpenAICompletionsLLM\n`;
output += `export const OPENAI_COMPATIBLE_PROVIDERS = {\n`;
for (const [providerId, provider] of Object.entries(openaiCompatibleProviders)) {
const p = provider as any;
output += ` "${providerId}": {\n`;
output += ` id: "${providerId}",\n`;
output += ` name: "${p.name}",\n`;
if (p.api) {
output += ` baseUrl: "${p.api}",\n`;
}
if (p.env) {
output += ` env: ${JSON.stringify(p.env)},\n`;
}
output += ` models: {\n`;
for (const [modelId, model] of Object.entries(p.models || {})) {
const m = model as any;
output += ` "${modelId}": ${JSON.stringify(m, null, 12).split('\n').join('\n ')},\n`;
}
output += ` }\n`;
output += ` },\n`;
}
output += `} as const;\n\n`;
for (const model of models) {
output += `\t\t\t"${model.id}": {\n`;
output += `\t\t\t\tid: "${model.id}",\n`;
output += `\t\t\t\tname: "${model.name}",\n`;
output += `\t\t\t\tprovider: "${model.provider}",\n`;
output += `\t\t\t\treasoning: ${model.reasoning},\n`;
output += `\t\t\t\tinput: ${JSON.stringify(model.input)},\n`;
output += `\t\t\t\tcost: {\n`;
output += `\t\t\t\t\tinput: ${model.cost.input},\n`;
output += `\t\t\t\t\toutput: ${model.cost.output},\n`;
output += `\t\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
output += `\t\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
output += `\t\t\t\t},\n`;
output += `\t\t\t\tcontextWindow: ${model.contextWindow},\n`;
output += `\t\t\t\tmaxTokens: ${model.maxTokens},\n`;
output += `\t\t\t} satisfies Model,\n`;
}
// Generate Anthropic models (avoiding duplicates)
output += `// Anthropic models - use AnthropicLLM\n`;
output += `export const ANTHROPIC_MODELS = {\n`;
const seenAnthropicModels = new Set<string>();
for (const [providerId, provider] of Object.entries(anthropicModels)) {
const p = provider as any;
for (const [modelId, model] of Object.entries(p.models || {})) {
if (!seenAnthropicModels.has(modelId)) {
seenAnthropicModels.add(modelId);
const m = model as any;
output += ` "${modelId}": ${JSON.stringify(m, null, 8).split('\n').join('\n ')},\n`;
}
}
}
output += `} as const;\n\n`;
output += `\t\t}\n`;
output += `\t},\n`;
}
// Generate Gemini models (avoiding duplicates)
output += `// Gemini models - use GeminiLLM\n`;
output += `export const GEMINI_MODELS = {\n`;
const seenGeminiModels = new Set<string>();
for (const [providerId, provider] of Object.entries(geminiModels)) {
const p = provider as any;
for (const [modelId, model] of Object.entries(p.models || {})) {
if (!seenGeminiModels.has(modelId)) {
seenGeminiModels.add(modelId);
const m = model as any;
output += ` "${modelId}": ${JSON.stringify(m, null, 8).split('\n').join('\n ')},\n`;
}
}
}
output += `} as const;\n\n`;
output += `} as const;
// Generate type helpers
output += `// Type helpers\n`;
output += `export type OpenAIModel = keyof typeof OPENAI_MODELS;\n`;
output += `export type OpenAICompatibleProvider = keyof typeof OPENAI_COMPATIBLE_PROVIDERS;\n`;
output += `export type AnthropicModel = keyof typeof ANTHROPIC_MODELS;\n`;
output += `export type GeminiModel = keyof typeof GEMINI_MODELS;\n\n`;
// Generate the factory function
output += `// Factory function implementation\n`;
output += `import { OpenAIResponsesLLM } from "./providers/openai-responses.js";\n`;
output += `import { OpenAICompletionsLLM } from "./providers/openai-completions.js";\n`;
output += `import { AnthropicLLM } from "./providers/anthropic.js";\n`;
output += `import { GeminiLLM } from "./providers/gemini.js";\n`;
output += `import type { LLM, LLMOptions } from "./types.js";\n\n`;
output += `export interface CreateLLMOptions {
apiKey?: string;
baseUrl?: string;
}
// Overloads for type safety
export function createLLM(
provider: "openai",
model: OpenAIModel,
options?: CreateLLMOptions
): OpenAIResponsesLLM;
export function createLLM(
provider: OpenAICompatibleProvider,
model: string, // We'll validate at runtime
options?: CreateLLMOptions
): OpenAICompletionsLLM;
export function createLLM(
provider: "anthropic",
model: AnthropicModel,
options?: CreateLLMOptions
): AnthropicLLM;
export function createLLM(
provider: "gemini",
model: GeminiModel,
options?: CreateLLMOptions
): GeminiLLM;
// Implementation
export function createLLM(
provider: string,
model: string,
options?: CreateLLMOptions
): LLM<LLMOptions> {
const apiKey = options?.apiKey || process.env[getEnvVar(provider)];
if (provider === "openai") {
return new OpenAIResponsesLLM(model, apiKey);
}
if (provider === "anthropic") {
return new AnthropicLLM(model, apiKey);
}
if (provider === "gemini") {
return new GeminiLLM(model, apiKey);
}
// OpenAI-compatible providers
if (provider in OPENAI_COMPATIBLE_PROVIDERS) {
const providerData = OPENAI_COMPATIBLE_PROVIDERS[provider as OpenAICompatibleProvider];
const baseUrl = options?.baseUrl || providerData.baseUrl;
return new OpenAICompletionsLLM(model, apiKey, baseUrl);
}
throw new Error(\`Unknown provider: \${provider}\`);
}
// Helper to get the default environment variable for a provider
function getEnvVar(provider: string): string {
switch (provider) {
case "openai": return "OPENAI_API_KEY";
case "anthropic": return "ANTHROPIC_API_KEY";
case "gemini": return "GEMINI_API_KEY";
case "groq": return "GROQ_API_KEY";
case "cerebras": return "CEREBRAS_API_KEY";
case "together": return "TOGETHER_API_KEY";
case "openrouter": return "OPENROUTER_API_KEY";
default: return \`\${provider.toUpperCase()}_API_KEY\`;
}
}
// Helper type to extract models for each provider
export type ProviderModels = {
[K in keyof typeof PROVIDERS]: typeof PROVIDERS[K]["models"]
};
`;
// Write the generated file
writeFileSync(join(process.cwd(), "src/models.generated.ts"), output);
console.log("✅ Generated src/models.generated.ts");
// Write file
writeFileSync(join(process.cwd(), "src/models.generated.ts"), output);
console.log("✅ Generated src/models.generated.ts");
// Count statistics
const openaiCount = Object.values(openaiModels).reduce((acc, p: any) => acc + Object.keys(p.models || {}).length, 0);
const compatCount = Object.values(openaiCompatibleProviders).reduce((acc, p: any) => acc + Object.keys(p.models || {}).length, 0);
const anthropicCount = Object.values(anthropicModels).reduce((acc, p: any) => acc + Object.keys(p.models || {}).length, 0);
const geminiCount = Object.values(geminiModels).reduce((acc, p: any) => acc + Object.keys(p.models || {}).length, 0);
// Print statistics
const totalModels = allModels.length;
const reasoningModels = allModels.filter(m => m.reasoning).length;
console.log(`\nModel counts:`);
console.log(` OpenAI (Responses API): ${openaiCount} models`);
console.log(` OpenAI-compatible: ${compatCount} models across ${Object.keys(openaiCompatibleProviders).length} providers`);
console.log(` Anthropic: ${anthropicCount} models`);
console.log(` Gemini: ${geminiCount} models`);
console.log(` Total: ${openaiCount + compatCount + anthropicCount + geminiCount} models`);
console.log(`\n📊 Model Statistics:`);
console.log(` Total tool-capable models: ${totalModels}`);
console.log(` Reasoning-capable models: ${reasoningModels}`);
for (const [provider, models] of Object.entries(providers)) {
console.log(` ${provider}: ${models.length} models`);
}
}
// Run the generator
generateModels().catch(console.error);