co-mono/packages/coding-agent/src/model-config.ts
Mario Zechner 0c5cbd0068 v0.7.12: Custom models/providers support via models.json
- Add ~/.pi/agent/models.json config for custom providers (Ollama, vLLM, etc.)
- Support all 4 API types (openai-completions, openai-responses, anthropic-messages, google-generative-ai)
- Live reload models.json on /model selector open
- Smart model defaults per provider (claude-sonnet-4-5, gpt-5.1-codex, etc.)
- Graceful session fallback when saved model missing or no API key
- Validation errors show precise file/field info in CLI and TUI
- Agent knows its own README.md path for self-documentation
- Added gpt-5.1-codex (400k context, 128k output, reasoning)

Fixes #21
2025-11-16 22:56:24 +01:00

265 lines
7.7 KiB
TypeScript

import { type Api, getApiKey, getModels, getProviders, type KnownProvider, type Model } from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { homedir } from "os";
import { join } from "path";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
type ProviderConfig = Static<typeof ProviderConfigSchema>;
type ModelDefinition = Static<typeof ModelDefinitionSchema>;
// Custom provider API key mappings (provider name -> apiKey config)
const customProviderApiKeys: Map<string, string> = new Map();
/**
* Resolve an API key config value to an actual key.
* First checks if it's an environment variable, then treats as literal.
*/
export function resolveApiKey(keyConfig: string): string | undefined {
// First check if it's an env var name
const envValue = process.env[keyConfig];
if (envValue) return envValue;
// Otherwise treat as literal API key
return keyConfig;
}
/**
* Load custom models from ~/.pi/agent/models.json
* Returns { models, error } - either models array or error message
*/
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
const configPath = join(homedir(), ".pi", "agent", "models.json");
if (!existsSync(configPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(configPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${configPath}`,
};
}
// Additional validation
try {
validateConfig(config);
} catch (error) {
return {
models: [],
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
};
}
// Parse models
return { models: parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${configPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
};
}
}
/**
* Validate config structure and requirements
*/
function validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
`Set at provider or model level.`,
);
}
// Validate required fields
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
/**
* Parse config into Model objects
*/
function parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
// Clear and rebuild custom provider API key mappings
customProviderApiKeys.clear();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for this provider
customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
// Model-level api overrides provider-level api
const api = modelDef.api || providerConfig.api;
if (!api) {
// This should have been caught by validateConfig, but be safe
continue;
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
});
}
}
return models;
}
/**
* Get all models (built-in + custom), freshly loaded
* Returns { models, error } - either models array or error message
*/
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
const builtInModels: Model<Api>[] = [];
const providers = getProviders();
// Load all built-in models
for (const provider of providers) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models
const { models: customModels, error } = loadCustomModels();
if (error) {
return { models: [], error };
}
// Merge: custom models come after built-in
return { models: [...builtInModels, ...customModels], error: null };
}
/**
* Get API key for a model (checks custom providers first, then built-in)
*/
export function getApiKeyForModel(model: Model<Api>): string | undefined {
// For custom providers, check their apiKey config
const customKeyConfig = customProviderApiKeys.get(model.provider);
if (customKeyConfig) {
return resolveApiKey(customKeyConfig);
}
// For built-in providers, use getApiKey from @mariozechner/pi-ai
return getApiKey(model.provider as KnownProvider);
}
/**
* Get only models that have valid API keys available
* Returns { models, error } - either models array or error message
*/
export function getAvailableModels(): { models: Model<Api>[]; error: string | null } {
const { models: allModels, error } = loadAndMergeModels();
if (error) {
return { models: [], error };
}
const availableModels = allModels.filter((model) => {
const apiKey = getApiKeyForModel(model);
return !!apiKey;
});
return { models: availableModels, error: null };
}
/**
* Find a specific model by provider and ID
* Returns { model, error } - either model or error message
*/
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels();
if (error) {
return { model: null, error };
}
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
return { model, error: null };
}