mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 21:03:42 +00:00
Allow models.json to override built-in providers (#406)
* Allow models.json to override built-in providers
When a provider is defined in models.json with the same name as a
built-in provider (e.g., 'anthropic', 'google'), the built-in models
for that provider are completely replaced by the custom definition.
This enables users to:
- Use custom base URLs (proxies, self-hosted endpoints)
- Define a subset of models they want available
- Customize model configurations for built-in providers
Example usage in ~/.pi/agent/models.json:
{
"providers": {
"anthropic": {
"baseUrl": "https://my-proxy.example.com/v1",
"apiKey": "ANTHROPIC_API_KEY",
"api": "anthropic-messages",
"models": [...]
}
}
}
* Refactor model-registry for readability
- Extract CustomModelsResult type and emptyCustomModelsResult helper
- Extract loadBuiltInModels method with clear skip logic
- Simplify loadModels with destructuring and ternary
- Reduce repetition in error handling paths
* Refactor model-registry tests for readability
- Extract providerConfig() helper to hide irrelevant model fields
- Extract writeModelsJson() helper for file writing
- Extract getModelsForProvider() helper for filtering
- Move modelsJsonPath to beforeEach
Reduces test file from 262 to 130 lines while maintaining same coverage.
This commit is contained in:
parent
6186e497c5
commit
243104fa18
2 changed files with 173 additions and 33 deletions
|
|
@ -74,6 +74,17 @@ const ModelsConfigSchema = Type.Object({
|
|||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/** Result of loading custom models from models.json */
|
||||
interface CustomModelsResult {
|
||||
models: Model<Api>[];
|
||||
providers: Set<string>;
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||
return { models: [], providers: new Set(), error };
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* Checks environment variable first, then treats as literal.
|
||||
|
|
@ -126,25 +137,17 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load built-in models
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
for (const provider of getProviders()) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models from models.json (if path provided)
|
||||
let customModels: Model<Api>[] = [];
|
||||
if (this.modelsJsonPath) {
|
||||
const result = this.loadCustomModels(this.modelsJsonPath);
|
||||
if (result.error) {
|
||||
this.loadError = result.error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
} else {
|
||||
customModels = result.models;
|
||||
}
|
||||
// Load custom models from models.json first (to know which providers to skip)
|
||||
const { models: customModels, providers: customProviders, error } = this.modelsJsonPath
|
||||
? this.loadCustomModels(this.modelsJsonPath)
|
||||
: emptyCustomModelsResult();
|
||||
|
||||
if (error) {
|
||||
this.loadError = error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(customProviders);
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth credentials
|
||||
|
|
@ -160,9 +163,16 @@ export class ModelRegistry {
|
|||
}
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | undefined } {
|
||||
/** Load built-in models, skipping providers that are overridden in models.json */
|
||||
private loadBuiltInModels(skipProviders: Set<string>): Model<Api>[] {
|
||||
return getProviders()
|
||||
.filter((provider) => !skipProviders.has(provider))
|
||||
.flatMap((provider) => getModels(provider as KnownProvider) as Model<Api>[]);
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: undefined };
|
||||
return emptyCustomModelsResult();
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -176,28 +186,22 @@ export class ModelRegistry {
|
|||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`);
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
// Parse models
|
||||
return { models: this.parseModels(config), error: undefined };
|
||||
// Parse models and collect provider names
|
||||
const providers = new Set(Object.keys(config.providers));
|
||||
return { models: this.parseModels(config), providers, error: undefined };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
136
packages/coding-agent/test/model-registry.test.ts
Normal file
136
packages/coding-agent/test/model-registry.test.ts
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, test } from "vitest";
|
||||
import { AuthStorage } from "../src/core/auth-storage.js";
|
||||
import { ModelRegistry } from "../src/core/model-registry.js";
|
||||
|
||||
describe("ModelRegistry", () => {
|
||||
let tempDir: string;
|
||||
let modelsJsonPath: string;
|
||||
let authStorage: AuthStorage;
|
||||
|
||||
beforeEach(() => {
|
||||
tempDir = join(tmpdir(), `pi-test-model-registry-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
mkdirSync(tempDir, { recursive: true });
|
||||
modelsJsonPath = join(tempDir, "models.json");
|
||||
authStorage = new AuthStorage(join(tempDir, "auth.json"));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (tempDir && existsSync(tempDir)) {
|
||||
rmSync(tempDir, { recursive: true });
|
||||
}
|
||||
});
|
||||
|
||||
/** Create minimal provider config */
|
||||
function providerConfig(
|
||||
baseUrl: string,
|
||||
models: Array<{ id: string; name?: string }>,
|
||||
api: string = "anthropic-messages",
|
||||
) {
|
||||
return {
|
||||
baseUrl,
|
||||
apiKey: "TEST_KEY",
|
||||
api,
|
||||
models: models.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.name ?? m.id,
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 100000,
|
||||
maxTokens: 8000,
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
function writeModelsJson(providers: Record<string, ReturnType<typeof providerConfig>>) {
|
||||
writeFileSync(modelsJsonPath, JSON.stringify({ providers }));
|
||||
}
|
||||
|
||||
function getModelsForProvider(registry: ModelRegistry, provider: string) {
|
||||
return registry.getAll().filter((m) => m.provider === provider);
|
||||
}
|
||||
|
||||
describe("provider override", () => {
|
||||
test("custom provider with same name as built-in replaces built-in models", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
expect(anthropicModels).toHaveLength(1);
|
||||
expect(anthropicModels[0].id).toBe("claude-custom");
|
||||
expect(anthropicModels[0].baseUrl).toBe("https://my-proxy.example.com/v1");
|
||||
});
|
||||
|
||||
test("custom provider with same name as built-in does not affect other built-in providers", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
|
||||
expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0);
|
||||
expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("multiple built-in providers can be overridden", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]),
|
||||
google: providerConfig("https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
const googleModels = getModelsForProvider(registry, "google");
|
||||
|
||||
expect(anthropicModels).toHaveLength(1);
|
||||
expect(anthropicModels[0].id).toBe("claude-proxy");
|
||||
expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1");
|
||||
|
||||
expect(googleModels).toHaveLength(1);
|
||||
expect(googleModels[0].id).toBe("gemini-proxy");
|
||||
expect(googleModels[0].baseUrl).toBe("https://google-proxy.example.com/v1");
|
||||
});
|
||||
|
||||
test("refresh() reloads overrides from disk", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-first" }]),
|
||||
});
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
|
||||
expect(getModelsForProvider(registry, "anthropic")[0].id).toBe("claude-first");
|
||||
|
||||
// Update and refresh
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://second-proxy.example.com/v1", [{ id: "claude-second" }]),
|
||||
});
|
||||
registry.refresh();
|
||||
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
expect(anthropicModels[0].id).toBe("claude-second");
|
||||
expect(anthropicModels[0].baseUrl).toBe("https://second-proxy.example.com/v1");
|
||||
});
|
||||
|
||||
test("removing override from models.json restores built-in provider", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
|
||||
expect(getModelsForProvider(registry, "anthropic")).toHaveLength(1);
|
||||
|
||||
// Remove override and refresh
|
||||
writeModelsJson({});
|
||||
registry.refresh();
|
||||
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
expect(anthropicModels.length).toBeGreaterThan(1);
|
||||
expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue