refactor(oauth): add provider registry

This commit is contained in:
Mario Zechner 2026-01-24 21:17:05 +01:00
parent 89636cfe6e
commit 3256d3c083
19 changed files with 655 additions and 291 deletions

View file

@ -9,13 +9,11 @@
import {
getEnvApiKey,
getOAuthApiKey,
loginAnthropic,
loginAntigravity,
loginGeminiCli,
loginGitHubCopilot,
loginOpenAICodex,
getOAuthProvider,
getOAuthProviders,
type OAuthCredentials,
type OAuthProvider,
type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
@ -156,54 +154,14 @@ export class AuthStorage {
/**
* Login to an OAuth provider.
*/
async login(
provider: OAuthProvider,
callbacks: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
/** For providers with local callback servers (e.g., openai-codex), races with browser callback */
onManualCodeInput?: () => Promise<string>;
/** For cancellation support (e.g., github-copilot polling) */
signal?: AbortSignal;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "openai-codex":
credentials = await loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
this.set(provider, { type: "oauth", ...credentials });
const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
}
/**
@ -219,8 +177,13 @@ export class AuthStorage {
* This ensures only one instance refreshes while others wait and use the result.
*/
private async refreshOAuthTokenWithLock(
provider: OAuthProvider,
providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
// Ensure auth file exists for locking
if (!existsSync(this.authPath)) {
const dir = dirname(this.authPath);
@ -250,7 +213,7 @@ export class AuthStorage {
// Re-read file after acquiring lock - another instance may have refreshed
this.reload();
const cred = this.data[provider];
const cred = this.data[providerId];
if (cred?.type !== "oauth") {
return null;
}
@ -259,10 +222,7 @@ export class AuthStorage {
// (another instance may have already refreshed it)
if (Date.now() < cred.expires) {
// Token is now valid - another instance refreshed it
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
const apiKey = needsProjectId
? JSON.stringify({ token: cred.access, projectId: cred.projectId })
: cred.access;
const apiKey = provider.getApiKey(cred);
return { apiKey, newCredentials: cred };
}
@ -274,9 +234,9 @@ export class AuthStorage {
}
}
const result = await getOAuthApiKey(provider, oauthCreds);
const result = await getOAuthApiKey(providerId, oauthCreds);
if (result) {
this.data[provider] = { type: "oauth", ...result.newCredentials };
this.data[providerId] = { type: "oauth", ...result.newCredentials };
this.save();
return result;
}
@ -303,41 +263,44 @@ export class AuthStorage {
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | undefined> {
async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[provider];
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return cred.key;
}
if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) {
// Use locked refresh to prevent race conditions
try {
const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider);
const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) {
return result.apiKey;
}
} catch {
// Refresh failed - re-read file to check if another instance succeeded
this.reload();
const updatedCred = this.data[provider];
const updatedCred = this.data[providerId];
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
// Another instance refreshed successfully, use those credentials
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
return needsProjectId
? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId })
: updatedCred.access;
return provider.getApiKey(updatedCred);
}
// Refresh truly failed - return undefined so model discovery skips this provider
@ -346,16 +309,22 @@ export class AuthStorage {
}
} else {
// Token not expired, use current access token
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access;
return provider.getApiKey(cred);
}
}
// Fall back to environment variable
const envKey = getEnvApiKey(provider);
const envKey = getEnvApiKey(providerId);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? undefined;
return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
}
}

View file

@ -76,6 +76,9 @@ export type {
MessageRenderOptions,
ModelSelectEvent,
ModelSelectSource,
// Provider Registration
ProviderConfig,
ProviderModelConfig,
ReadToolResultEvent,
// Commands
RegisteredCommand,

View file

@ -32,6 +32,7 @@ import type {
ExtensionRuntime,
LoadExtensionsResult,
MessageRenderer,
ProviderConfig,
RegisteredCommand,
ToolDefinition,
} from "./types.js";
@ -122,6 +123,7 @@ export function createExtensionRuntime(): ExtensionRuntime {
getThinkingLevel: notInitialized,
setThinkingLevel: notInitialized,
flagValues: new Map(),
pendingProviderRegistrations: [],
};
}
@ -238,6 +240,10 @@ function createExtensionAPI(
runtime.setThinkingLevel(level);
},
registerProvider(name: string, config: ProviderConfig) {
runtime.pendingProviderRegistrations.push({ name, config });
},
events: eventBus,
} as ExtensionAPI;

View file

@ -203,6 +203,12 @@ export class ExtensionRunner {
this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact;
// Process provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
}
this.runtime.pendingProviderRegistrations = [];
}
bindCommandContext(actions?: ExtensionCommandContextActions): void {

View file

@ -14,7 +14,15 @@ import type {
AgentToolUpdateCallback,
ThinkingLevel,
} from "@mariozechner/pi-agent-core";
import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
import type {
Api,
ImageContent,
Model,
OAuthCredentials,
OAuthLoginCallbacks,
TextContent,
ToolResultMessage,
} from "@mariozechner/pi-ai";
import type {
AutocompleteItem,
Component,
@ -854,10 +862,119 @@ export interface ExtensionAPI {
/** Set thinking level (clamped to model capabilities). */
setThinkingLevel(level: ThinkingLevel): void;
// =========================================================================
// Provider Registration
// =========================================================================
/**
* Register or override a model provider.
*
* If `models` is provided: replaces all existing models for this provider.
* If only `baseUrl` is provided: overrides the URL for existing models.
* If `oauth` is provided: registers OAuth provider for /login support.
*
* @example
* // Register a new provider with custom models
* pi.registerProvider("my-proxy", {
* baseUrl: "https://proxy.example.com",
* apiKey: "PROXY_API_KEY",
* api: "anthropic-messages",
* models: [
* {
* id: "claude-sonnet-4-20250514",
* name: "Claude 4 Sonnet (proxy)",
* reasoning: false,
* input: ["text", "image"],
* cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
* contextWindow: 200000,
* maxTokens: 16384
* }
* ]
* });
*
* @example
* // Override baseUrl for an existing provider
* pi.registerProvider("anthropic", {
* baseUrl: "https://proxy.example.com"
* });
*
* @example
* // Register provider with OAuth support
* pi.registerProvider("corporate-ai", {
* baseUrl: "https://ai.corp.com",
* api: "openai-responses",
* models: [...],
* oauth: {
* name: "Corporate AI (SSO)",
* async login(callbacks) { ... },
* async refreshToken(credentials) { ... },
* getApiKey(credentials) { return credentials.access; }
* }
* });
*/
registerProvider(name: string, config: ProviderConfig): void;
/** Shared event bus for extension communication. */
events: EventBus;
}
// ============================================================================
// Provider Registration Types
// ============================================================================
/** Configuration for registering a provider via pi.registerProvider(). */
export interface ProviderConfig {
/** Base URL for the API endpoint. Required when defining models. */
baseUrl?: string;
/** API key or environment variable name. Required when defining models (unless oauth provided). */
apiKey?: string;
/** API type. Required at provider or model level when defining models. */
api?: Api;
/** Custom headers to include in requests. */
headers?: Record<string, string>;
/** If true, adds Authorization: Bearer header with the resolved API key. */
authHeader?: boolean;
/** Models to register. If provided, replaces all existing models for this provider. */
models?: ProviderModelConfig[];
/** OAuth provider for /login support. The `id` is set automatically from the provider name. */
oauth?: {
/** Display name for the provider in login UI. */
name: string;
/** Run the login flow, return credentials to persist. */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Refresh expired credentials, return updated credentials to persist. */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider. */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl based on credentials). */
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
};
}
/** Configuration for a model within a provider. */
export interface ProviderModelConfig {
/** Model ID (e.g., "claude-sonnet-4-20250514"). */
id: string;
/** Display name (e.g., "Claude 4 Sonnet"). */
name: string;
/** API type override for this model. */
api?: Api;
/** Whether the model supports extended thinking. */
reasoning: boolean;
/** Supported input types. */
input: ("text" | "image")[];
/** Cost per token (for tracking, can be 0). */
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
/** Maximum context window size in tokens. */
contextWindow: number;
/** Maximum output tokens. */
maxTokens: number;
/** Custom headers for this model. */
headers?: Record<string, string>;
/** OpenAI compatibility settings. */
compat?: Model<Api>["compat"];
}
/** Extension factory function type. Supports both sync and async initialization. */
export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise<void>;
@ -926,6 +1043,8 @@ export type SetLabelHandler = (entryId: string, label: string | undefined) => vo
*/
export interface ExtensionRuntimeState {
flagValues: Map<string, boolean | string>;
/** Provider registrations queued during extension loading, processed when runner binds */
pendingProviderRegistrations: Array<{ name: string; config: ProviderConfig }>;
}
/**

View file

@ -4,12 +4,12 @@
import {
type Api,
getGitHubCopilotBaseUrl,
getModels,
getProviders,
type KnownProvider,
type Model,
normalizeDomain,
type OAuthProviderInterface,
registerOAuthProvider,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
@ -26,7 +26,13 @@ const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])),
});
const OpenAIResponsesCompatSchema = Type.Object({
@ -174,6 +180,7 @@ export function clearApiKeyCache(): void {
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
@ -200,6 +207,10 @@ export class ModelRegistry {
this.customProviderApiKeys.clear();
this.loadError = undefined;
this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
}
/**
@ -224,19 +235,17 @@ export class ModelRegistry {
}
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
const combined = [...builtInModels, ...customModels];
let combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth credentials
const copilotCred = this.authStorage.get("github-copilot");
if (copilotCred?.type === "oauth") {
const domain = copilotCred.enterpriseUrl
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
} else {
this.models = combined;
// Let OAuth providers modify their models (e.g., update baseUrl)
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
const cred = this.authStorage.get(oauthProvider.id);
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
combined = oauthProvider.modifyModels(combined, cred);
}
}
this.models = combined;
}
/** Load built-in models, skipping replaced providers and applying overrides */
@ -449,4 +458,118 @@ export class ModelRegistry {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
}
if (!config.apiKey && !config.oauth) {
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
};
});
}
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
reasoning: boolean;
input: ("text" | "image")[];
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
}>;
}