mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 15:02:32 +00:00
v0.7.12: Custom models/providers support via models.json
- Add ~/.pi/agent/models.json config for custom providers (Ollama, vLLM, etc.) - Support all 4 API types (openai-completions, openai-responses, anthropic-messages, google-generative-ai) - Live reload models.json on /model selector open - Smart model defaults per provider (claude-sonnet-4-5, gpt-5.1-codex, etc.) - Graceful session fallback when saved model missing or no API key - Validation errors show precise file/field info in CLI and TUI - Agent knows its own README.md path for self-documentation - Added gpt-5.1-codex (400k context, 128k output, reasoning) Fixes #21
This commit is contained in:
parent
112ce6e5d1
commit
0c5cbd0068
15 changed files with 793 additions and 114 deletions
265
packages/coding-agent/src/model-config.ts
Normal file
265
packages/coding-agent/src/model-config.ts
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
import { type Api, getApiKey, getModels, getProviders, type KnownProvider, type Model } from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { join } from "path";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
||||
// Schema for custom model definition
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
contextWindow: Type.Number(),
|
||||
maxTokens: Type.Number(),
|
||||
});
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.String({ minLength: 1 }),
|
||||
apiKey: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
models: Type.Array(ModelDefinitionSchema),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
type ProviderConfig = Static<typeof ProviderConfigSchema>;
|
||||
type ModelDefinition = Static<typeof ModelDefinitionSchema>;
|
||||
|
||||
// Custom provider API key mappings (provider name -> apiKey config)
|
||||
const customProviderApiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* First checks if it's an environment variable, then treats as literal.
|
||||
*/
|
||||
export function resolveApiKey(keyConfig: string): string | undefined {
|
||||
// First check if it's an env var name
|
||||
const envValue = process.env[keyConfig];
|
||||
if (envValue) return envValue;
|
||||
|
||||
// Otherwise treat as literal API key
|
||||
return keyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom models from ~/.pi/agent/models.json
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
|
||||
const configPath = join(homedir(), ".pi", "agent", "models.json");
|
||||
if (!existsSync(configPath)) {
|
||||
return { models: [], error: null };
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(configPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const ajv = new Ajv();
|
||||
const validate = ajv.compile(ModelsConfigSchema);
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${configPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
try {
|
||||
validateConfig(config);
|
||||
} catch (error) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Parse models
|
||||
return { models: parseModels(config), error: null };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${configPath}`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate config structure and requirements
|
||||
*/
|
||||
function validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
|
||||
`Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
||||
if (modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse config into Model objects
|
||||
*/
|
||||
function parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
// Clear and rebuild custom provider API key mappings
|
||||
customProviderApiKeys.clear();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Store API key config for this provider
|
||||
customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
// Model-level api overrides provider-level api
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
|
||||
if (!api) {
|
||||
// This should have been caught by validateConfig, but be safe
|
||||
continue;
|
||||
}
|
||||
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom), freshly loaded
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
const providers = getProviders();
|
||||
|
||||
// Load all built-in models
|
||||
for (const provider of providers) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models
|
||||
const { models: customModels, error } = loadCustomModels();
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
// Merge: custom models come after built-in
|
||||
return { models: [...builtInModels, ...customModels], error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model (checks custom providers first, then built-in)
|
||||
*/
|
||||
export function getApiKeyForModel(model: Model<Api>): string | undefined {
|
||||
// For custom providers, check their apiKey config
|
||||
const customKeyConfig = customProviderApiKeys.get(model.provider);
|
||||
if (customKeyConfig) {
|
||||
return resolveApiKey(customKeyConfig);
|
||||
}
|
||||
|
||||
// For built-in providers, use getApiKey from @mariozechner/pi-ai
|
||||
return getApiKey(model.provider as KnownProvider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have valid API keys available
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
export function getAvailableModels(): { models: Model<Api>[]; error: string | null } {
|
||||
const { models: allModels, error } = loadAndMergeModels();
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
const availableModels = allModels.filter((model) => {
|
||||
const apiKey = getApiKeyForModel(model);
|
||||
return !!apiKey;
|
||||
});
|
||||
|
||||
return { models: availableModels, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a specific model by provider and ID
|
||||
* Returns { model, error } - either model or error message
|
||||
*/
|
||||
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
|
||||
const { models: allModels, error } = loadAndMergeModels();
|
||||
|
||||
if (error) {
|
||||
return { model: null, error };
|
||||
}
|
||||
|
||||
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
|
||||
return { model, error: null };
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue