mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 00:03:00 +00:00
- Add ~/.pi/agent/models.json config for custom providers (Ollama, vLLM, etc.) - Support all 4 API types (openai-completions, openai-responses, anthropic-messages, google-generative-ai) - Live reload models.json on /model selector open - Smart model defaults per provider (claude-sonnet-4-5, gpt-5.1-codex, etc.) - Graceful session fallback when saved model missing or no API key - Validation errors show precise file/field info in CLI and TUI - Agent knows its own README.md path for self-documentation - Added gpt-5.1-codex (400k context, 128k output, reasoning) Fixes #21
265 lines
7.7 KiB
TypeScript
265 lines
7.7 KiB
TypeScript
import { type Api, getApiKey, getModels, getProviders, type KnownProvider, type Model } from "@mariozechner/pi-ai";
|
|
import { type Static, Type } from "@sinclair/typebox";
|
|
import AjvModule from "ajv";
|
|
import { existsSync, readFileSync } from "fs";
|
|
import { homedir } from "os";
|
|
import { join } from "path";
|
|
|
|
// Handle both default and named exports
|
|
const Ajv = (AjvModule as any).default || AjvModule;
|
|
|
|
// Schema for custom model definition
|
|
const ModelDefinitionSchema = Type.Object({
|
|
id: Type.String({ minLength: 1 }),
|
|
name: Type.String({ minLength: 1 }),
|
|
api: Type.Optional(
|
|
Type.Union([
|
|
Type.Literal("openai-completions"),
|
|
Type.Literal("openai-responses"),
|
|
Type.Literal("anthropic-messages"),
|
|
Type.Literal("google-generative-ai"),
|
|
]),
|
|
),
|
|
reasoning: Type.Boolean(),
|
|
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
|
cost: Type.Object({
|
|
input: Type.Number(),
|
|
output: Type.Number(),
|
|
cacheRead: Type.Number(),
|
|
cacheWrite: Type.Number(),
|
|
}),
|
|
contextWindow: Type.Number(),
|
|
maxTokens: Type.Number(),
|
|
});
|
|
|
|
const ProviderConfigSchema = Type.Object({
|
|
baseUrl: Type.String({ minLength: 1 }),
|
|
apiKey: Type.String({ minLength: 1 }),
|
|
api: Type.Optional(
|
|
Type.Union([
|
|
Type.Literal("openai-completions"),
|
|
Type.Literal("openai-responses"),
|
|
Type.Literal("anthropic-messages"),
|
|
Type.Literal("google-generative-ai"),
|
|
]),
|
|
),
|
|
models: Type.Array(ModelDefinitionSchema),
|
|
});
|
|
|
|
const ModelsConfigSchema = Type.Object({
|
|
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
|
});
|
|
|
|
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
|
type ProviderConfig = Static<typeof ProviderConfigSchema>;
|
|
type ModelDefinition = Static<typeof ModelDefinitionSchema>;
|
|
|
|
// Custom provider API key mappings (provider name -> apiKey config)
|
|
const customProviderApiKeys: Map<string, string> = new Map();
|
|
|
|
/**
|
|
* Resolve an API key config value to an actual key.
|
|
* First checks if it's an environment variable, then treats as literal.
|
|
*/
|
|
export function resolveApiKey(keyConfig: string): string | undefined {
|
|
// First check if it's an env var name
|
|
const envValue = process.env[keyConfig];
|
|
if (envValue) return envValue;
|
|
|
|
// Otherwise treat as literal API key
|
|
return keyConfig;
|
|
}
|
|
|
|
/**
|
|
* Load custom models from ~/.pi/agent/models.json
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
|
|
const configPath = join(homedir(), ".pi", "agent", "models.json");
|
|
if (!existsSync(configPath)) {
|
|
return { models: [], error: null };
|
|
}
|
|
|
|
try {
|
|
const content = readFileSync(configPath, "utf-8");
|
|
const config: ModelsConfig = JSON.parse(content);
|
|
|
|
// Validate schema
|
|
const ajv = new Ajv();
|
|
const validate = ajv.compile(ModelsConfigSchema);
|
|
if (!validate(config)) {
|
|
const errors =
|
|
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
|
"Unknown schema error";
|
|
return {
|
|
models: [],
|
|
error: `Invalid models.json schema:\n${errors}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
|
|
// Additional validation
|
|
try {
|
|
validateConfig(config);
|
|
} catch (error) {
|
|
return {
|
|
models: [],
|
|
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
|
|
// Parse models
|
|
return { models: parseModels(config), error: null };
|
|
} catch (error) {
|
|
if (error instanceof SyntaxError) {
|
|
return {
|
|
models: [],
|
|
error: `Failed to parse models.json: ${error.message}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
return {
|
|
models: [],
|
|
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Validate config structure and requirements
|
|
*/
|
|
function validateConfig(config: ModelsConfig): void {
|
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
|
const hasProviderApi = !!providerConfig.api;
|
|
|
|
for (const modelDef of providerConfig.models) {
|
|
const hasModelApi = !!modelDef.api;
|
|
|
|
if (!hasProviderApi && !hasModelApi) {
|
|
throw new Error(
|
|
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
|
|
`Set at provider or model level.`,
|
|
);
|
|
}
|
|
|
|
// Validate required fields
|
|
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
|
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
|
if (modelDef.contextWindow <= 0)
|
|
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
|
if (modelDef.maxTokens <= 0)
|
|
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Parse config into Model objects
|
|
*/
|
|
function parseModels(config: ModelsConfig): Model<Api>[] {
|
|
const models: Model<Api>[] = [];
|
|
|
|
// Clear and rebuild custom provider API key mappings
|
|
customProviderApiKeys.clear();
|
|
|
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
|
// Store API key config for this provider
|
|
customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
|
|
|
for (const modelDef of providerConfig.models) {
|
|
// Model-level api overrides provider-level api
|
|
const api = modelDef.api || providerConfig.api;
|
|
|
|
if (!api) {
|
|
// This should have been caught by validateConfig, but be safe
|
|
continue;
|
|
}
|
|
|
|
models.push({
|
|
id: modelDef.id,
|
|
name: modelDef.name,
|
|
api: api as Api,
|
|
provider: providerName,
|
|
baseUrl: providerConfig.baseUrl,
|
|
reasoning: modelDef.reasoning,
|
|
input: modelDef.input as ("text" | "image")[],
|
|
cost: modelDef.cost,
|
|
contextWindow: modelDef.contextWindow,
|
|
maxTokens: modelDef.maxTokens,
|
|
});
|
|
}
|
|
}
|
|
|
|
return models;
|
|
}
|
|
|
|
/**
|
|
* Get all models (built-in + custom), freshly loaded
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
|
|
const builtInModels: Model<Api>[] = [];
|
|
const providers = getProviders();
|
|
|
|
// Load all built-in models
|
|
for (const provider of providers) {
|
|
const providerModels = getModels(provider as KnownProvider);
|
|
builtInModels.push(...(providerModels as Model<Api>[]));
|
|
}
|
|
|
|
// Load custom models
|
|
const { models: customModels, error } = loadCustomModels();
|
|
|
|
if (error) {
|
|
return { models: [], error };
|
|
}
|
|
|
|
// Merge: custom models come after built-in
|
|
return { models: [...builtInModels, ...customModels], error: null };
|
|
}
|
|
|
|
/**
|
|
* Get API key for a model (checks custom providers first, then built-in)
|
|
*/
|
|
export function getApiKeyForModel(model: Model<Api>): string | undefined {
|
|
// For custom providers, check their apiKey config
|
|
const customKeyConfig = customProviderApiKeys.get(model.provider);
|
|
if (customKeyConfig) {
|
|
return resolveApiKey(customKeyConfig);
|
|
}
|
|
|
|
// For built-in providers, use getApiKey from @mariozechner/pi-ai
|
|
return getApiKey(model.provider as KnownProvider);
|
|
}
|
|
|
|
/**
|
|
* Get only models that have valid API keys available
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
export function getAvailableModels(): { models: Model<Api>[]; error: string | null } {
|
|
const { models: allModels, error } = loadAndMergeModels();
|
|
|
|
if (error) {
|
|
return { models: [], error };
|
|
}
|
|
|
|
const availableModels = allModels.filter((model) => {
|
|
const apiKey = getApiKeyForModel(model);
|
|
return !!apiKey;
|
|
});
|
|
|
|
return { models: availableModels, error: null };
|
|
}
|
|
|
|
/**
|
|
* Find a specific model by provider and ID
|
|
* Returns { model, error } - either model or error message
|
|
*/
|
|
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
|
|
const { models: allModels, error } = loadAndMergeModels();
|
|
|
|
if (error) {
|
|
return { model: null, error };
|
|
}
|
|
|
|
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
|
|
return { model, error: null };
|
|
}
|