diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index 94a839f8..08a35eb6 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -74,6 +74,17 @@ const ModelsConfigSchema = Type.Object({ type ModelsConfig = Static; +/** Result of loading custom models from models.json */ +interface CustomModelsResult { + models: Model[]; + providers: Set; + error: string | undefined; +} + +function emptyCustomModelsResult(error?: string): CustomModelsResult { + return { models: [], providers: new Set(), error }; +} + /** * Resolve an API key config value to an actual key. * Checks environment variable first, then treats as literal. @@ -126,25 +137,17 @@ export class ModelRegistry { } private loadModels(): void { - // Load built-in models - const builtInModels: Model[] = []; - for (const provider of getProviders()) { - const providerModels = getModels(provider as KnownProvider); - builtInModels.push(...(providerModels as Model[])); - } - - // Load custom models from models.json (if path provided) - let customModels: Model[] = []; - if (this.modelsJsonPath) { - const result = this.loadCustomModels(this.modelsJsonPath); - if (result.error) { - this.loadError = result.error; - // Keep built-in models even if custom models failed to load - } else { - customModels = result.models; - } + // Load custom models from models.json first (to know which providers to skip) + const { models: customModels, providers: customProviders, error } = this.modelsJsonPath + ? this.loadCustomModels(this.modelsJsonPath) + : emptyCustomModelsResult(); + + if (error) { + this.loadError = error; + // Keep built-in models even if custom models failed to load } + const builtInModels = this.loadBuiltInModels(customProviders); const combined = [...builtInModels, ...customModels]; // Update github-copilot base URL based on OAuth credentials @@ -160,9 +163,16 @@ export class ModelRegistry { } } - private loadCustomModels(modelsJsonPath: string): { models: Model[]; error: string | undefined } { + /** Load built-in models, skipping providers that are overridden in models.json */ + private loadBuiltInModels(skipProviders: Set): Model[] { + return getProviders() + .filter((provider) => !skipProviders.has(provider)) + .flatMap((provider) => getModels(provider as KnownProvider) as Model[]); + } + + private loadCustomModels(modelsJsonPath: string): CustomModelsResult { if (!existsSync(modelsJsonPath)) { - return { models: [], error: undefined }; + return emptyCustomModelsResult(); } try { @@ -176,28 +186,22 @@ export class ModelRegistry { const errors = validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") || "Unknown schema error"; - return { - models: [], - error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`, - }; + return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`); } // Additional validation this.validateConfig(config); - // Parse models - return { models: this.parseModels(config), error: undefined }; + // Parse models and collect provider names + const providers = new Set(Object.keys(config.providers)); + return { models: this.parseModels(config), providers, error: undefined }; } catch (error) { if (error instanceof SyntaxError) { - return { - models: [], - error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`, - }; + return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`); } - return { - models: [], - error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`, - }; + return emptyCustomModelsResult( + `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`, + ); } } diff --git a/packages/coding-agent/test/model-registry.test.ts b/packages/coding-agent/test/model-registry.test.ts new file mode 100644 index 00000000..1c318ca7 --- /dev/null +++ b/packages/coding-agent/test/model-registry.test.ts @@ -0,0 +1,136 @@ +import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { ModelRegistry } from "../src/core/model-registry.js"; + +describe("ModelRegistry", () => { + let tempDir: string; + let modelsJsonPath: string; + let authStorage: AuthStorage; + + beforeEach(() => { + tempDir = join(tmpdir(), `pi-test-model-registry-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(tempDir, { recursive: true }); + modelsJsonPath = join(tempDir, "models.json"); + authStorage = new AuthStorage(join(tempDir, "auth.json")); + }); + + afterEach(() => { + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }); + + /** Create minimal provider config */ + function providerConfig( + baseUrl: string, + models: Array<{ id: string; name?: string }>, + api: string = "anthropic-messages", + ) { + return { + baseUrl, + apiKey: "TEST_KEY", + api, + models: models.map((m) => ({ + id: m.id, + name: m.name ?? m.id, + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 100000, + maxTokens: 8000, + })), + }; + } + + function writeModelsJson(providers: Record>) { + writeFileSync(modelsJsonPath, JSON.stringify({ providers })); + } + + function getModelsForProvider(registry: ModelRegistry, provider: string) { + return registry.getAll().filter((m) => m.provider === provider); + } + + describe("provider override", () => { + test("custom provider with same name as built-in replaces built-in models", () => { + writeModelsJson({ + anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const anthropicModels = getModelsForProvider(registry, "anthropic"); + + expect(anthropicModels).toHaveLength(1); + expect(anthropicModels[0].id).toBe("claude-custom"); + expect(anthropicModels[0].baseUrl).toBe("https://my-proxy.example.com/v1"); + }); + + test("custom provider with same name as built-in does not affect other built-in providers", () => { + writeModelsJson({ + anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + + expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0); + expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0); + }); + + test("multiple built-in providers can be overridden", () => { + writeModelsJson({ + anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]), + google: providerConfig("https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai"), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const anthropicModels = getModelsForProvider(registry, "anthropic"); + const googleModels = getModelsForProvider(registry, "google"); + + expect(anthropicModels).toHaveLength(1); + expect(anthropicModels[0].id).toBe("claude-proxy"); + expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); + + expect(googleModels).toHaveLength(1); + expect(googleModels[0].id).toBe("gemini-proxy"); + expect(googleModels[0].baseUrl).toBe("https://google-proxy.example.com/v1"); + }); + + test("refresh() reloads overrides from disk", () => { + writeModelsJson({ + anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-first" }]), + }); + const registry = new ModelRegistry(authStorage, modelsJsonPath); + + expect(getModelsForProvider(registry, "anthropic")[0].id).toBe("claude-first"); + + // Update and refresh + writeModelsJson({ + anthropic: providerConfig("https://second-proxy.example.com/v1", [{ id: "claude-second" }]), + }); + registry.refresh(); + + const anthropicModels = getModelsForProvider(registry, "anthropic"); + expect(anthropicModels[0].id).toBe("claude-second"); + expect(anthropicModels[0].baseUrl).toBe("https://second-proxy.example.com/v1"); + }); + + test("removing override from models.json restores built-in provider", () => { + writeModelsJson({ + anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]), + }); + const registry = new ModelRegistry(authStorage, modelsJsonPath); + + expect(getModelsForProvider(registry, "anthropic")).toHaveLength(1); + + // Remove override and refresh + writeModelsJson({}); + registry.refresh(); + + const anthropicModels = getModelsForProvider(registry, "anthropic"); + expect(anthropicModels.length).toBeGreaterThan(1); + expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); + }); + }); +});