/** * Model registry - manages built-in and custom models, provides API key resolution. */ import { type Api, type AssistantMessageEventStream, type Context, getModels, getProviders, type KnownProvider, type Model, type OAuthProviderInterface, registerApiProvider, registerOAuthProvider, type SimpleStreamOptions, } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import AjvModule from "ajv"; import { execSync } from "child_process"; import { existsSync, readFileSync } from "fs"; import { join } from "path"; import { getAgentDir } from "../config.js"; import type { AuthStorage } from "./auth-storage.js"; const Ajv = (AjvModule as any).default || AjvModule; // Schema for OpenRouter routing preferences const OpenRouterRoutingSchema = Type.Object({ only: Type.Optional(Type.Array(Type.String())), order: Type.Optional(Type.Array(Type.String())), }); // Schema for OpenAI compatibility settings const OpenAICompletionsCompatSchema = Type.Object({ supportsStore: Type.Optional(Type.Boolean()), supportsDeveloperRole: Type.Optional(Type.Boolean()), supportsReasoningEffort: Type.Optional(Type.Boolean()), supportsUsageInStreaming: Type.Optional(Type.Boolean()), maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])), requiresToolResultName: Type.Optional(Type.Boolean()), requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()), requiresThinkingAsText: Type.Optional(Type.Boolean()), requiresMistralToolIds: Type.Optional(Type.Boolean()), thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])), openRouterRouting: Type.Optional(OpenRouterRoutingSchema), }); const OpenAIResponsesCompatSchema = Type.Object({ // Reserved for future use }); const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]); // Schema for custom model definition const ModelDefinitionSchema = Type.Object({ id: Type.String({ minLength: 1 }), name: Type.String({ minLength: 1 }), api: Type.Optional(Type.String({ minLength: 1 })), reasoning: Type.Boolean(), input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])), cost: Type.Object({ input: Type.Number(), output: Type.Number(), cacheRead: Type.Number(), cacheWrite: Type.Number(), }), contextWindow: Type.Number(), maxTokens: Type.Number(), headers: Type.Optional(Type.Record(Type.String(), Type.String())), compat: Type.Optional(OpenAICompatSchema), }); const ProviderConfigSchema = Type.Object({ baseUrl: Type.Optional(Type.String({ minLength: 1 })), apiKey: Type.Optional(Type.String({ minLength: 1 })), api: Type.Optional(Type.String({ minLength: 1 })), headers: Type.Optional(Type.Record(Type.String(), Type.String())), authHeader: Type.Optional(Type.Boolean()), models: Type.Optional(Type.Array(ModelDefinitionSchema)), }); const ModelsConfigSchema = Type.Object({ providers: Type.Record(Type.String(), ProviderConfigSchema), }); type ModelsConfig = Static; /** Provider override config (baseUrl, headers, apiKey) without custom models */ interface ProviderOverride { baseUrl?: string; headers?: Record; apiKey?: string; } /** Result of loading custom models from models.json */ interface CustomModelsResult { models: Model[]; /** Providers with custom models (full replacement) */ replacedProviders: Set; /** Providers with only baseUrl/headers override (no custom models) */ overrides: Map; error: string | undefined; } function emptyCustomModelsResult(error?: string): CustomModelsResult { return { models: [], replacedProviders: new Set(), overrides: new Map(), error }; } // Cache for shell command results (persists for process lifetime) const commandResultCache = new Map(); /** * Resolve a config value (API key, header value, etc.) to an actual value. * - If starts with "!", executes the rest as a shell command and uses stdout (cached) * - Otherwise checks environment variable first, then treats as literal (not cached) */ function resolveConfigValue(config: string): string | undefined { if (config.startsWith("!")) { return executeCommand(config); } const envValue = process.env[config]; return envValue || config; } function executeCommand(commandConfig: string): string | undefined { if (commandResultCache.has(commandConfig)) { return commandResultCache.get(commandConfig); } const command = commandConfig.slice(1); let result: string | undefined; try { const output = execSync(command, { encoding: "utf-8", timeout: 10000, stdio: ["ignore", "pipe", "ignore"], }); result = output.trim() || undefined; } catch { result = undefined; } commandResultCache.set(commandConfig, result); return result; } /** * Resolve all header values using the same resolution logic as API keys. */ function resolveHeaders(headers: Record | undefined): Record | undefined { if (!headers) return undefined; const resolved: Record = {}; for (const [key, value] of Object.entries(headers)) { const resolvedValue = resolveConfigValue(value); if (resolvedValue) { resolved[key] = resolvedValue; } } return Object.keys(resolved).length > 0 ? resolved : undefined; } /** Clear the config value command cache. Exported for testing. */ export function clearApiKeyCache(): void { commandResultCache.clear(); } /** * Model registry - loads and manages models, resolves API keys via AuthStorage. */ export class ModelRegistry { private models: Model[] = []; private customProviderApiKeys: Map = new Map(); private registeredProviders: Map = new Map(); private loadError: string | undefined = undefined; constructor( readonly authStorage: AuthStorage, private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"), ) { // Set up fallback resolver for custom provider API keys this.authStorage.setFallbackResolver((provider) => { const keyConfig = this.customProviderApiKeys.get(provider); if (keyConfig) { return resolveConfigValue(keyConfig); } return undefined; }); // Load models this.loadModels(); } /** * Reload models from disk (built-in + custom from models.json). */ refresh(): void { this.customProviderApiKeys.clear(); this.loadError = undefined; this.loadModels(); for (const [providerName, config] of this.registeredProviders.entries()) { this.applyProviderConfig(providerName, config); } } /** * Get any error from loading models.json (undefined if no error). */ getError(): string | undefined { return this.loadError; } private loadModels(): void { // Load custom models from models.json first (to know which providers to skip/override) const { models: customModels, replacedProviders, overrides, error, } = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult(); if (error) { this.loadError = error; // Keep built-in models even if custom models failed to load } const builtInModels = this.loadBuiltInModels(replacedProviders, overrides); let combined = [...builtInModels, ...customModels]; // Let OAuth providers modify their models (e.g., update baseUrl) for (const oauthProvider of this.authStorage.getOAuthProviders()) { const cred = this.authStorage.get(oauthProvider.id); if (cred?.type === "oauth" && oauthProvider.modifyModels) { combined = oauthProvider.modifyModels(combined, cred); } } this.models = combined; } /** Load built-in models, skipping replaced providers and applying overrides */ private loadBuiltInModels(replacedProviders: Set, overrides: Map): Model[] { return getProviders() .filter((provider) => !replacedProviders.has(provider)) .flatMap((provider) => { const models = getModels(provider as KnownProvider) as Model[]; const override = overrides.get(provider); if (!override) return models; // Apply baseUrl/headers override to all models of this provider const resolvedHeaders = resolveHeaders(override.headers); return models.map((m) => ({ ...m, baseUrl: override.baseUrl ?? m.baseUrl, headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers, })); }); } private loadCustomModels(modelsJsonPath: string): CustomModelsResult { if (!existsSync(modelsJsonPath)) { return emptyCustomModelsResult(); } try { const content = readFileSync(modelsJsonPath, "utf-8"); const config: ModelsConfig = JSON.parse(content); // Validate schema const ajv = new Ajv(); const validate = ajv.compile(ModelsConfigSchema); if (!validate(config)) { const errors = validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") || "Unknown schema error"; return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`); } // Additional validation this.validateConfig(config); // Separate providers into "full replacement" (has models) vs "override-only" (no models) const replacedProviders = new Set(); const overrides = new Map(); for (const [providerName, providerConfig] of Object.entries(config.providers)) { if (providerConfig.models && providerConfig.models.length > 0) { // Has custom models -> full replacement replacedProviders.add(providerName); } else { // No models -> just override baseUrl/headers on built-in overrides.set(providerName, { baseUrl: providerConfig.baseUrl, headers: providerConfig.headers, apiKey: providerConfig.apiKey, }); // Store API key for fallback resolver if (providerConfig.apiKey) { this.customProviderApiKeys.set(providerName, providerConfig.apiKey); } } } return { models: this.parseModels(config), replacedProviders, overrides, error: undefined }; } catch (error) { if (error instanceof SyntaxError) { return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`); } return emptyCustomModelsResult( `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`, ); } } private validateConfig(config: ModelsConfig): void { for (const [providerName, providerConfig] of Object.entries(config.providers)) { const hasProviderApi = !!providerConfig.api; const models = providerConfig.models ?? []; if (models.length === 0) { // Override-only config: just needs baseUrl (to override built-in) if (!providerConfig.baseUrl) { throw new Error( `Provider ${providerName}: must specify either "baseUrl" (for override) or "models" (for replacement).`, ); } } else { // Full replacement: needs baseUrl and apiKey if (!providerConfig.baseUrl) { throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`); } if (!providerConfig.apiKey) { throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`); } } for (const modelDef of models) { const hasModelApi = !!modelDef.api; if (!hasProviderApi && !hasModelApi) { throw new Error( `Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`, ); } if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`); if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`); if (modelDef.contextWindow <= 0) throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`); if (modelDef.maxTokens <= 0) throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`); } } } private parseModels(config: ModelsConfig): Model[] { const models: Model[] = []; for (const [providerName, providerConfig] of Object.entries(config.providers)) { const modelDefs = providerConfig.models ?? []; if (modelDefs.length === 0) continue; // Override-only, no custom models // Store API key config for fallback resolver if (providerConfig.apiKey) { this.customProviderApiKeys.set(providerName, providerConfig.apiKey); } for (const modelDef of modelDefs) { const api = modelDef.api || providerConfig.api; if (!api) continue; // Merge headers: provider headers are base, model headers override // Resolve env vars and shell commands in header values const providerHeaders = resolveHeaders(providerConfig.headers); const modelHeaders = resolveHeaders(modelDef.headers); let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined; // If authHeader is true, add Authorization header with resolved API key if (providerConfig.authHeader && providerConfig.apiKey) { const resolvedKey = resolveConfigValue(providerConfig.apiKey); if (resolvedKey) { headers = { ...headers, Authorization: `Bearer ${resolvedKey}` }; } } // baseUrl is validated to exist for providers with models models.push({ id: modelDef.id, name: modelDef.name, api: api as Api, provider: providerName, baseUrl: providerConfig.baseUrl!, reasoning: modelDef.reasoning, input: modelDef.input as ("text" | "image")[], cost: modelDef.cost, contextWindow: modelDef.contextWindow, maxTokens: modelDef.maxTokens, headers, compat: modelDef.compat, } as Model); } } return models; } /** * Get all models (built-in + custom). * If models.json had errors, returns only built-in models. */ getAll(): Model[] { return this.models; } /** * Get only models that have auth configured. * This is a fast check that doesn't refresh OAuth tokens. */ getAvailable(): Model[] { return this.models.filter((m) => this.authStorage.hasAuth(m.provider)); } /** * Find a model by provider and ID. */ find(provider: string, modelId: string): Model | undefined { return this.models.find((m) => m.provider === provider && m.id === modelId); } /** * Get API key for a model. */ async getApiKey(model: Model): Promise { return this.authStorage.getApiKey(model.provider); } /** * Get API key for a provider. */ async getApiKeyForProvider(provider: string): Promise { return this.authStorage.getApiKey(provider); } /** * Check if a model is using OAuth credentials (subscription). */ isUsingOAuth(model: Model): boolean { const cred = this.authStorage.get(model.provider); return cred?.type === "oauth"; } /** * Register a provider dynamically (from extensions). * * If provider has models: replaces all existing models for this provider. * If provider has only baseUrl/headers: overrides existing models' URLs. * If provider has oauth: registers OAuth provider for /login support. */ registerProvider(providerName: string, config: ProviderConfigInput): void { this.registeredProviders.set(providerName, config); this.applyProviderConfig(providerName, config); } private applyProviderConfig(providerName: string, config: ProviderConfigInput): void { // Register OAuth provider if provided if (config.oauth) { // Ensure the OAuth provider ID matches the provider name const oauthProvider: OAuthProviderInterface = { ...config.oauth, id: providerName, }; registerOAuthProvider(oauthProvider); } if (config.streamSimple) { if (!config.api) { throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`); } const streamSimple = config.streamSimple; registerApiProvider({ api: config.api, stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions), streamSimple, }); } // Store API key for auth resolution if (config.apiKey) { this.customProviderApiKeys.set(providerName, config.apiKey); } if (config.models && config.models.length > 0) { // Full replacement: remove existing models for this provider this.models = this.models.filter((m) => m.provider !== providerName); // Validate required fields if (!config.baseUrl) { throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`); } if (!config.apiKey && !config.oauth) { throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`); } // Parse and add new models for (const modelDef of config.models) { const api = modelDef.api || config.api; if (!api) { throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`); } // Merge headers const providerHeaders = resolveHeaders(config.headers); const modelHeaders = resolveHeaders(modelDef.headers); let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined; // If authHeader is true, add Authorization header if (config.authHeader && config.apiKey) { const resolvedKey = resolveConfigValue(config.apiKey); if (resolvedKey) { headers = { ...headers, Authorization: `Bearer ${resolvedKey}` }; } } this.models.push({ id: modelDef.id, name: modelDef.name, api: api as Api, provider: providerName, baseUrl: config.baseUrl, reasoning: modelDef.reasoning, input: modelDef.input as ("text" | "image")[], cost: modelDef.cost, contextWindow: modelDef.contextWindow, maxTokens: modelDef.maxTokens, headers, compat: modelDef.compat, } as Model); } } else if (config.baseUrl) { // Override-only: update baseUrl/headers for existing models const resolvedHeaders = resolveHeaders(config.headers); this.models = this.models.map((m) => { if (m.provider !== providerName) return m; return { ...m, baseUrl: config.baseUrl ?? m.baseUrl, headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers, }; }); } } } /** * Input type for registerProvider API. */ export interface ProviderConfigInput { baseUrl?: string; apiKey?: string; api?: Api; streamSimple?: (model: Model, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream; headers?: Record; authHeader?: boolean; /** OAuth provider for /login support */ oauth?: Omit; models?: Array<{ id: string; name: string; api?: Api; reasoning: boolean; input: ("text" | "image")[]; cost: { input: number; output: number; cacheRead: number; cacheWrite: number }; contextWindow: number; maxTokens: number; headers?: Record; compat?: Model["compat"]; }>; }