import { existsSync, mkdirSync, readFileSync, rmSync, writeFileSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; import { afterEach, beforeEach, describe, expect, test } from "vitest"; import { AuthStorage } from "../src/core/auth-storage.js"; import { clearApiKeyCache, ModelRegistry } from "../src/core/model-registry.js"; describe("ModelRegistry", () => { let tempDir: string; let modelsJsonPath: string; let authStorage: AuthStorage; beforeEach(() => { tempDir = join(tmpdir(), `pi-test-model-registry-${Date.now()}-${Math.random().toString(36).slice(2)}`); mkdirSync(tempDir, { recursive: true }); modelsJsonPath = join(tempDir, "models.json"); authStorage = new AuthStorage(join(tempDir, "auth.json")); }); afterEach(() => { if (tempDir && existsSync(tempDir)) { rmSync(tempDir, { recursive: true }); } clearApiKeyCache(); }); /** Create minimal provider config */ function providerConfig( baseUrl: string, models: Array<{ id: string; name?: string }>, api: string = "anthropic-messages", ) { return { baseUrl, apiKey: "TEST_KEY", api, models: models.map((m) => ({ id: m.id, name: m.name ?? m.id, reasoning: false, input: ["text"], cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: 100000, maxTokens: 8000, })), }; } function writeModelsJson(providers: Record>) { writeFileSync(modelsJsonPath, JSON.stringify({ providers })); } function getModelsForProvider(registry: ModelRegistry, provider: string) { return registry.getAll().filter((m) => m.provider === provider); } /** Create a baseUrl-only override (no custom models) */ function overrideConfig(baseUrl: string, headers?: Record) { return { baseUrl, ...(headers && { headers }) }; } /** Write raw providers config (for mixed override/replacement scenarios) */ function writeRawModelsJson(providers: Record) { writeFileSync(modelsJsonPath, JSON.stringify({ providers })); } describe("baseUrl override (no custom models)", () => { test("overriding baseUrl keeps all built-in models", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // Should have multiple built-in models, not just one expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); }); test("overriding baseUrl changes URL on all built-in models", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // All models should have the new baseUrl for (const model of anthropicModels) { expect(model.baseUrl).toBe("https://my-proxy.example.com/v1"); } }); test("overriding headers merges with model headers", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1", { "X-Custom-Header": "custom-value", }), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); for (const model of anthropicModels) { expect(model.headers?.["X-Custom-Header"]).toBe("custom-value"); } }); test("baseUrl-only override does not affect other providers", () => { writeRawModelsJson({ anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const googleModels = getModelsForProvider(registry, "google"); // Google models should still have their original baseUrl expect(googleModels.length).toBeGreaterThan(0); expect(googleModels[0].baseUrl).not.toBe("https://my-proxy.example.com/v1"); }); test("can mix baseUrl override and full replacement", () => { writeRawModelsJson({ // baseUrl-only for anthropic anthropic: overrideConfig("https://anthropic-proxy.example.com/v1"), // Full replacement for google google: providerConfig( "https://google-proxy.example.com/v1", [{ id: "gemini-custom" }], "google-generative-ai", ), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); // Anthropic: multiple built-in models with new baseUrl const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); // Google: single custom model const googleModels = getModelsForProvider(registry, "google"); expect(googleModels).toHaveLength(1); expect(googleModels[0].id).toBe("gemini-custom"); }); test("refresh() picks up baseUrl override changes", () => { writeRawModelsJson({ anthropic: overrideConfig("https://first-proxy.example.com/v1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1"); // Update and refresh writeRawModelsJson({ anthropic: overrideConfig("https://second-proxy.example.com/v1"), }); registry.refresh(); expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://second-proxy.example.com/v1"); }); }); describe("provider replacement (with custom models)", () => { test("custom provider with same name as built-in replaces built-in models", () => { writeModelsJson({ anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels).toHaveLength(1); expect(anthropicModels[0].id).toBe("claude-custom"); expect(anthropicModels[0].baseUrl).toBe("https://my-proxy.example.com/v1"); }); test("custom provider with same name as built-in does not affect other built-in providers", () => { writeModelsJson({ anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0); expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0); }); test("multiple built-in providers can be overridden", () => { writeModelsJson({ anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]), google: providerConfig( "https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai", ), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); const googleModels = getModelsForProvider(registry, "google"); expect(anthropicModels).toHaveLength(1); expect(anthropicModels[0].id).toBe("claude-proxy"); expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); expect(googleModels).toHaveLength(1); expect(googleModels[0].id).toBe("gemini-proxy"); expect(googleModels[0].baseUrl).toBe("https://google-proxy.example.com/v1"); }); test("refresh() reloads overrides from disk", () => { writeModelsJson({ anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-first" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")[0].id).toBe("claude-first"); // Update and refresh writeModelsJson({ anthropic: providerConfig("https://second-proxy.example.com/v1", [{ id: "claude-second" }]), }); registry.refresh(); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels[0].id).toBe("claude-second"); expect(anthropicModels[0].baseUrl).toBe("https://second-proxy.example.com/v1"); }); test("removing override from models.json restores built-in provider", () => { writeModelsJson({ anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")).toHaveLength(1); // Remove override and refresh writeModelsJson({}); registry.refresh(); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels.length).toBeGreaterThan(1); expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); }); }); describe("API key resolution", () => { /** Create provider config with custom apiKey */ function providerWithApiKey(apiKey: string) { return { baseUrl: "https://example.com/v1", apiKey, api: "anthropic-messages", models: [ { id: "test-model", name: "Test Model", reasoning: false, input: ["text"], cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: 100000, maxTokens: 8000, }, ], }; } test("apiKey with ! prefix executes command and uses stdout", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!echo test-api-key-from-command"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("test-api-key-from-command"); }); test("apiKey with ! prefix trims whitespace from command output", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!echo ' spaced-key '"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("spaced-key"); }); test("apiKey with ! prefix handles multiline output (uses trimmed result)", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!printf 'line1\\nline2'"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("line1\nline2"); }); test("apiKey with ! prefix returns undefined on command failure", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!exit 1"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); }); test("apiKey with ! prefix returns undefined on nonexistent command", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!nonexistent-command-12345"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); }); test("apiKey with ! prefix returns undefined on empty output", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!printf ''"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); }); test("apiKey as environment variable name resolves to env value", async () => { const originalEnv = process.env.TEST_API_KEY_12345; process.env.TEST_API_KEY_12345 = "env-api-key-value"; try { writeRawModelsJson({ "custom-provider": providerWithApiKey("TEST_API_KEY_12345"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("env-api-key-value"); } finally { if (originalEnv === undefined) { delete process.env.TEST_API_KEY_12345; } else { process.env.TEST_API_KEY_12345 = originalEnv; } } }); test("apiKey as literal value is used directly when not an env var", async () => { // Make sure this isn't an env var delete process.env.literal_api_key_value; writeRawModelsJson({ "custom-provider": providerWithApiKey("literal_api_key_value"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("literal_api_key_value"); }); test("apiKey command can use shell features like pipes", async () => { writeRawModelsJson({ "custom-provider": providerWithApiKey("!echo 'hello world' | tr ' ' '-'"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("hello-world"); }); describe("caching", () => { test("command is only executed once per process", async () => { // Use a command that writes to a file to count invocations const counterFile = join(tempDir, "counter"); writeFileSync(counterFile, "0"); const command = `!sh -c 'count=$(cat ${counterFile}); echo $((count + 1)) > ${counterFile}; echo "key-value"'`; writeRawModelsJson({ "custom-provider": providerWithApiKey(command), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); // Call multiple times await registry.getApiKeyForProvider("custom-provider"); await registry.getApiKeyForProvider("custom-provider"); await registry.getApiKeyForProvider("custom-provider"); // Command should have only run once const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10); expect(count).toBe(1); }); test("cache persists across registry instances", async () => { const counterFile = join(tempDir, "counter"); writeFileSync(counterFile, "0"); const command = `!sh -c 'count=$(cat ${counterFile}); echo $((count + 1)) > ${counterFile}; echo "key-value"'`; writeRawModelsJson({ "custom-provider": providerWithApiKey(command), }); // Create multiple registry instances const registry1 = new ModelRegistry(authStorage, modelsJsonPath); await registry1.getApiKeyForProvider("custom-provider"); const registry2 = new ModelRegistry(authStorage, modelsJsonPath); await registry2.getApiKeyForProvider("custom-provider"); // Command should still have only run once const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10); expect(count).toBe(1); }); test("clearApiKeyCache allows command to run again", async () => { const counterFile = join(tempDir, "counter"); writeFileSync(counterFile, "0"); const command = `!sh -c 'count=$(cat ${counterFile}); echo $((count + 1)) > ${counterFile}; echo "key-value"'`; writeRawModelsJson({ "custom-provider": providerWithApiKey(command), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); await registry.getApiKeyForProvider("custom-provider"); // Clear cache and call again clearApiKeyCache(); await registry.getApiKeyForProvider("custom-provider"); // Command should have run twice const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10); expect(count).toBe(2); }); test("different commands are cached separately", async () => { writeRawModelsJson({ "provider-a": providerWithApiKey("!echo key-a"), "provider-b": providerWithApiKey("!echo key-b"), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const keyA = await registry.getApiKeyForProvider("provider-a"); const keyB = await registry.getApiKeyForProvider("provider-b"); expect(keyA).toBe("key-a"); expect(keyB).toBe("key-b"); }); test("failed commands are cached (not retried)", async () => { const counterFile = join(tempDir, "counter"); writeFileSync(counterFile, "0"); const command = `!sh -c 'count=$(cat ${counterFile}); echo $((count + 1)) > ${counterFile}; exit 1'`; writeRawModelsJson({ "custom-provider": providerWithApiKey(command), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); // Call multiple times - all should return undefined const key1 = await registry.getApiKeyForProvider("custom-provider"); const key2 = await registry.getApiKeyForProvider("custom-provider"); expect(key1).toBeUndefined(); expect(key2).toBeUndefined(); // Command should have only run once despite failures const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10); expect(count).toBe(1); }); test("environment variables are not cached (changes are picked up)", async () => { const envVarName = "TEST_API_KEY_CACHE_TEST_98765"; const originalEnv = process.env[envVarName]; try { process.env[envVarName] = "first-value"; writeRawModelsJson({ "custom-provider": providerWithApiKey(envVarName), }); const registry = new ModelRegistry(authStorage, modelsJsonPath); const key1 = await registry.getApiKeyForProvider("custom-provider"); expect(key1).toBe("first-value"); // Change env var process.env[envVarName] = "second-value"; const key2 = await registry.getApiKeyForProvider("custom-provider"); expect(key2).toBe("second-value"); } finally { if (originalEnv === undefined) { delete process.env[envVarName]; } else { process.env[envVarName] = originalEnv; } } }); }); }); });