mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 00:02:11 +00:00
refactor(oauth): add provider registry
This commit is contained in:
parent
89636cfe6e
commit
3256d3c083
19 changed files with 655 additions and 291 deletions
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
import {
|
||||
type Api,
|
||||
getGitHubCopilotBaseUrl,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
normalizeDomain,
|
||||
type OAuthProviderInterface,
|
||||
registerOAuthProvider,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
|
|
@ -26,7 +26,13 @@ const OpenAICompletionsCompatSchema = Type.Object({
|
|||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])),
|
||||
});
|
||||
|
||||
const OpenAIResponsesCompatSchema = Type.Object({
|
||||
|
|
@ -174,6 +180,7 @@ export function clearApiKeyCache(): void {
|
|||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
|
|
@ -200,6 +207,10 @@ export class ModelRegistry {
|
|||
this.customProviderApiKeys.clear();
|
||||
this.loadError = undefined;
|
||||
this.loadModels();
|
||||
|
||||
for (const [providerName, config] of this.registeredProviders.entries()) {
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -224,19 +235,17 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
let combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth credentials
|
||||
const copilotCred = this.authStorage.get("github-copilot");
|
||||
if (copilotCred?.type === "oauth") {
|
||||
const domain = copilotCred.enterpriseUrl
|
||||
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
|
||||
: undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
|
||||
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
} else {
|
||||
this.models = combined;
|
||||
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||
const cred = this.authStorage.get(oauthProvider.id);
|
||||
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
|
||||
combined = oauthProvider.modifyModels(combined, cred);
|
||||
}
|
||||
}
|
||||
|
||||
this.models = combined;
|
||||
}
|
||||
|
||||
/** Load built-in models, skipping replaced providers and applying overrides */
|
||||
|
|
@ -449,4 +458,118 @@ export class ModelRegistry {
|
|||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a provider dynamically (from extensions).
|
||||
*
|
||||
* If provider has models: replaces all existing models for this provider.
|
||||
* If provider has only baseUrl/headers: overrides existing models' URLs.
|
||||
* If provider has oauth: registers OAuth provider for /login support.
|
||||
*/
|
||||
registerProvider(providerName: string, config: ProviderConfigInput): void {
|
||||
this.registeredProviders.set(providerName, config);
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
|
||||
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
|
||||
// Register OAuth provider if provided
|
||||
if (config.oauth) {
|
||||
// Ensure the OAuth provider ID matches the provider name
|
||||
const oauthProvider: OAuthProviderInterface = {
|
||||
...config.oauth,
|
||||
id: providerName,
|
||||
};
|
||||
registerOAuthProvider(oauthProvider);
|
||||
}
|
||||
|
||||
// Store API key for auth resolution
|
||||
if (config.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||
}
|
||||
|
||||
if (config.models && config.models.length > 0) {
|
||||
// Full replacement: remove existing models for this provider
|
||||
this.models = this.models.filter((m) => m.provider !== providerName);
|
||||
|
||||
// Validate required fields
|
||||
if (!config.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
|
||||
}
|
||||
if (!config.apiKey && !config.oauth) {
|
||||
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
|
||||
}
|
||||
|
||||
// Parse and add new models
|
||||
for (const modelDef of config.models) {
|
||||
const api = modelDef.api || config.api;
|
||||
if (!api) {
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
const providerHeaders = resolveHeaders(config.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header
|
||||
if (config.authHeader && config.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(config.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
this.models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: config.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
} else if (config.baseUrl) {
|
||||
// Override-only: update baseUrl/headers for existing models
|
||||
const resolvedHeaders = resolveHeaders(config.headers);
|
||||
this.models = this.models.map((m) => {
|
||||
if (m.provider !== providerName) return m;
|
||||
return {
|
||||
...m,
|
||||
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Input type for registerProvider API.
|
||||
*/
|
||||
export interface ProviderConfigInput {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
oauth?: Omit<OAuthProviderInterface, "id">;
|
||||
models?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
api?: Api;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
compat?: Model<Api>["compat"];
|
||||
}>;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue