import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; import { AuthStorage } from "../src/core/auth-storage.js"; import { ModelRegistry } from "../src/core/model-registry.js"; describe("ModelRegistry", () => { let tempDir: string; let modelsJsonPath: string; let authStorage: AuthStorage; beforeEach(() => { tempDir = join(tmpdir(), `pi-test-model-registry-${Date.now()}-${Math.random().toString(36).slice(2)}`); mkdirSync(tempDir, { recursive: true }); modelsJsonPath = join(tempDir, "models.json"); authStorage = new AuthStorage(join(tempDir, "auth.json")); }); afterEach(() => { if (tempDir && existsSync(tempDir)) { rmSync(tempDir, { recursive: true }); } }); /** Create minimal provider config */ function providerConfig( baseUrl: string, models: Array<{ id: string; name?: string }>, api: string = "anthropic-messages", ) { return { baseUrl, apiKey: "TEST_KEY", api, models: models.map((m) => ({ id: m.id, name: m.name ?? m.id, reasoning: false, input: ["text"], cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: 100000, maxTokens: 8000, })), }; } function writeModelsJson(providers: Record>) { writeFileSync(modelsJsonPath, JSON.stringify({ providers })); } function getModelsForProvider(registry: ModelRegistry, provider: string) { return registry.getAll().filter((m) => m.provider === provider); } /** Create a baseUrl-only override (no custom models) */ function overrideConfig(baseUrl: string, headers?: Record) { return { baseUrl, ...(headers && { headers }) }; } /** Write raw providers config (for mixed override/replacement scenarios) */ function writeRawModelsJson(providers: Record) { writeFileSync(modelsJsonPath, JSON.stringify({ providers })); } describe("baseUrl override (no custom models)", () => { test("overriding baseUrl keeps all built-in models", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // Should have multiple built-in models, not just one expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); }); test("overriding baseUrl changes URL on all built-in models", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // All models should have the new baseUrl for (const model of anthropicModels) { expect(model.baseUrl).toBe("https://my-proxy.example.com/v1"); } }); test("overriding headers merges with model headers", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1", { "X-Custom-Header": "custom-value", }), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); for (const model of anthropicModels) { expect(model.headers?.["X-Custom-Header"]).toBe("custom-value"); } }); test("baseUrl-only override does not affect other providers", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const googleModels = getModelsForProvider(registry, "google"); // Google models should still have their original baseUrl expect(googleModels.length).toBeGreaterThan(0); expect(googleModels[0].baseUrl).not.toBe("https://my-proxy.example.com/v1"); }); test("can mix baseUrl override and full replacement", () => { writeRawModelsJson({ // baseUrl-only for anthropic anthropic: overrideConfig("https://anthropic-proxy.example.com/v1"), // Full replacement for google google: providerConfig( "https://google-proxy.example.com/v1", [{ id: "gemini-custom" }], "google-generative-ai", ), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); // Anthropic: multiple built-in models with new baseUrl const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); // Google: single custom model const googleModels = getModelsForProvider(registry, "google"); expect(googleModels).toHaveLength(1); expect(googleModels[0].id).toBe("gemini-custom"); }); test("refresh() picks up baseUrl override changes", () => { writeRawModelsJson({ anthropic: overrideConfig("https://first-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1"); // Update and refresh writeRawModelsJson({ anthropic: overrideConfig("https://second-proxy.example.com/v1"), }); registry.refresh(); expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://second-proxy.example.com/v1"); }); }); describe("provider replacement (with custom models)", () => { test("custom provider with same name as built-in replaces built-in models", () => { writeModelsJson({ anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels).toHaveLength(1); expect(anthropicModels[0].id).toBe("claude-custom"); expect(anthropicModels[0].baseUrl).toBe("https://my-proxy.example.com/v1"); }); test("custom provider with same name as built-in does not affect other built-in providers", () => { writeModelsJson({ anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0); expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0); }); test("multiple built-in providers can be overridden", () => { writeModelsJson({ anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]), google: providerConfig( "https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai", ), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); const googleModels = getModelsForProvider(registry, "google"); expect(anthropicModels).toHaveLength(1); expect(anthropicModels[0].id).toBe("claude-proxy"); expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); expect(googleModels).toHaveLength(1); expect(googleModels[0].id).toBe("gemini-proxy"); expect(googleModels[0].baseUrl).toBe("https://google-proxy.example.com/v1"); }); test("refresh() reloads overrides from disk", () => { writeModelsJson({ anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-first" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")[0].id).toBe("claude-first"); // Update and refresh writeModelsJson({ anthropic: providerConfig("https://second-proxy.example.com/v1", [{ id: "claude-second" }]), }); registry.refresh(); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels[0].id).toBe("claude-second"); expect(anthropicModels[0].baseUrl).toBe("https://second-proxy.example.com/v1"); }); test("removing override from models.json restores built-in provider", () => { writeModelsJson({ anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")).toHaveLength(1); // Remove override and refresh writeModelsJson({}); registry.refresh(); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); }); }); });