mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-22 01:02:16 +00:00
Enhance provider override to support baseUrl-only mode
Builds on #406 to support simpler proxy use case: - Override just baseUrl to route built-in provider through proxy - All built-in models preserved, no need to redefine them - Full replacement still works when models array is provided
This commit is contained in:
parent
243104fa18
commit
d747ec6e23
4 changed files with 238 additions and 25 deletions
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Built-in provider overrides in `models.json`: override just `baseUrl` to route a built-in provider through a proxy while keeping all its models, or define `models` to fully replace the provider ([#406](https://github.com/badlogic/pi-mono/pull/406) by [@yevhen](https://github.com/yevhen))
|
||||||
- Automatic image resizing: images larger than 2000x2000 are resized for better model compatibility. Original dimensions are injected into the prompt. Controlled via `/settings` or `images.autoResize` in settings.json. ([#402](https://github.com/badlogic/pi-mono/pull/402) by [@mitsuhiko](https://github.com/mitsuhiko))
|
- Automatic image resizing: images larger than 2000x2000 are resized for better model compatibility. Original dimensions are injected into the prompt. Controlled via `/settings` or `images.autoResize` in settings.json. ([#402](https://github.com/badlogic/pi-mono/pull/402) by [@mitsuhiko](https://github.com/mitsuhiko))
|
||||||
- Alt+Enter keybind to queue follow-up messages while agent is streaming
|
- Alt+Enter keybind to queue follow-up messages while agent is streaming
|
||||||
- `Theme` and `ThemeColor` types now exported for hooks using `ctx.ui.custom()`
|
- `Theme` and `ThemeColor` types now exported for hooks using `ctx.ui.custom()`
|
||||||
|
|
|
||||||
|
|
@ -465,6 +465,37 @@ Add custom models (Ollama, vLLM, LM Studio, etc.) via `~/.pi/agent/models.json`:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Overriding built-in providers:**
|
||||||
|
|
||||||
|
To route a built-in provider (anthropic, openai, google, etc.) through a proxy without redefining all models, just specify the `baseUrl`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"anthropic": {
|
||||||
|
"baseUrl": "https://my-proxy.example.com/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
All built-in Anthropic models remain available with the new endpoint. Existing OAuth or API key auth continues to work.
|
||||||
|
|
||||||
|
To fully replace a built-in provider with custom models, include the `models` array:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"anthropic": {
|
||||||
|
"baseUrl": "https://my-proxy.example.com/v1",
|
||||||
|
"apiKey": "ANTHROPIC_API_KEY",
|
||||||
|
"api": "anthropic-messages",
|
||||||
|
"models": [...]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
**Authorization header:** Set `authHeader: true` to add `Authorization: Bearer <apiKey>` automatically.
|
**Authorization header:** Set `authHeader: true` to add `Authorization: Bearer <apiKey>` automatically.
|
||||||
|
|
||||||
**OpenAI compatibility (`compat` field):**
|
**OpenAI compatibility (`compat` field):**
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,8 @@ const ModelDefinitionSchema = Type.Object({
|
||||||
});
|
});
|
||||||
|
|
||||||
const ProviderConfigSchema = Type.Object({
|
const ProviderConfigSchema = Type.Object({
|
||||||
baseUrl: Type.String({ minLength: 1 }),
|
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||||
apiKey: Type.String({ minLength: 1 }),
|
apiKey: Type.Optional(Type.String({ minLength: 1 })),
|
||||||
api: Type.Optional(
|
api: Type.Optional(
|
||||||
Type.Union([
|
Type.Union([
|
||||||
Type.Literal("openai-completions"),
|
Type.Literal("openai-completions"),
|
||||||
|
|
@ -65,7 +65,7 @@ const ProviderConfigSchema = Type.Object({
|
||||||
),
|
),
|
||||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||||
authHeader: Type.Optional(Type.Boolean()),
|
authHeader: Type.Optional(Type.Boolean()),
|
||||||
models: Type.Array(ModelDefinitionSchema),
|
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||||
});
|
});
|
||||||
|
|
||||||
const ModelsConfigSchema = Type.Object({
|
const ModelsConfigSchema = Type.Object({
|
||||||
|
|
@ -74,15 +74,25 @@ const ModelsConfigSchema = Type.Object({
|
||||||
|
|
||||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||||
|
|
||||||
|
/** Provider override config (baseUrl, headers, apiKey) without custom models */
|
||||||
|
interface ProviderOverride {
|
||||||
|
baseUrl?: string;
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
apiKey?: string;
|
||||||
|
}
|
||||||
|
|
||||||
/** Result of loading custom models from models.json */
|
/** Result of loading custom models from models.json */
|
||||||
interface CustomModelsResult {
|
interface CustomModelsResult {
|
||||||
models: Model<Api>[];
|
models: Model<Api>[];
|
||||||
providers: Set<string>;
|
/** Providers with custom models (full replacement) */
|
||||||
|
replacedProviders: Set<string>;
|
||||||
|
/** Providers with only baseUrl/headers override (no custom models) */
|
||||||
|
overrides: Map<string, ProviderOverride>;
|
||||||
error: string | undefined;
|
error: string | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||||
return { models: [], providers: new Set(), error };
|
return { models: [], replacedProviders: new Set(), overrides: new Map(), error };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -137,17 +147,20 @@ export class ModelRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
private loadModels(): void {
|
private loadModels(): void {
|
||||||
// Load custom models from models.json first (to know which providers to skip)
|
// Load custom models from models.json first (to know which providers to skip/override)
|
||||||
const { models: customModels, providers: customProviders, error } = this.modelsJsonPath
|
const {
|
||||||
? this.loadCustomModels(this.modelsJsonPath)
|
models: customModels,
|
||||||
: emptyCustomModelsResult();
|
replacedProviders,
|
||||||
|
overrides,
|
||||||
|
error,
|
||||||
|
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
this.loadError = error;
|
this.loadError = error;
|
||||||
// Keep built-in models even if custom models failed to load
|
// Keep built-in models even if custom models failed to load
|
||||||
}
|
}
|
||||||
|
|
||||||
const builtInModels = this.loadBuiltInModels(customProviders);
|
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
|
||||||
const combined = [...builtInModels, ...customModels];
|
const combined = [...builtInModels, ...customModels];
|
||||||
|
|
||||||
// Update github-copilot base URL based on OAuth credentials
|
// Update github-copilot base URL based on OAuth credentials
|
||||||
|
|
@ -163,11 +176,22 @@ export class ModelRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Load built-in models, skipping providers that are overridden in models.json */
|
/** Load built-in models, skipping replaced providers and applying overrides */
|
||||||
private loadBuiltInModels(skipProviders: Set<string>): Model<Api>[] {
|
private loadBuiltInModels(replacedProviders: Set<string>, overrides: Map<string, ProviderOverride>): Model<Api>[] {
|
||||||
return getProviders()
|
return getProviders()
|
||||||
.filter((provider) => !skipProviders.has(provider))
|
.filter((provider) => !replacedProviders.has(provider))
|
||||||
.flatMap((provider) => getModels(provider as KnownProvider) as Model<Api>[]);
|
.flatMap((provider) => {
|
||||||
|
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||||
|
const override = overrides.get(provider);
|
||||||
|
if (!override) return models;
|
||||||
|
|
||||||
|
// Apply baseUrl/headers override to all models of this provider
|
||||||
|
return models.map((m) => ({
|
||||||
|
...m,
|
||||||
|
baseUrl: override.baseUrl ?? m.baseUrl,
|
||||||
|
headers: override.headers ? { ...m.headers, ...override.headers } : m.headers,
|
||||||
|
}));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||||
|
|
@ -192,9 +216,29 @@ export class ModelRegistry {
|
||||||
// Additional validation
|
// Additional validation
|
||||||
this.validateConfig(config);
|
this.validateConfig(config);
|
||||||
|
|
||||||
// Parse models and collect provider names
|
// Separate providers into "full replacement" (has models) vs "override-only" (no models)
|
||||||
const providers = new Set(Object.keys(config.providers));
|
const replacedProviders = new Set<string>();
|
||||||
return { models: this.parseModels(config), providers, error: undefined };
|
const overrides = new Map<string, ProviderOverride>();
|
||||||
|
|
||||||
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||||
|
if (providerConfig.models && providerConfig.models.length > 0) {
|
||||||
|
// Has custom models -> full replacement
|
||||||
|
replacedProviders.add(providerName);
|
||||||
|
} else {
|
||||||
|
// No models -> just override baseUrl/headers on built-in
|
||||||
|
overrides.set(providerName, {
|
||||||
|
baseUrl: providerConfig.baseUrl,
|
||||||
|
headers: providerConfig.headers,
|
||||||
|
apiKey: providerConfig.apiKey,
|
||||||
|
});
|
||||||
|
// Store API key for fallback resolver
|
||||||
|
if (providerConfig.apiKey) {
|
||||||
|
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { models: this.parseModels(config), replacedProviders, overrides, error: undefined };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof SyntaxError) {
|
if (error instanceof SyntaxError) {
|
||||||
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
||||||
|
|
@ -208,8 +252,26 @@ export class ModelRegistry {
|
||||||
private validateConfig(config: ModelsConfig): void {
|
private validateConfig(config: ModelsConfig): void {
|
||||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||||
const hasProviderApi = !!providerConfig.api;
|
const hasProviderApi = !!providerConfig.api;
|
||||||
|
const models = providerConfig.models ?? [];
|
||||||
|
|
||||||
for (const modelDef of providerConfig.models) {
|
if (models.length === 0) {
|
||||||
|
// Override-only config: just needs baseUrl (to override built-in)
|
||||||
|
if (!providerConfig.baseUrl) {
|
||||||
|
throw new Error(
|
||||||
|
`Provider ${providerName}: must specify either "baseUrl" (for override) or "models" (for replacement).`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Full replacement: needs baseUrl and apiKey
|
||||||
|
if (!providerConfig.baseUrl) {
|
||||||
|
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
|
||||||
|
}
|
||||||
|
if (!providerConfig.apiKey) {
|
||||||
|
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const modelDef of models) {
|
||||||
const hasModelApi = !!modelDef.api;
|
const hasModelApi = !!modelDef.api;
|
||||||
|
|
||||||
if (!hasProviderApi && !hasModelApi) {
|
if (!hasProviderApi && !hasModelApi) {
|
||||||
|
|
@ -232,10 +294,15 @@ export class ModelRegistry {
|
||||||
const models: Model<Api>[] = [];
|
const models: Model<Api>[] = [];
|
||||||
|
|
||||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||||
// Store API key config for fallback resolver
|
const modelDefs = providerConfig.models ?? [];
|
||||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
if (modelDefs.length === 0) continue; // Override-only, no custom models
|
||||||
|
|
||||||
for (const modelDef of providerConfig.models) {
|
// Store API key config for fallback resolver
|
||||||
|
if (providerConfig.apiKey) {
|
||||||
|
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const modelDef of modelDefs) {
|
||||||
const api = modelDef.api || providerConfig.api;
|
const api = modelDef.api || providerConfig.api;
|
||||||
if (!api) continue;
|
if (!api) continue;
|
||||||
|
|
||||||
|
|
@ -246,19 +313,20 @@ export class ModelRegistry {
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
// If authHeader is true, add Authorization header with resolved API key
|
// If authHeader is true, add Authorization header with resolved API key
|
||||||
if (providerConfig.authHeader) {
|
if (providerConfig.authHeader && providerConfig.apiKey) {
|
||||||
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
|
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
|
||||||
if (resolvedKey) {
|
if (resolvedKey) {
|
||||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// baseUrl is validated to exist for providers with models
|
||||||
models.push({
|
models.push({
|
||||||
id: modelDef.id,
|
id: modelDef.id,
|
||||||
name: modelDef.name,
|
name: modelDef.name,
|
||||||
api: api as Api,
|
api: api as Api,
|
||||||
provider: providerName,
|
provider: providerName,
|
||||||
baseUrl: providerConfig.baseUrl,
|
baseUrl: providerConfig.baseUrl!,
|
||||||
reasoning: modelDef.reasoning,
|
reasoning: modelDef.reasoning,
|
||||||
input: modelDef.input as ("text" | "image")[],
|
input: modelDef.input as ("text" | "image")[],
|
||||||
cost: modelDef.cost,
|
cost: modelDef.cost,
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,116 @@ describe("ModelRegistry", () => {
|
||||||
return registry.getAll().filter((m) => m.provider === provider);
|
return registry.getAll().filter((m) => m.provider === provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("provider override", () => {
|
/** Create a baseUrl-only override (no custom models) */
|
||||||
|
function overrideConfig(baseUrl: string, headers?: Record<string, string>) {
|
||||||
|
return { baseUrl, ...(headers && { headers }) };
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Write raw providers config (for mixed override/replacement scenarios) */
|
||||||
|
function writeRawModelsJson(providers: Record<string, unknown>) {
|
||||||
|
writeFileSync(modelsJsonPath, JSON.stringify({ providers }));
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("baseUrl override (no custom models)", () => {
|
||||||
|
test("overriding baseUrl keeps all built-in models", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||||
|
|
||||||
|
// Should have multiple built-in models, not just one
|
||||||
|
expect(anthropicModels.length).toBeGreaterThan(1);
|
||||||
|
expect(anthropicModels.some((m) => m.id.includes("claude"))).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("overriding baseUrl changes URL on all built-in models", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||||
|
|
||||||
|
// All models should have the new baseUrl
|
||||||
|
for (const model of anthropicModels) {
|
||||||
|
expect(model.baseUrl).toBe("https://my-proxy.example.com/v1");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test("overriding headers merges with model headers", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://my-proxy.example.com/v1", {
|
||||||
|
"X-Custom-Header": "custom-value",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||||
|
|
||||||
|
for (const model of anthropicModels) {
|
||||||
|
expect(model.headers?.["X-Custom-Header"]).toBe("custom-value");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test("baseUrl-only override does not affect other providers", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
const googleModels = getModelsForProvider(registry, "google");
|
||||||
|
|
||||||
|
// Google models should still have their original baseUrl
|
||||||
|
expect(googleModels.length).toBeGreaterThan(0);
|
||||||
|
expect(googleModels[0].baseUrl).not.toBe("https://my-proxy.example.com/v1");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("can mix baseUrl override and full replacement", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
// baseUrl-only for anthropic
|
||||||
|
anthropic: overrideConfig("https://anthropic-proxy.example.com/v1"),
|
||||||
|
// Full replacement for google
|
||||||
|
google: providerConfig(
|
||||||
|
"https://google-proxy.example.com/v1",
|
||||||
|
[{ id: "gemini-custom" }],
|
||||||
|
"google-generative-ai",
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
|
||||||
|
// Anthropic: multiple built-in models with new baseUrl
|
||||||
|
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||||
|
expect(anthropicModels.length).toBeGreaterThan(1);
|
||||||
|
expect(anthropicModels[0].baseUrl).toBe("https://anthropic-proxy.example.com/v1");
|
||||||
|
|
||||||
|
// Google: single custom model
|
||||||
|
const googleModels = getModelsForProvider(registry, "google");
|
||||||
|
expect(googleModels).toHaveLength(1);
|
||||||
|
expect(googleModels[0].id).toBe("gemini-custom");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("refresh() picks up baseUrl override changes", () => {
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://first-proxy.example.com/v1"),
|
||||||
|
});
|
||||||
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
|
||||||
|
expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1");
|
||||||
|
|
||||||
|
// Update and refresh
|
||||||
|
writeRawModelsJson({
|
||||||
|
anthropic: overrideConfig("https://second-proxy.example.com/v1"),
|
||||||
|
});
|
||||||
|
registry.refresh();
|
||||||
|
|
||||||
|
expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://second-proxy.example.com/v1");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("provider replacement (with custom models)", () => {
|
||||||
test("custom provider with same name as built-in replaces built-in models", () => {
|
test("custom provider with same name as built-in replaces built-in models", () => {
|
||||||
writeModelsJson({
|
writeModelsJson({
|
||||||
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||||
|
|
@ -81,7 +190,11 @@ describe("ModelRegistry", () => {
|
||||||
test("multiple built-in providers can be overridden", () => {
|
test("multiple built-in providers can be overridden", () => {
|
||||||
writeModelsJson({
|
writeModelsJson({
|
||||||
anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]),
|
anthropic: providerConfig("https://anthropic-proxy.example.com/v1", [{ id: "claude-proxy" }]),
|
||||||
google: providerConfig("https://google-proxy.example.com/v1", [{ id: "gemini-proxy" }], "google-generative-ai"),
|
google: providerConfig(
|
||||||
|
"https://google-proxy.example.com/v1",
|
||||||
|
[{ id: "gemini-proxy" }],
|
||||||
|
"google-generative-ai",
|
||||||
|
),
|
||||||
});
|
});
|
||||||
|
|
||||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue