From d747ec6e23e562cba0f3b84112cb015493211997 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Sat, 3 Jan 2026 01:06:08 +0100 Subject: [PATCH] Enhance provider override to support baseUrl-only mode Builds on #406 to support simpler proxy use case: - Override just baseUrl to route built-in provider through proxy - All built-in models preserved, no need to redefine them - Full replacement still works when models array is provided --- packages/coding-agent/CHANGELOG.md | 1 + packages/coding-agent/README.md | 31 +++++ .../coding-agent/src/core/model-registry.ts | 114 +++++++++++++---- .../coding-agent/test/model-registry.test.ts | 117 +++++++++++++++++- 4 files changed, 238 insertions(+), 25 deletions(-) diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index 8b85a280..7a3e89b5 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -25,6 +25,7 @@ ### Added +- Built-in provider overrides in `models.json`: override just `baseUrl` to route a built-in provider through a proxy while keeping all its models, or define `models` to fully replace the provider ([#406](https://github.com/badlogic/pi-mono/pull/406) by [@yevhen](https://github.com/yevhen)) - Automatic image resizing: images larger than 2000x2000 are resized for better model compatibility. Original dimensions are injected into the prompt. Controlled via `/settings` or `images.autoResize` in settings.json. ([#402](https://github.com/badlogic/pi-mono/pull/402) by [@mitsuhiko](https://github.com/mitsuhiko)) - Alt+Enter keybind to queue follow-up messages while agent is streaming - `Theme` and `ThemeColor` types now exported for hooks using `ctx.ui.custom()` diff --git a/packages/coding-agent/README.md b/packages/coding-agent/README.md index 95614797..71c2d1cd 100644 --- a/packages/coding-agent/README.md +++ b/packages/coding-agent/README.md @@ -465,6 +465,37 @@ Add custom models (Ollama, vLLM, LM Studio, etc.) via `~/.pi/agent/models.json`: } ``` +**Overriding built-in providers:** + +To route a built-in provider (anthropic, openai, google, etc.) through a proxy without redefining all models, just specify the `baseUrl`: + +```json +{ + "providers": { + "anthropic": { + "baseUrl": "https://my-proxy.example.com/v1" + } + } +} +``` + +All built-in Anthropic models remain available with the new endpoint. Existing OAuth or API key auth continues to work. + +To fully replace a built-in provider with custom models, include the `models` array: + +```json +{ + "providers": { + "anthropic": { + "baseUrl": "https://my-proxy.example.com/v1", + "apiKey": "ANTHROPIC_API_KEY", + "api": "anthropic-messages", + "models": [...] + } + } +} +``` + **Authorization header:** Set `authHeader: true` to add `Authorization: Bearer ` automatically. **OpenAI compatibility (`compat` field):** diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index 08a35eb6..16cd6d02 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -53,8 +53,8 @@ const ModelDefinitionSchema = Type.Object({ }); const ProviderConfigSchema = Type.Object({ - baseUrl: Type.String({ minLength: 1 }), - apiKey: Type.String({ minLength: 1 }), + baseUrl: Type.Optional(Type.String({ minLength: 1 })), + apiKey: Type.Optional(Type.String({ minLength: 1 })), api: Type.Optional( Type.Union([ Type.Literal("openai-completions"), @@ -65,7 +65,7 @@ const ProviderConfigSchema = Type.Object({ ), headers: Type.Optional(Type.Record(Type.String(), Type.String())), authHeader: Type.Optional(Type.Boolean()), - models: Type.Array(ModelDefinitionSchema), + models: Type.Optional(Type.Array(ModelDefinitionSchema)), }); const ModelsConfigSchema = Type.Object({ @@ -74,15 +74,25 @@ const ModelsConfigSchema = Type.Object({ type ModelsConfig = Static; +/** Provider override config (baseUrl, headers, apiKey) without custom models */ +interface ProviderOverride { + baseUrl?: string; + headers?: Record; + apiKey?: string; +} + /** Result of loading custom models from models.json */ interface CustomModelsResult { models: Model[]; - providers: Set; + /** Providers with custom models (full replacement) */ + replacedProviders: Set; + /** Providers with only baseUrl/headers override (no custom models) */ + overrides: Map; error: string | undefined; } function emptyCustomModelsResult(error?: string): CustomModelsResult { - return { models: [], providers: new Set(), error }; + return { models: [], replacedProviders: new Set(), overrides: new Map(), error }; } /** @@ -137,17 +147,20 @@ export class ModelRegistry { } private loadModels(): void { - // Load custom models from models.json first (to know which providers to skip) - const { models: customModels, providers: customProviders, error } = this.modelsJsonPath - ? this.loadCustomModels(this.modelsJsonPath) - : emptyCustomModelsResult(); + // Load custom models from models.json first (to know which providers to skip/override) + const { + models: customModels, + replacedProviders, + overrides, + error, + } = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult(); if (error) { this.loadError = error; // Keep built-in models even if custom models failed to load } - const builtInModels = this.loadBuiltInModels(customProviders); + const builtInModels = this.loadBuiltInModels(replacedProviders, overrides); const combined = [...builtInModels, ...customModels]; // Update github-copilot base URL based on OAuth credentials @@ -163,11 +176,22 @@ export class ModelRegistry { } } - /** Load built-in models, skipping providers that are overridden in models.json */ - private loadBuiltInModels(skipProviders: Set): Model[] { + /** Load built-in models, skipping replaced providers and applying overrides */ + private loadBuiltInModels(replacedProviders: Set, overrides: Map): Model[] { return getProviders() - .filter((provider) => !skipProviders.has(provider)) - .flatMap((provider) => getModels(provider as KnownProvider) as Model[]); + .filter((provider) => !replacedProviders.has(provider)) + .flatMap((provider) => { + const models = getModels(provider as KnownProvider) as Model[]; + const override = overrides.get(provider); + if (!override) return models; + + // Apply baseUrl/headers override to all models of this provider + return models.map((m) => ({ + ...m, + baseUrl: override.baseUrl ?? m.baseUrl, + headers: override.headers ? { ...m.headers, ...override.headers } : m.headers, + })); + }); } private loadCustomModels(modelsJsonPath: string): CustomModelsResult { @@ -192,9 +216,29 @@ export class ModelRegistry { // Additional validation this.validateConfig(config); - // Parse models and collect provider names - const providers = new Set(Object.keys(config.providers)); - return { models: this.parseModels(config), providers, error: undefined }; + // Separate providers into "full replacement" (has models) vs "override-only" (no models) + const replacedProviders = new Set(); + const overrides = new Map(); + + for (const [providerName, providerConfig] of Object.entries(config.providers)) { + if (providerConfig.models && providerConfig.models.length > 0) { + // Has custom models -> full replacement + replacedProviders.add(providerName); + } else { + // No models -> just override baseUrl/headers on built-in + overrides.set(providerName, { + baseUrl: providerConfig.baseUrl, + headers: providerConfig.headers, + apiKey: providerConfig.apiKey, + }); + // Store API key for fallback resolver + if (providerConfig.apiKey) { + this.customProviderApiKeys.set(providerName, providerConfig.apiKey); + } + } + } + + return { models: this.parseModels(config), replacedProviders, overrides, error: undefined }; } catch (error) { if (error instanceof SyntaxError) { return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`); @@ -208,8 +252,26 @@ export class ModelRegistry { private validateConfig(config: ModelsConfig): void { for (const [providerName, providerConfig] of Object.entries(config.providers)) { const hasProviderApi = !!providerConfig.api; + const models = providerConfig.models ?? []; - for (const modelDef of providerConfig.models) { + if (models.length === 0) { + // Override-only config: just needs baseUrl (to override built-in) + if (!providerConfig.baseUrl) { + throw new Error( + `Provider ${providerName}: must specify either "baseUrl" (for override) or "models" (for replacement).`, + ); + } + } else { + // Full replacement: needs baseUrl and apiKey + if (!providerConfig.baseUrl) { + throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`); + } + if (!providerConfig.apiKey) { + throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`); + } + } + + for (const modelDef of models) { const hasModelApi = !!modelDef.api; if (!hasProviderApi && !hasModelApi) { @@ -232,10 +294,15 @@ export class ModelRegistry { const models: Model[] = []; for (const [providerName, providerConfig] of Object.entries(config.providers)) { - // Store API key config for fallback resolver - this.customProviderApiKeys.set(providerName, providerConfig.apiKey); + const modelDefs = providerConfig.models ?? []; + if (modelDefs.length === 0) continue; // Override-only, no custom models - for (const modelDef of providerConfig.models) { + // Store API key config for fallback resolver + if (providerConfig.apiKey) { + this.customProviderApiKeys.set(providerName, providerConfig.apiKey); + } + + for (const modelDef of modelDefs) { const api = modelDef.api || providerConfig.api; if (!api) continue; @@ -246,19 +313,20 @@ export class ModelRegistry { : undefined; // If authHeader is true, add Authorization header with resolved API key - if (providerConfig.authHeader) { + if (providerConfig.authHeader && providerConfig.apiKey) { const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey); if (resolvedKey) { headers = { ...headers, Authorization: `Bearer ${resolvedKey}` }; } } + // baseUrl is validated to exist for providers with models models.push({ id: modelDef.id, name: modelDef.name, api: api as Api, provider: providerName, - baseUrl: providerConfig.baseUrl, + baseUrl: providerConfig.baseUrl!, reasoning: modelDef.reasoning, input: modelDef.input as ("text" | "image")[], cost: modelDef.cost, diff --git a/packages/coding-agent/test/model-registry.test.ts b/packages/coding-agent/test/model-registry.test.ts index 1c318ca7..26bad48e 100644 --- a/packages/coding-agent/test/model-registry.test.ts +++ b/packages/coding-agent/test/model-registry.test.ts @@ -53,7 +53,116 @@ describe("ModelRegistry", () => { return registry.getAll().filter((m) => m.provider === provider); } - describe("provider override", () => { + /** Create a baseUrl-only override (no custom models) */ + function overrideConfig(baseUrl: string, headers?: Record) { + return { baseUrl, ...(headers && { headers }) }; + } + + /** Write raw providers config (for mixed override/replacement scenarios) */ + function writeRawModelsJson(providers: Record) { + writeFileSync(modelsJsonPath, JSON.stringify({ providers })); + } + + describe("baseUrl override (no custom models)", () => { + test("overriding baseUrl keeps all built-in models", () => { + writeRawModelsJson({ + anthropic: overrideConfig("https://my-proxy.example.com/v1"), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const anthropicModels = getModelsForProvider(registry, "anthropic"); + + // Should have multiple built-in models, not just one + expect(anthropicModels.length).toBeGreaterThan(1); + expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true); + }); + + test("overriding baseUrl changes URL on all built-in models", () => { + writeRawModelsJson({ + anthropic: overrideConfig("https://my-proxy.example.com/v1"), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const anthropicModels = getModelsForProvider(registry, "anthropic"); + + // All models should have the new baseUrl + for (const model of anthropicModels) { + expect(model.baseUrl).toBe("https://my-proxy.example.com/v1"); + } + }); + + test("overriding headers merges with model headers", () => { + writeRawModelsJson({ + anthropic: overrideConfig("https://my-proxy.example.com/v1", { + "X-Custom-Header": "custom-value", + }), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const anthropicModels = getModelsForProvider(registry, "anthropic"); + + for (const model of anthropicModels) { + expect(model.headers?.["X-Custom-Header"]).toBe("custom-value"); + } + }); + + test("baseUrl-only override does not affect other providers", () => { + writeRawModelsJson({ + anthropic: overrideConfig("https://my-proxy.example.com/v1"), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + const googleModels = getModelsForProvider(registry, "google"); + + // Google models should still have their original baseUrl + expect(googleModels.length).toBeGreaterThan(0); + expect(googleModels[0].baseUrl).not.toBe("https://my-proxy.example.com/v1"); + }); + + test("can mix baseUrl override and full replacement", () => { + writeRawModelsJson({ + // baseUrl-only for anthropic + anthropic: overrideConfig("https://anthropic-proxy.example.com/v1"), + // Full replacement for google + google: providerConfig( + "https://google-proxy.example.com/v1", + [{ id: "gemini-custom" }], + "google-generative-ai", + ), + }); + + const registry = new ModelRegistry(authStorage, modelsJsonPath); + + // Anthropic: multiple built-in models with new baseUrl + const anthropicModels = getModelsForProvider(registry, "anthropic"); + expect(anthropicModels.length).toBeGreaterThan(1); + expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1"); + + // Google: single custom model + const googleModels = getModelsForProvider(registry, "google"); + expect(googleModels).toHaveLength(1); + expect(googleModels[0].id).toBe("gemini-custom"); + }); + + test("refresh() picks up baseUrl override changes", () => { + writeRawModelsJson({ + anthropic: overrideConfig("https://first-proxy.example.com/v1"), + }); + const registry = new ModelRegistry(authStorage, modelsJsonPath); + + expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1"); + + // Update and refresh + writeRawModelsJson({ + anthropic: overrideConfig("https://second-proxy.example.com/v1"), + }); + registry.refresh(); + + expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://second-proxy.example.com/v1"); + }); + }); + + describe("provider replacement (with custom models)", () => { test("custom provider with same name as built-in replaces built-in models", () => { writeModelsJson({ anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), @@ -81,7 +190,11 @@ describe("ModelRegistry", () => { test("multiple built-in providers can be overridden", () => { writeModelsJson({ anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]), - google: providerConfig("https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai"), + google: providerConfig( + "https://google-proxy.example.com/v1", + [{ id: "gemini-proxy" }], + "google-generative-ai", + ), }); const registry = new ModelRegistry(authStorage, modelsJsonPath);