Enhance provider override to support baseUrl-only mode

Builds on #406 to support simpler proxy use case:
- Override just baseUrl to route built-in provider through proxy
- All built-in models preserved, no need to redefine them
- Full replacement still works when models array is provided
This commit is contained in:
Mario Zechner 2026-01-03 01:06:08 +01:00
parent 243104fa18
commit d747ec6e23
4 changed files with 238 additions and 25 deletions

View file

@ -53,8 +53,8 @@ const ModelDefinitionSchema = Type.Object({
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
@ -65,7 +65,7 @@ const ProviderConfigSchema = Type.Object({
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
});
const ModelsConfigSchema = Type.Object({
@ -74,15 +74,25 @@ const ModelsConfigSchema = Type.Object({
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/** Provider override config (baseUrl, headers, apiKey) without custom models */
interface ProviderOverride {
baseUrl?: string;
headers?: Record<string, string>;
apiKey?: string;
}
/** Result of loading custom models from models.json */
interface CustomModelsResult {
models: Model<Api>[];
providers: Set<string>;
/** Providers with custom models (full replacement) */
replacedProviders: Set<string>;
/** Providers with only baseUrl/headers override (no custom models) */
overrides: Map<string, ProviderOverride>;
error: string | undefined;
}
function emptyCustomModelsResult(error?: string): CustomModelsResult {
return { models: [], providers: new Set(), error };
return { models: [], replacedProviders: new Set(), overrides: new Map(), error };
}
/**
@ -137,17 +147,20 @@ export class ModelRegistry {
}
private loadModels(): void {
// Load custom models from models.json first (to know which providers to skip)
const { models: customModels, providers: customProviders, error } = this.modelsJsonPath
? this.loadCustomModels(this.modelsJsonPath)
: emptyCustomModelsResult();
// Load custom models from models.json first (to know which providers to skip/override)
const {
models: customModels,
replacedProviders,
overrides,
error,
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
if (error) {
this.loadError = error;
// Keep built-in models even if custom models failed to load
}
const builtInModels = this.loadBuiltInModels(customProviders);
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
const combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth credentials
@ -163,11 +176,22 @@ export class ModelRegistry {
}
}
/** Load built-in models, skipping providers that are overridden in models.json */
private loadBuiltInModels(skipProviders: Set<string>): Model<Api>[] {
/** Load built-in models, skipping replaced providers and applying overrides */
private loadBuiltInModels(replacedProviders: Set<string>, overrides: Map<string, ProviderOverride>): Model<Api>[] {
return getProviders()
.filter((provider) => !skipProviders.has(provider))
.flatMap((provider) => getModels(provider as KnownProvider) as Model<Api>[]);
.filter((provider) => !replacedProviders.has(provider))
.flatMap((provider) => {
const models = getModels(provider as KnownProvider) as Model<Api>[];
const override = overrides.get(provider);
if (!override) return models;
// Apply baseUrl/headers override to all models of this provider
return models.map((m) => ({
...m,
baseUrl: override.baseUrl ?? m.baseUrl,
headers: override.headers ? { ...m.headers, ...override.headers } : m.headers,
}));
});
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
@ -192,9 +216,29 @@ export class ModelRegistry {
// Additional validation
this.validateConfig(config);
// Parse models and collect provider names
const providers = new Set(Object.keys(config.providers));
return { models: this.parseModels(config), providers, error: undefined };
// Separate providers into "full replacement" (has models) vs "override-only" (no models)
const replacedProviders = new Set<string>();
const overrides = new Map<string, ProviderOverride>();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
if (providerConfig.models && providerConfig.models.length > 0) {
// Has custom models -> full replacement
replacedProviders.add(providerName);
} else {
// No models -> just override baseUrl/headers on built-in
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
apiKey: providerConfig.apiKey,
});
// Store API key for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
}
}
return { models: this.parseModels(config), replacedProviders, overrides, error: undefined };
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
@ -208,8 +252,26 @@ export class ModelRegistry {
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
for (const modelDef of providerConfig.models) {
if (models.length === 0) {
// Override-only config: just needs baseUrl (to override built-in)
if (!providerConfig.baseUrl) {
throw new Error(
`Provider ${providerName}: must specify either "baseUrl" (for override) or "models" (for replacement).`,
);
}
} else {
// Full replacement: needs baseUrl and apiKey
if (!providerConfig.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
}
if (!providerConfig.apiKey) {
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
}
}
for (const modelDef of models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
@ -232,10 +294,15 @@ export class ModelRegistry {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for fallback resolver
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
for (const modelDef of providerConfig.models) {
// Store API key config for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
for (const modelDef of modelDefs) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
@ -246,19 +313,20 @@ export class ModelRegistry {
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
if (providerConfig.authHeader && providerConfig.apiKey) {
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
// baseUrl is validated to exist for providers with models
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
baseUrl: providerConfig.baseUrl!,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,