mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 04:01:56 +00:00
feat(coding-agent): merge custom models with built-ins by id
This commit is contained in:
parent
ddd5a65c7e
commit
76a6a74517
4 changed files with 147 additions and 122 deletions
|
|
@ -130,9 +130,7 @@ interface ProviderOverride {
|
|||
/** Result of loading custom models from models.json */
|
||||
interface CustomModelsResult {
|
||||
models: Model<Api>[];
|
||||
/** Providers with custom models (full replacement) */
|
||||
replacedProviders: Set<string>;
|
||||
/** Providers with only baseUrl/headers override (no custom models) */
|
||||
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
|
||||
overrides: Map<string, ProviderOverride>;
|
||||
/** Per-model overrides: provider -> modelId -> override */
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>;
|
||||
|
|
@ -140,7 +138,7 @@ interface CustomModelsResult {
|
|||
}
|
||||
|
||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||
return { models: [], replacedProviders: new Set(), overrides: new Map(), modelOverrides: new Map(), error };
|
||||
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
|
||||
}
|
||||
|
||||
function mergeCompat(
|
||||
|
|
@ -260,10 +258,9 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load custom models from models.json first (to know which providers to skip/override)
|
||||
// Load custom models and overrides from models.json
|
||||
const {
|
||||
models: customModels,
|
||||
replacedProviders,
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error,
|
||||
|
|
@ -274,8 +271,8 @@ export class ModelRegistry {
|
|||
// Keep built-in models even if custom models failed to load
|
||||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides, modelOverrides);
|
||||
let combined = [...builtInModels, ...customModels];
|
||||
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
|
||||
let combined = this.mergeCustomModels(builtInModels, customModels);
|
||||
|
||||
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||
|
|
@ -288,41 +285,52 @@ export class ModelRegistry {
|
|||
this.models = combined;
|
||||
}
|
||||
|
||||
/** Load built-in models, skipping replaced providers and applying overrides */
|
||||
/** Load built-in models and apply provider/model overrides */
|
||||
private loadBuiltInModels(
|
||||
replacedProviders: Set<string>,
|
||||
overrides: Map<string, ProviderOverride>,
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>,
|
||||
): Model<Api>[] {
|
||||
return getProviders()
|
||||
.filter((provider) => !replacedProviders.has(provider))
|
||||
.flatMap((provider) => {
|
||||
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||
const providerOverride = overrides.get(provider);
|
||||
const perModelOverrides = modelOverrides.get(provider);
|
||||
return getProviders().flatMap((provider) => {
|
||||
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||
const providerOverride = overrides.get(provider);
|
||||
const perModelOverrides = modelOverrides.get(provider);
|
||||
|
||||
return models.map((m) => {
|
||||
let model = m;
|
||||
return models.map((m) => {
|
||||
let model = m;
|
||||
|
||||
// Apply provider-level baseUrl/headers override
|
||||
if (providerOverride) {
|
||||
const resolvedHeaders = resolveHeaders(providerOverride.headers);
|
||||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
|
||||
};
|
||||
}
|
||||
// Apply provider-level baseUrl/headers override
|
||||
if (providerOverride) {
|
||||
const resolvedHeaders = resolveHeaders(providerOverride.headers);
|
||||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply per-model override
|
||||
const modelOverride = perModelOverrides?.get(m.id);
|
||||
if (modelOverride) {
|
||||
model = applyModelOverride(model, modelOverride);
|
||||
}
|
||||
// Apply per-model override
|
||||
const modelOverride = perModelOverrides?.get(m.id);
|
||||
if (modelOverride) {
|
||||
model = applyModelOverride(model, modelOverride);
|
||||
}
|
||||
|
||||
return model;
|
||||
});
|
||||
return model;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
|
||||
private mergeCustomModels(builtInModels: Model<Api>[], customModels: Model<Api>[]): Model<Api>[] {
|
||||
const merged = [...builtInModels];
|
||||
for (const customModel of customModels) {
|
||||
const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = customModel;
|
||||
} else {
|
||||
merged.push(customModel);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
|
|
@ -347,35 +355,30 @@ export class ModelRegistry {
|
|||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
// Separate providers into "full replacement" (has models) vs "override-only" (no models)
|
||||
const replacedProviders = new Set<string>();
|
||||
const overrides = new Map<string, ProviderOverride>();
|
||||
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
if (providerConfig.models && providerConfig.models.length > 0) {
|
||||
// Has custom models -> full replacement
|
||||
replacedProviders.add(providerName);
|
||||
} else {
|
||||
// No models -> just override baseUrl/headers on built-in
|
||||
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
|
||||
if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) {
|
||||
overrides.set(providerName, {
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
headers: providerConfig.headers,
|
||||
apiKey: providerConfig.apiKey,
|
||||
});
|
||||
// Store API key for fallback resolver
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect per-model overrides (works with both full replacement and override-only)
|
||||
// Store API key for fallback resolver.
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
if (providerConfig.modelOverrides) {
|
||||
modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides)));
|
||||
}
|
||||
}
|
||||
|
||||
return { models: this.parseModels(config), replacedProviders, overrides, modelOverrides, error: undefined };
|
||||
return { models: this.parseModels(config), overrides, modelOverrides, error: undefined };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
||||
|
|
@ -399,7 +402,7 @@ export class ModelRegistry {
|
|||
throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`);
|
||||
}
|
||||
} else {
|
||||
// Full replacement: needs baseUrl and apiKey
|
||||
// Custom models are merged into provider models and require endpoint + auth.
|
||||
if (!providerConfig.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue