mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 17:00:59 +00:00
Fixes #39 - Added headers field to Model type (provider and model level) - Model headers override provider headers when merged - Supported in all APIs: - Anthropic: defaultHeaders - OpenAI (completions/responses): defaultHeaders - Google: httpOptions.headers - Enables bypassing Cloudflare bot detection for proxied endpoints - Updated documentation with examples Also fixed: - Mistral/Chutes syntax error (iif -> if) - process.env.ANTHROPIC_API_KEY bug (use delete instead of = undefined)
294 lines
8.7 KiB
TypeScript
294 lines
8.7 KiB
TypeScript
import { type Api, getApiKey, getModels, getProviders, type KnownProvider, type Model } from "@mariozechner/pi-ai";
|
|
import { type Static, Type } from "@sinclair/typebox";
|
|
import AjvModule from "ajv";
|
|
import { existsSync, readFileSync } from "fs";
|
|
import { homedir } from "os";
|
|
import { join } from "path";
|
|
import { getOAuthToken } from "./oauth/index.js";
|
|
|
|
// Handle both default and named exports
|
|
const Ajv = (AjvModule as any).default || AjvModule;
|
|
|
|
// Schema for custom model definition
|
|
const ModelDefinitionSchema = Type.Object({
|
|
id: Type.String({ minLength: 1 }),
|
|
name: Type.String({ minLength: 1 }),
|
|
api: Type.Optional(
|
|
Type.Union([
|
|
Type.Literal("openai-completions"),
|
|
Type.Literal("openai-responses"),
|
|
Type.Literal("anthropic-messages"),
|
|
Type.Literal("google-generative-ai"),
|
|
]),
|
|
),
|
|
reasoning: Type.Boolean(),
|
|
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
|
cost: Type.Object({
|
|
input: Type.Number(),
|
|
output: Type.Number(),
|
|
cacheRead: Type.Number(),
|
|
cacheWrite: Type.Number(),
|
|
}),
|
|
contextWindow: Type.Number(),
|
|
maxTokens: Type.Number(),
|
|
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
|
});
|
|
|
|
const ProviderConfigSchema = Type.Object({
|
|
baseUrl: Type.String({ minLength: 1 }),
|
|
apiKey: Type.String({ minLength: 1 }),
|
|
api: Type.Optional(
|
|
Type.Union([
|
|
Type.Literal("openai-completions"),
|
|
Type.Literal("openai-responses"),
|
|
Type.Literal("anthropic-messages"),
|
|
Type.Literal("google-generative-ai"),
|
|
]),
|
|
),
|
|
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
|
models: Type.Array(ModelDefinitionSchema),
|
|
});
|
|
|
|
const ModelsConfigSchema = Type.Object({
|
|
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
|
});
|
|
|
|
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
|
type ProviderConfig = Static<typeof ProviderConfigSchema>;
|
|
type ModelDefinition = Static<typeof ModelDefinitionSchema>;
|
|
|
|
// Custom provider API key mappings (provider name -> apiKey config)
|
|
const customProviderApiKeys: Map<string, string> = new Map();
|
|
|
|
/**
|
|
* Resolve an API key config value to an actual key.
|
|
* First checks if it's an environment variable, then treats as literal.
|
|
*/
|
|
export function resolveApiKey(keyConfig: string): string | undefined {
|
|
// First check if it's an env var name
|
|
const envValue = process.env[keyConfig];
|
|
if (envValue) return envValue;
|
|
|
|
// Otherwise treat as literal API key
|
|
return keyConfig;
|
|
}
|
|
|
|
/**
|
|
* Load custom models from ~/.pi/agent/models.json
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
|
|
const configPath = join(homedir(), ".pi", "agent", "models.json");
|
|
if (!existsSync(configPath)) {
|
|
return { models: [], error: null };
|
|
}
|
|
|
|
try {
|
|
const content = readFileSync(configPath, "utf-8");
|
|
const config: ModelsConfig = JSON.parse(content);
|
|
|
|
// Validate schema
|
|
const ajv = new Ajv();
|
|
const validate = ajv.compile(ModelsConfigSchema);
|
|
if (!validate(config)) {
|
|
const errors =
|
|
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
|
"Unknown schema error";
|
|
return {
|
|
models: [],
|
|
error: `Invalid models.json schema:\n${errors}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
|
|
// Additional validation
|
|
try {
|
|
validateConfig(config);
|
|
} catch (error) {
|
|
return {
|
|
models: [],
|
|
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
|
|
// Parse models
|
|
return { models: parseModels(config), error: null };
|
|
} catch (error) {
|
|
if (error instanceof SyntaxError) {
|
|
return {
|
|
models: [],
|
|
error: `Failed to parse models.json: ${error.message}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
return {
|
|
models: [],
|
|
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
|
|
};
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Validate config structure and requirements
|
|
*/
|
|
function validateConfig(config: ModelsConfig): void {
|
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
|
const hasProviderApi = !!providerConfig.api;
|
|
|
|
for (const modelDef of providerConfig.models) {
|
|
const hasModelApi = !!modelDef.api;
|
|
|
|
if (!hasProviderApi && !hasModelApi) {
|
|
throw new Error(
|
|
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
|
|
`Set at provider or model level.`,
|
|
);
|
|
}
|
|
|
|
// Validate required fields
|
|
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
|
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
|
if (modelDef.contextWindow <= 0)
|
|
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
|
if (modelDef.maxTokens <= 0)
|
|
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Parse config into Model objects
|
|
*/
|
|
function parseModels(config: ModelsConfig): Model<Api>[] {
|
|
const models: Model<Api>[] = [];
|
|
|
|
// Clear and rebuild custom provider API key mappings
|
|
customProviderApiKeys.clear();
|
|
|
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
|
// Store API key config for this provider
|
|
customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
|
|
|
for (const modelDef of providerConfig.models) {
|
|
// Model-level api overrides provider-level api
|
|
const api = modelDef.api || providerConfig.api;
|
|
|
|
if (!api) {
|
|
// This should have been caught by validateConfig, but be safe
|
|
continue;
|
|
}
|
|
|
|
// Merge headers: provider headers are base, model headers override
|
|
const headers =
|
|
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
|
|
|
|
models.push({
|
|
id: modelDef.id,
|
|
name: modelDef.name,
|
|
api: api as Api,
|
|
provider: providerName,
|
|
baseUrl: providerConfig.baseUrl,
|
|
reasoning: modelDef.reasoning,
|
|
input: modelDef.input as ("text" | "image")[],
|
|
cost: modelDef.cost,
|
|
contextWindow: modelDef.contextWindow,
|
|
maxTokens: modelDef.maxTokens,
|
|
headers,
|
|
});
|
|
}
|
|
}
|
|
|
|
return models;
|
|
}
|
|
|
|
/**
|
|
* Get all models (built-in + custom), freshly loaded
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
|
|
const builtInModels: Model<Api>[] = [];
|
|
const providers = getProviders();
|
|
|
|
// Load all built-in models
|
|
for (const provider of providers) {
|
|
const providerModels = getModels(provider as KnownProvider);
|
|
builtInModels.push(...(providerModels as Model<Api>[]));
|
|
}
|
|
|
|
// Load custom models
|
|
const { models: customModels, error } = loadCustomModels();
|
|
|
|
if (error) {
|
|
return { models: [], error };
|
|
}
|
|
|
|
// Merge: custom models come after built-in
|
|
return { models: [...builtInModels, ...customModels], error: null };
|
|
}
|
|
|
|
/**
|
|
* Get API key for a model (checks custom providers first, then built-in)
|
|
* Now async to support OAuth token refresh
|
|
*/
|
|
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
|
|
// For custom providers, check their apiKey config
|
|
const customKeyConfig = customProviderApiKeys.get(model.provider);
|
|
if (customKeyConfig) {
|
|
return resolveApiKey(customKeyConfig);
|
|
}
|
|
|
|
// For Anthropic, check OAuth first
|
|
if (model.provider === "anthropic") {
|
|
// 1. Check OAuth storage (auto-refresh if needed)
|
|
const oauthToken = await getOAuthToken("anthropic");
|
|
if (oauthToken) {
|
|
return oauthToken;
|
|
}
|
|
|
|
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
|
|
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
|
if (oauthEnv) {
|
|
return oauthEnv;
|
|
}
|
|
|
|
// 3. Fall back to ANTHROPIC_API_KEY env var
|
|
}
|
|
|
|
// For built-in providers, use getApiKey from @mariozechner/pi-ai
|
|
return getApiKey(model.provider as KnownProvider);
|
|
}
|
|
|
|
/**
|
|
* Get only models that have valid API keys available
|
|
* Returns { models, error } - either models array or error message
|
|
*/
|
|
export async function getAvailableModels(): Promise<{ models: Model<Api>[]; error: string | null }> {
|
|
const { models: allModels, error } = loadAndMergeModels();
|
|
|
|
if (error) {
|
|
return { models: [], error };
|
|
}
|
|
|
|
const availableModels: Model<Api>[] = [];
|
|
for (const model of allModels) {
|
|
const apiKey = await getApiKeyForModel(model);
|
|
if (apiKey) {
|
|
availableModels.push(model);
|
|
}
|
|
}
|
|
|
|
return { models: availableModels, error: null };
|
|
}
|
|
|
|
/**
|
|
* Find a specific model by provider and ID
|
|
* Returns { model, error } - either model or error message
|
|
*/
|
|
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
|
|
const { models: allModels, error } = loadAndMergeModels();
|
|
|
|
if (error) {
|
|
return { model: null, error };
|
|
}
|
|
|
|
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
|
|
return { model, error: null };
|
|
}
|