From 3256d3c0836637b9b5734ce94882a17a7da37976 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Sat, 24 Jan 2026 21:17:05 +0100 Subject: [PATCH] refactor(oauth): add provider registry --- packages/ai/src/cli.ts | 103 +++-------- packages/ai/src/utils/oauth/anthropic.ts | 22 ++- packages/ai/src/utils/oauth/github-copilot.ts | 37 +++- .../ai/src/utils/oauth/google-antigravity.ts | 29 ++- .../ai/src/utils/oauth/google-gemini-cli.ts | 29 ++- packages/ai/src/utils/oauth/index.ts | 170 ++++++++---------- packages/ai/src/utils/oauth/openai-codex.ts | 25 ++- packages/ai/src/utils/oauth/types.ts | 48 +++-- packages/coding-agent/docs/extensions.md | 61 +++++++ .../coding-agent/src/core/auth-storage.ts | 115 +++++------- .../coding-agent/src/core/extensions/index.ts | 3 + .../src/core/extensions/loader.ts | 6 + .../src/core/extensions/runner.ts | 6 + .../coding-agent/src/core/extensions/types.ts | 121 ++++++++++++- .../coding-agent/src/core/model-registry.ts | 149 +++++++++++++-- packages/coding-agent/src/index.ts | 2 + packages/coding-agent/src/main.ts | 6 +- .../interactive/components/oauth-selector.ts | 11 +- .../src/modes/interactive/interactive-mode.ts | 3 +- 19 files changed, 655 insertions(+), 291 deletions(-) diff --git a/packages/ai/src/cli.ts b/packages/ai/src/cli.ts index 1a865c1c..158881c9 100644 --- a/packages/ai/src/cli.ts +++ b/packages/ai/src/cli.ts @@ -2,13 +2,8 @@ import { existsSync, readFileSync, writeFileSync } from "fs"; import { createInterface } from "readline"; -import { loginAnthropic } from "./utils/oauth/anthropic.js"; -import { loginGitHubCopilot } from "./utils/oauth/github-copilot.js"; -import { loginAntigravity } from "./utils/oauth/google-antigravity.js"; -import { loginGeminiCli } from "./utils/oauth/google-gemini-cli.js"; -import { getOAuthProviders } from "./utils/oauth/index.js"; -import { loginOpenAICodex } from "./utils/oauth/openai-codex.js"; -import type { OAuthCredentials, OAuthProvider } from "./utils/oauth/types.js"; +import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js"; +import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js"; const AUTH_FILE = "auth.json"; const PROVIDERS = getOAuthProviders(); @@ -30,78 +25,31 @@ function saveAuth(auth: Record): v writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8"); } -async function login(provider: OAuthProvider): Promise { - const rl = createInterface({ input: process.stdin, output: process.stdout }); +async function login(providerId: OAuthProviderId): Promise { + const provider = getOAuthProvider(providerId); + if (!provider) { + console.error(`Unknown provider: ${providerId}`); + process.exit(1); + } + const rl = createInterface({ input: process.stdin, output: process.stdout }); const promptFn = (msg: string) => prompt(rl, `${msg} `); try { - let credentials: OAuthCredentials; - - switch (provider) { - case "anthropic": - credentials = await loginAnthropic( - (url) => { - console.log(`\nOpen this URL in your browser:\n${url}\n`); - }, - async () => { - return await promptFn("Paste the authorization code:"); - }, - ); - break; - - case "github-copilot": - credentials = await loginGitHubCopilot({ - onAuth: (url, instructions) => { - console.log(`\nOpen this URL in your browser:\n${url}`); - if (instructions) console.log(instructions); - console.log(); - }, - onPrompt: async (p) => { - return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`); - }, - onProgress: (msg) => console.log(msg), - }); - break; - - case "google-gemini-cli": - credentials = await loginGeminiCli( - (info) => { - console.log(`\nOpen this URL in your browser:\n${info.url}`); - if (info.instructions) console.log(info.instructions); - console.log(); - }, - (msg) => console.log(msg), - ); - break; - - case "google-antigravity": - credentials = await loginAntigravity( - (info) => { - console.log(`\nOpen this URL in your browser:\n${info.url}`); - if (info.instructions) console.log(info.instructions); - console.log(); - }, - (msg) => console.log(msg), - ); - break; - case "openai-codex": - credentials = await loginOpenAICodex({ - onAuth: (info) => { - console.log(`\nOpen this URL in your browser:\n${info.url}`); - if (info.instructions) console.log(info.instructions); - console.log(); - }, - onPrompt: async (p) => { - return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`); - }, - onProgress: (msg) => console.log(msg), - }); - break; - } + const credentials = await provider.login({ + onAuth: (info) => { + console.log(`\nOpen this URL in your browser:\n${info.url}`); + if (info.instructions) console.log(info.instructions); + console.log(); + }, + onPrompt: async (p) => { + return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`); + }, + onProgress: (msg) => console.log(msg), + }); const auth = loadAuth(); - auth[provider] = { type: "oauth", ...credentials }; + auth[providerId] = { type: "oauth", ...credentials }; saveAuth(auth); console.log(`\nCredentials saved to ${AUTH_FILE}`); @@ -115,6 +63,7 @@ async function main(): Promise { const command = args[0]; if (!command || command === "help" || command === "--help" || command === "-h") { + const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n"); console.log(`Usage: npx @mariozechner/pi-ai [provider] Commands: @@ -122,11 +71,7 @@ Commands: list List available providers Providers: - anthropic Anthropic (Claude Pro/Max) - github-copilot GitHub Copilot - google-gemini-cli Google Gemini CLI - google-antigravity Antigravity (Gemini 3, Claude, GPT-OSS) - openai-codex OpenAI Codex (ChatGPT Plus/Pro) +${providerList} Examples: npx @mariozechner/pi-ai login # interactive provider selection @@ -145,7 +90,7 @@ Examples: } if (command === "login") { - let provider = args[1] as OAuthProvider | undefined; + let provider = args[1] as OAuthProviderId | undefined; if (!provider) { const rl = createInterface({ input: process.stdin, output: process.stdout }); diff --git a/packages/ai/src/utils/oauth/anthropic.ts b/packages/ai/src/utils/oauth/anthropic.ts index 74a2228c..5355df0d 100644 --- a/packages/ai/src/utils/oauth/anthropic.ts +++ b/packages/ai/src/utils/oauth/anthropic.ts @@ -3,7 +3,7 @@ */ import { generatePKCE } from "./pkce.js"; -import type { OAuthCredentials } from "./types.js"; +import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js"; const decode = (s: string) => atob(s); const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl"); @@ -116,3 +116,23 @@ export async function refreshAnthropicToken(refreshToken: string): Promise { + return loginAnthropic( + (url) => callbacks.onAuth({ url }), + () => callbacks.onPrompt({ message: "Paste the authorization code:" }), + ); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + return refreshAnthropicToken(credentials.refresh); + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, +}; diff --git a/packages/ai/src/utils/oauth/github-copilot.ts b/packages/ai/src/utils/oauth/github-copilot.ts index 06661fb9..1b0fe623 100644 --- a/packages/ai/src/utils/oauth/github-copilot.ts +++ b/packages/ai/src/utils/oauth/github-copilot.ts @@ -3,7 +3,12 @@ */ import { getModels } from "../../models.js"; -import type { OAuthCredentials } from "./types.js"; +import type { Api, Model } from "../../types.js"; +import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js"; + +type CopilotCredentials = OAuthCredentials & { + enterpriseUrl?: string; +}; const decode = (s: string) => atob(s); const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg="); @@ -344,3 +349,33 @@ export async function loginGitHubCopilot(options: { await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined); return credentials; } + +export const githubCopilotOAuthProvider: OAuthProviderInterface = { + id: "github-copilot", + name: "GitHub Copilot", + + async login(callbacks: OAuthLoginCallbacks): Promise { + return loginGitHubCopilot({ + onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }), + onPrompt: callbacks.onPrompt, + onProgress: callbacks.onProgress, + signal: callbacks.signal, + }); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + const creds = credentials as CopilotCredentials; + return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl); + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, + + modifyModels(models: Model[], credentials: OAuthCredentials): Model[] { + const creds = credentials as CopilotCredentials; + const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined; + const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain); + return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)); + }, +}; diff --git a/packages/ai/src/utils/oauth/google-antigravity.ts b/packages/ai/src/utils/oauth/google-antigravity.ts index 86428f49..70033fa3 100644 --- a/packages/ai/src/utils/oauth/google-antigravity.ts +++ b/packages/ai/src/utils/oauth/google-antigravity.ts @@ -8,7 +8,11 @@ import type { Server } from "http"; import { generatePKCE } from "./pkce.js"; -import type { OAuthCredentials } from "./types.js"; +import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js"; + +type AntigravityCredentials = OAuthCredentials & { + projectId: string; +}; // Antigravity OAuth credentials (different from Gemini CLI) const decode = (s: string) => atob(s); @@ -411,3 +415,26 @@ export async function loginAntigravity( server.server.close(); } } + +export const antigravityOAuthProvider: OAuthProviderInterface = { + id: "google-antigravity", + name: "Antigravity (Gemini 3, Claude, GPT-OSS)", + usesCallbackServer: true, + + async login(callbacks: OAuthLoginCallbacks): Promise { + return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + const creds = credentials as AntigravityCredentials; + if (!creds.projectId) { + throw new Error("Antigravity credentials missing projectId"); + } + return refreshAntigravityToken(creds.refresh, creds.projectId); + }, + + getApiKey(credentials: OAuthCredentials): string { + const creds = credentials as AntigravityCredentials; + return JSON.stringify({ token: creds.access, projectId: creds.projectId }); + }, +}; diff --git a/packages/ai/src/utils/oauth/google-gemini-cli.ts b/packages/ai/src/utils/oauth/google-gemini-cli.ts index 8f258336..29fdc710 100644 --- a/packages/ai/src/utils/oauth/google-gemini-cli.ts +++ b/packages/ai/src/utils/oauth/google-gemini-cli.ts @@ -8,7 +8,11 @@ import type { Server } from "http"; import { generatePKCE } from "./pkce.js"; -import type { OAuthCredentials } from "./types.js"; +import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js"; + +type GeminiCredentials = OAuthCredentials & { + projectId: string; +}; const decode = (s: string) => atob(s); const CLIENT_ID = decode( @@ -553,3 +557,26 @@ export async function loginGeminiCli( server.server.close(); } } + +export const geminiCliOAuthProvider: OAuthProviderInterface = { + id: "google-gemini-cli", + name: "Google Cloud Code Assist (Gemini CLI)", + usesCallbackServer: true, + + async login(callbacks: OAuthLoginCallbacks): Promise { + return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + const creds = credentials as GeminiCredentials; + if (!creds.projectId) { + throw new Error("Google Cloud credentials missing projectId"); + } + return refreshGoogleCloudToken(creds.refresh, creds.projectId); + }, + + getApiKey(credentials: OAuthCredentials): string { + const creds = credentials as GeminiCredentials; + return JSON.stringify({ token: creds.access, projectId: creds.projectId }); + }, +}; diff --git a/packages/ai/src/utils/oauth/index.ts b/packages/ai/src/utils/oauth/index.ts index 1f9ff1a2..1077d748 100644 --- a/packages/ai/src/utils/oauth/index.ts +++ b/packages/ai/src/utils/oauth/index.ts @@ -10,100 +10,111 @@ */ // Anthropic -export { loginAnthropic, refreshAnthropicToken } from "./anthropic.js"; +export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js"; // GitHub Copilot export { getGitHubCopilotBaseUrl, + githubCopilotOAuthProvider, loginGitHubCopilot, normalizeDomain, refreshGitHubCopilotToken, } from "./github-copilot.js"; // Google Antigravity -export { - loginAntigravity, - refreshAntigravityToken, -} from "./google-antigravity.js"; +export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js"; // Google Gemini CLI -export { - loginGeminiCli, - refreshGoogleCloudToken, -} from "./google-gemini-cli.js"; +export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js"; // OpenAI Codex (ChatGPT OAuth) -export { - loginOpenAICodex, - refreshOpenAICodexToken, -} from "./openai-codex.js"; +export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js"; export * from "./types.js"; // ============================================================================ -// High-level API +// Provider Registry // ============================================================================ -import { refreshAnthropicToken } from "./anthropic.js"; -import { refreshGitHubCopilotToken } from "./github-copilot.js"; -import { refreshAntigravityToken } from "./google-antigravity.js"; -import { refreshGoogleCloudToken } from "./google-gemini-cli.js"; -import { refreshOpenAICodexToken } from "./openai-codex.js"; -import type { OAuthCredentials, OAuthProvider, OAuthProviderInfo } from "./types.js"; +import { anthropicOAuthProvider } from "./anthropic.js"; +import { githubCopilotOAuthProvider } from "./github-copilot.js"; +import { antigravityOAuthProvider } from "./google-antigravity.js"; +import { geminiCliOAuthProvider } from "./google-gemini-cli.js"; +import { openaiCodexOAuthProvider } from "./openai-codex.js"; +import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js"; + +const oauthProviderRegistry = new Map([ + [anthropicOAuthProvider.id, anthropicOAuthProvider], + [githubCopilotOAuthProvider.id, githubCopilotOAuthProvider], + [geminiCliOAuthProvider.id, geminiCliOAuthProvider], + [antigravityOAuthProvider.id, antigravityOAuthProvider], + [openaiCodexOAuthProvider.id, openaiCodexOAuthProvider], +]); + +/** + * Get an OAuth provider by ID + */ +export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined { + return oauthProviderRegistry.get(id); +} + +/** + * Register a custom OAuth provider + */ +export function registerOAuthProvider(provider: OAuthProviderInterface): void { + oauthProviderRegistry.set(provider.id, provider); +} + +/** + * Get all registered OAuth providers + */ +export function getOAuthProviders(): OAuthProviderInterface[] { + return Array.from(oauthProviderRegistry.values()); +} + +/** + * @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[] + */ +export function getOAuthProviderInfoList(): OAuthProviderInfo[] { + return getOAuthProviders().map((p) => ({ + id: p.id, + name: p.name, + available: true, + })); +} + +// ============================================================================ +// High-level API (uses provider registry) +// ============================================================================ /** * Refresh token for any OAuth provider. - * Saves the new credentials and returns the new access token. + * @deprecated Use getOAuthProvider(id).refreshToken() instead */ export async function refreshOAuthToken( - provider: OAuthProvider, + providerId: OAuthProviderId, credentials: OAuthCredentials, ): Promise { - if (!credentials) { - throw new Error(`No OAuth credentials found for ${provider}`); + const provider = getOAuthProvider(providerId); + if (!provider) { + throw new Error(`Unknown OAuth provider: ${providerId}`); } - - let newCredentials: OAuthCredentials; - - switch (provider) { - case "anthropic": - newCredentials = await refreshAnthropicToken(credentials.refresh); - break; - case "github-copilot": - newCredentials = await refreshGitHubCopilotToken(credentials.refresh, credentials.enterpriseUrl); - break; - case "google-gemini-cli": - if (!credentials.projectId) { - throw new Error("Google Cloud credentials missing projectId"); - } - newCredentials = await refreshGoogleCloudToken(credentials.refresh, credentials.projectId); - break; - case "google-antigravity": - if (!credentials.projectId) { - throw new Error("Antigravity credentials missing projectId"); - } - newCredentials = await refreshAntigravityToken(credentials.refresh, credentials.projectId); - break; - case "openai-codex": - newCredentials = await refreshOpenAICodexToken(credentials.refresh); - break; - default: - throw new Error(`Unknown OAuth provider: ${provider}`); - } - - return newCredentials; + return provider.refreshToken(credentials); } /** * Get API key for a provider from OAuth credentials. * Automatically refreshes expired tokens. * - * For google-gemini-cli and antigravity, returns JSON-encoded { token, projectId } - * - * @returns API key string, or null if no credentials + * @returns API key string and updated credentials, or null if no credentials * @throws Error if refresh fails */ export async function getOAuthApiKey( - provider: OAuthProvider, + providerId: OAuthProviderId, credentials: Record, ): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> { - let creds = credentials[provider]; + const provider = getOAuthProvider(providerId); + if (!provider) { + throw new Error(`Unknown OAuth provider: ${providerId}`); + } + + let creds = credentials[providerId]; if (!creds) { return null; } @@ -111,47 +122,12 @@ export async function getOAuthApiKey( // Refresh if expired if (Date.now() >= creds.expires) { try { - creds = await refreshOAuthToken(provider, creds); + creds = await provider.refreshToken(creds); } catch (_error) { - throw new Error(`Failed to refresh OAuth token for ${provider}`); + throw new Error(`Failed to refresh OAuth token for ${providerId}`); } } - // For providers that need projectId, return JSON - const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; - const apiKey = needsProjectId ? JSON.stringify({ token: creds.access, projectId: creds.projectId }) : creds.access; + const apiKey = provider.getApiKey(creds); return { newCredentials: creds, apiKey }; } - -/** - * Get list of OAuth providers - */ -export function getOAuthProviders(): OAuthProviderInfo[] { - return [ - { - id: "anthropic", - name: "Anthropic (Claude Pro/Max)", - available: true, - }, - { - id: "openai-codex", - name: "ChatGPT Plus/Pro (Codex Subscription)", - available: true, - }, - { - id: "github-copilot", - name: "GitHub Copilot", - available: true, - }, - { - id: "google-gemini-cli", - name: "Google Cloud Code Assist (Gemini CLI)", - available: true, - }, - { - id: "google-antigravity", - name: "Antigravity (Gemini 3, Claude, GPT-OSS)", - available: true, - }, - ]; -} diff --git a/packages/ai/src/utils/oauth/openai-codex.ts b/packages/ai/src/utils/oauth/openai-codex.ts index 8b0578ab..820168d9 100644 --- a/packages/ai/src/utils/oauth/openai-codex.ts +++ b/packages/ai/src/utils/oauth/openai-codex.ts @@ -18,7 +18,7 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version } import { generatePKCE } from "./pkce.js"; -import type { OAuthCredentials, OAuthPrompt } from "./types.js"; +import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js"; const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"; const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"; @@ -430,3 +430,26 @@ export async function refreshOpenAICodexToken(refreshToken: string): Promise { + return loginOpenAICodex({ + onAuth: callbacks.onAuth, + onPrompt: callbacks.onPrompt, + onProgress: callbacks.onProgress, + onManualCodeInput: callbacks.onManualCodeInput, + }); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + return refreshOpenAICodexToken(credentials.refresh); + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, +}; diff --git a/packages/ai/src/utils/oauth/types.ts b/packages/ai/src/utils/oauth/types.ts index 245d93f6..e3520342 100644 --- a/packages/ai/src/utils/oauth/types.ts +++ b/packages/ai/src/utils/oauth/types.ts @@ -1,19 +1,16 @@ +import type { Api, Model } from "../../types.js"; + export type OAuthCredentials = { refresh: string; access: string; expires: number; - enterpriseUrl?: string; - projectId?: string; - email?: string; - accountId?: string; + [key: string]: unknown; }; -export type OAuthProvider = - | "anthropic" - | "github-copilot" - | "google-gemini-cli" - | "google-antigravity" - | "openai-codex"; +export type OAuthProviderId = string; + +/** @deprecated Use OAuthProviderId instead */ +export type OAuthProvider = OAuthProviderId; export type OAuthPrompt = { message: string; @@ -26,8 +23,37 @@ export type OAuthAuthInfo = { instructions?: string; }; +export interface OAuthLoginCallbacks { + onAuth: (info: OAuthAuthInfo) => void; + onPrompt: (prompt: OAuthPrompt) => Promise; + onProgress?: (message: string) => void; + onManualCodeInput?: () => Promise; + signal?: AbortSignal; +} + +export interface OAuthProviderInterface { + readonly id: OAuthProviderId; + readonly name: string; + + /** Run the login flow, return credentials to persist */ + login(callbacks: OAuthLoginCallbacks): Promise; + + /** Whether login uses a local callback server and supports manual code input. */ + usesCallbackServer?: boolean; + + /** Refresh expired credentials, return updated credentials to persist */ + refreshToken(credentials: OAuthCredentials): Promise; + + /** Convert credentials to API key string for the provider */ + getApiKey(credentials: OAuthCredentials): string; + + /** Optional: modify models for this provider (e.g., update baseUrl) */ + modifyModels?(models: Model[], credentials: OAuthCredentials): Model[]; +} + +/** @deprecated Use OAuthProviderInterface instead */ export interface OAuthProviderInfo { - id: OAuthProvider; + id: OAuthProviderId; name: string; available: boolean; } diff --git a/packages/coding-agent/docs/extensions.md b/packages/coding-agent/docs/extensions.md index a71b5ea8..879d461c 100644 --- a/packages/coding-agent/docs/extensions.md +++ b/packages/coding-agent/docs/extensions.md @@ -1146,6 +1146,67 @@ pi.events.on("my:event", (data) => { ... }); pi.events.emit("my:event", { ... }); ``` +### pi.registerProvider(name, config) + +Register or override a model provider dynamically. Useful for proxies, custom endpoints, or team-wide model configurations. + +```typescript +// Register a new provider with custom models +pi.registerProvider("my-proxy", { + baseUrl: "https://proxy.example.com", + apiKey: "PROXY_API_KEY", // env var name or literal + api: "anthropic-messages", + models: [ + { + id: "claude-sonnet-4-20250514", + name: "Claude 4 Sonnet (proxy)", + reasoning: false, + input: ["text", "image"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 16384 + } + ] +}); + +// Override baseUrl for an existing provider (keeps all models) +pi.registerProvider("anthropic", { + baseUrl: "https://proxy.example.com" +}); + +// Register provider with OAuth support for /login +pi.registerProvider("corporate-ai", { + baseUrl: "https://ai.corp.com", + api: "openai-responses", + models: [...], + oauth: { + name: "Corporate AI (SSO)", + async login(callbacks) { + // Custom OAuth flow + callbacks.onAuth({ url: "https://sso.corp.com/..." }); + const code = await callbacks.onPrompt({ message: "Enter code:" }); + return { refresh: code, access: code, expires: Date.now() + 3600000 }; + }, + async refreshToken(credentials) { + // Refresh logic + return credentials; + }, + getApiKey(credentials) { + return credentials.access; + } + } +}); +``` + +**Config options:** +- `baseUrl` - API endpoint URL. Required when defining models. +- `apiKey` - API key or environment variable name. Required when defining models (unless `oauth` provided). +- `api` - API type: `"anthropic-messages"`, `"openai-completions"`, `"openai-responses"`, etc. +- `headers` - Custom headers to include in requests. +- `authHeader` - If true, adds `Authorization: Bearer` header automatically. +- `models` - Array of model definitions. If provided, replaces all existing models for this provider. +- `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu. + ## State Management Extensions with state should store it in tool result `details` for proper branching support: diff --git a/packages/coding-agent/src/core/auth-storage.ts b/packages/coding-agent/src/core/auth-storage.ts index 7042c137..47b798ae 100644 --- a/packages/coding-agent/src/core/auth-storage.ts +++ b/packages/coding-agent/src/core/auth-storage.ts @@ -9,13 +9,11 @@ import { getEnvApiKey, getOAuthApiKey, - loginAnthropic, - loginAntigravity, - loginGeminiCli, - loginGitHubCopilot, - loginOpenAICodex, + getOAuthProvider, + getOAuthProviders, type OAuthCredentials, - type OAuthProvider, + type OAuthLoginCallbacks, + type OAuthProviderId, } from "@mariozechner/pi-ai"; import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; import { dirname, join } from "path"; @@ -156,54 +154,14 @@ export class AuthStorage { /** * Login to an OAuth provider. */ - async login( - provider: OAuthProvider, - callbacks: { - onAuth: (info: { url: string; instructions?: string }) => void; - onPrompt: (prompt: { message: string; placeholder?: string }) => Promise; - onProgress?: (message: string) => void; - /** For providers with local callback servers (e.g., openai-codex), races with browser callback */ - onManualCodeInput?: () => Promise; - /** For cancellation support (e.g., github-copilot polling) */ - signal?: AbortSignal; - }, - ): Promise { - let credentials: OAuthCredentials; - - switch (provider) { - case "anthropic": - credentials = await loginAnthropic( - (url) => callbacks.onAuth({ url }), - () => callbacks.onPrompt({ message: "Paste the authorization code:" }), - ); - break; - case "github-copilot": - credentials = await loginGitHubCopilot({ - onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }), - onPrompt: callbacks.onPrompt, - onProgress: callbacks.onProgress, - signal: callbacks.signal, - }); - break; - case "google-gemini-cli": - credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput); - break; - case "google-antigravity": - credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput); - break; - case "openai-codex": - credentials = await loginOpenAICodex({ - onAuth: callbacks.onAuth, - onPrompt: callbacks.onPrompt, - onProgress: callbacks.onProgress, - onManualCodeInput: callbacks.onManualCodeInput, - }); - break; - default: - throw new Error(`Unknown OAuth provider: ${provider}`); + async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise { + const provider = getOAuthProvider(providerId); + if (!provider) { + throw new Error(`Unknown OAuth provider: ${providerId}`); } - this.set(provider, { type: "oauth", ...credentials }); + const credentials = await provider.login(callbacks); + this.set(providerId, { type: "oauth", ...credentials }); } /** @@ -219,8 +177,13 @@ export class AuthStorage { * This ensures only one instance refreshes while others wait and use the result. */ private async refreshOAuthTokenWithLock( - provider: OAuthProvider, + providerId: OAuthProviderId, ): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> { + const provider = getOAuthProvider(providerId); + if (!provider) { + return null; + } + // Ensure auth file exists for locking if (!existsSync(this.authPath)) { const dir = dirname(this.authPath); @@ -250,7 +213,7 @@ export class AuthStorage { // Re-read file after acquiring lock - another instance may have refreshed this.reload(); - const cred = this.data[provider]; + const cred = this.data[providerId]; if (cred?.type !== "oauth") { return null; } @@ -259,10 +222,7 @@ export class AuthStorage { // (another instance may have already refreshed it) if (Date.now() < cred.expires) { // Token is now valid - another instance refreshed it - const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; - const apiKey = needsProjectId - ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) - : cred.access; + const apiKey = provider.getApiKey(cred); return { apiKey, newCredentials: cred }; } @@ -274,9 +234,9 @@ export class AuthStorage { } } - const result = await getOAuthApiKey(provider, oauthCreds); + const result = await getOAuthApiKey(providerId, oauthCreds); if (result) { - this.data[provider] = { type: "oauth", ...result.newCredentials }; + this.data[providerId] = { type: "oauth", ...result.newCredentials }; this.save(); return result; } @@ -303,41 +263,44 @@ export class AuthStorage { * 4. Environment variable * 5. Fallback resolver (models.json custom providers) */ - async getApiKey(provider: string): Promise { + async getApiKey(providerId: string): Promise { // Runtime override takes highest priority - const runtimeKey = this.runtimeOverrides.get(provider); + const runtimeKey = this.runtimeOverrides.get(providerId); if (runtimeKey) { return runtimeKey; } - const cred = this.data[provider]; + const cred = this.data[providerId]; if (cred?.type === "api_key") { return cred.key; } if (cred?.type === "oauth") { + const provider = getOAuthProvider(providerId); + if (!provider) { + // Unknown OAuth provider, can't get API key + return undefined; + } + // Check if token needs refresh const needsRefresh = Date.now() >= cred.expires; if (needsRefresh) { // Use locked refresh to prevent race conditions try { - const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider); + const result = await this.refreshOAuthTokenWithLock(providerId); if (result) { return result.apiKey; } } catch { // Refresh failed - re-read file to check if another instance succeeded this.reload(); - const updatedCred = this.data[provider]; + const updatedCred = this.data[providerId]; if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) { // Another instance refreshed successfully, use those credentials - const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; - return needsProjectId - ? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId }) - : updatedCred.access; + return provider.getApiKey(updatedCred); } // Refresh truly failed - return undefined so model discovery skips this provider @@ -346,16 +309,22 @@ export class AuthStorage { } } else { // Token not expired, use current access token - const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; - return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access; + return provider.getApiKey(cred); } } // Fall back to environment variable - const envKey = getEnvApiKey(provider); + const envKey = getEnvApiKey(providerId); if (envKey) return envKey; // Fall back to custom resolver (e.g., models.json custom providers) - return this.fallbackResolver?.(provider) ?? undefined; + return this.fallbackResolver?.(providerId) ?? undefined; + } + + /** + * Get all registered OAuth providers + */ + getOAuthProviders() { + return getOAuthProviders(); } } diff --git a/packages/coding-agent/src/core/extensions/index.ts b/packages/coding-agent/src/core/extensions/index.ts index d10d0a1a..b842dd53 100644 --- a/packages/coding-agent/src/core/extensions/index.ts +++ b/packages/coding-agent/src/core/extensions/index.ts @@ -76,6 +76,9 @@ export type { MessageRenderOptions, ModelSelectEvent, ModelSelectSource, + // Provider Registration + ProviderConfig, + ProviderModelConfig, ReadToolResultEvent, // Commands RegisteredCommand, diff --git a/packages/coding-agent/src/core/extensions/loader.ts b/packages/coding-agent/src/core/extensions/loader.ts index 8c2f8721..4b46bb56 100644 --- a/packages/coding-agent/src/core/extensions/loader.ts +++ b/packages/coding-agent/src/core/extensions/loader.ts @@ -32,6 +32,7 @@ import type { ExtensionRuntime, LoadExtensionsResult, MessageRenderer, + ProviderConfig, RegisteredCommand, ToolDefinition, } from "./types.js"; @@ -122,6 +123,7 @@ export function createExtensionRuntime(): ExtensionRuntime { getThinkingLevel: notInitialized, setThinkingLevel: notInitialized, flagValues: new Map(), + pendingProviderRegistrations: [], }; } @@ -238,6 +240,10 @@ function createExtensionAPI( runtime.setThinkingLevel(level); }, + registerProvider(name: string, config: ProviderConfig) { + runtime.pendingProviderRegistrations.push({ name, config }); + }, + events: eventBus, } as ExtensionAPI; diff --git a/packages/coding-agent/src/core/extensions/runner.ts b/packages/coding-agent/src/core/extensions/runner.ts index c7deee56..73197acc 100644 --- a/packages/coding-agent/src/core/extensions/runner.ts +++ b/packages/coding-agent/src/core/extensions/runner.ts @@ -203,6 +203,12 @@ export class ExtensionRunner { this.shutdownHandler = contextActions.shutdown; this.getContextUsageFn = contextActions.getContextUsage; this.compactFn = contextActions.compact; + + // Process provider registrations queued during extension loading + for (const { name, config } of this.runtime.pendingProviderRegistrations) { + this.modelRegistry.registerProvider(name, config); + } + this.runtime.pendingProviderRegistrations = []; } bindCommandContext(actions?: ExtensionCommandContextActions): void { diff --git a/packages/coding-agent/src/core/extensions/types.ts b/packages/coding-agent/src/core/extensions/types.ts index 2a058160..08316693 100644 --- a/packages/coding-agent/src/core/extensions/types.ts +++ b/packages/coding-agent/src/core/extensions/types.ts @@ -14,7 +14,15 @@ import type { AgentToolUpdateCallback, ThinkingLevel, } from "@mariozechner/pi-agent-core"; -import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { + Api, + ImageContent, + Model, + OAuthCredentials, + OAuthLoginCallbacks, + TextContent, + ToolResultMessage, +} from "@mariozechner/pi-ai"; import type { AutocompleteItem, Component, @@ -854,10 +862,119 @@ export interface ExtensionAPI { /** Set thinking level (clamped to model capabilities). */ setThinkingLevel(level: ThinkingLevel): void; + // ========================================================================= + // Provider Registration + // ========================================================================= + + /** + * Register or override a model provider. + * + * If `models` is provided: replaces all existing models for this provider. + * If only `baseUrl` is provided: overrides the URL for existing models. + * If `oauth` is provided: registers OAuth provider for /login support. + * + * @example + * // Register a new provider with custom models + * pi.registerProvider("my-proxy", { + * baseUrl: "https://proxy.example.com", + * apiKey: "PROXY_API_KEY", + * api: "anthropic-messages", + * models: [ + * { + * id: "claude-sonnet-4-20250514", + * name: "Claude 4 Sonnet (proxy)", + * reasoning: false, + * input: ["text", "image"], + * cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + * contextWindow: 200000, + * maxTokens: 16384 + * } + * ] + * }); + * + * @example + * // Override baseUrl for an existing provider + * pi.registerProvider("anthropic", { + * baseUrl: "https://proxy.example.com" + * }); + * + * @example + * // Register provider with OAuth support + * pi.registerProvider("corporate-ai", { + * baseUrl: "https://ai.corp.com", + * api: "openai-responses", + * models: [...], + * oauth: { + * name: "Corporate AI (SSO)", + * async login(callbacks) { ... }, + * async refreshToken(credentials) { ... }, + * getApiKey(credentials) { return credentials.access; } + * } + * }); + */ + registerProvider(name: string, config: ProviderConfig): void; + /** Shared event bus for extension communication. */ events: EventBus; } +// ============================================================================ +// Provider Registration Types +// ============================================================================ + +/** Configuration for registering a provider via pi.registerProvider(). */ +export interface ProviderConfig { + /** Base URL for the API endpoint. Required when defining models. */ + baseUrl?: string; + /** API key or environment variable name. Required when defining models (unless oauth provided). */ + apiKey?: string; + /** API type. Required at provider or model level when defining models. */ + api?: Api; + /** Custom headers to include in requests. */ + headers?: Record; + /** If true, adds Authorization: Bearer header with the resolved API key. */ + authHeader?: boolean; + /** Models to register. If provided, replaces all existing models for this provider. */ + models?: ProviderModelConfig[]; + /** OAuth provider for /login support. The `id` is set automatically from the provider name. */ + oauth?: { + /** Display name for the provider in login UI. */ + name: string; + /** Run the login flow, return credentials to persist. */ + login(callbacks: OAuthLoginCallbacks): Promise; + /** Refresh expired credentials, return updated credentials to persist. */ + refreshToken(credentials: OAuthCredentials): Promise; + /** Convert credentials to API key string for the provider. */ + getApiKey(credentials: OAuthCredentials): string; + /** Optional: modify models for this provider (e.g., update baseUrl based on credentials). */ + modifyModels?(models: Model[], credentials: OAuthCredentials): Model[]; + }; +} + +/** Configuration for a model within a provider. */ +export interface ProviderModelConfig { + /** Model ID (e.g., "claude-sonnet-4-20250514"). */ + id: string; + /** Display name (e.g., "Claude 4 Sonnet"). */ + name: string; + /** API type override for this model. */ + api?: Api; + /** Whether the model supports extended thinking. */ + reasoning: boolean; + /** Supported input types. */ + input: ("text" | "image")[]; + /** Cost per token (for tracking, can be 0). */ + cost: { input: number; output: number; cacheRead: number; cacheWrite: number }; + /** Maximum context window size in tokens. */ + contextWindow: number; + /** Maximum output tokens. */ + maxTokens: number; + /** Custom headers for this model. */ + headers?: Record; + /** OpenAI compatibility settings. */ + compat?: Model["compat"]; +} + /** Extension factory function type. Supports both sync and async initialization. */ export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise; @@ -926,6 +1043,8 @@ export type SetLabelHandler = (entryId: string, label: string | undefined) => vo */ export interface ExtensionRuntimeState { flagValues: Map; + /** Provider registrations queued during extension loading, processed when runner binds */ + pendingProviderRegistrations: Array<{ name: string; config: ProviderConfig }>; } /** diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index a276cce8..77c8c488 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -4,12 +4,12 @@ import { type Api, - getGitHubCopilotBaseUrl, getModels, getProviders, type KnownProvider, type Model, - normalizeDomain, + type OAuthProviderInterface, + registerOAuthProvider, } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import AjvModule from "ajv"; @@ -26,7 +26,13 @@ const OpenAICompletionsCompatSchema = Type.Object({ supportsStore: Type.Optional(Type.Boolean()), supportsDeveloperRole: Type.Optional(Type.Boolean()), supportsReasoningEffort: Type.Optional(Type.Boolean()), + supportsUsageInStreaming: Type.Optional(Type.Boolean()), maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])), + requiresToolResultName: Type.Optional(Type.Boolean()), + requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()), + requiresThinkingAsText: Type.Optional(Type.Boolean()), + requiresMistralToolIds: Type.Optional(Type.Boolean()), + thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])), }); const OpenAIResponsesCompatSchema = Type.Object({ @@ -174,6 +180,7 @@ export function clearApiKeyCache(): void { export class ModelRegistry { private models: Model[] = []; private customProviderApiKeys: Map = new Map(); + private registeredProviders: Map = new Map(); private loadError: string | undefined = undefined; constructor( @@ -200,6 +207,10 @@ export class ModelRegistry { this.customProviderApiKeys.clear(); this.loadError = undefined; this.loadModels(); + + for (const [providerName, config] of this.registeredProviders.entries()) { + this.applyProviderConfig(providerName, config); + } } /** @@ -224,19 +235,17 @@ export class ModelRegistry { } const builtInModels = this.loadBuiltInModels(replacedProviders, overrides); - const combined = [...builtInModels, ...customModels]; + let combined = [...builtInModels, ...customModels]; - // Update github-copilot base URL based on OAuth credentials - const copilotCred = this.authStorage.get("github-copilot"); - if (copilotCred?.type === "oauth") { - const domain = copilotCred.enterpriseUrl - ? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined) - : undefined; - const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain); - this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)); - } else { - this.models = combined; + // Let OAuth providers modify their models (e.g., update baseUrl) + for (const oauthProvider of this.authStorage.getOAuthProviders()) { + const cred = this.authStorage.get(oauthProvider.id); + if (cred?.type === "oauth" && oauthProvider.modifyModels) { + combined = oauthProvider.modifyModels(combined, cred); + } } + + this.models = combined; } /** Load built-in models, skipping replaced providers and applying overrides */ @@ -449,4 +458,118 @@ export class ModelRegistry { const cred = this.authStorage.get(model.provider); return cred?.type === "oauth"; } + + /** + * Register a provider dynamically (from extensions). + * + * If provider has models: replaces all existing models for this provider. + * If provider has only baseUrl/headers: overrides existing models' URLs. + * If provider has oauth: registers OAuth provider for /login support. + */ + registerProvider(providerName: string, config: ProviderConfigInput): void { + this.registeredProviders.set(providerName, config); + this.applyProviderConfig(providerName, config); + } + + private applyProviderConfig(providerName: string, config: ProviderConfigInput): void { + // Register OAuth provider if provided + if (config.oauth) { + // Ensure the OAuth provider ID matches the provider name + const oauthProvider: OAuthProviderInterface = { + ...config.oauth, + id: providerName, + }; + registerOAuthProvider(oauthProvider); + } + + // Store API key for auth resolution + if (config.apiKey) { + this.customProviderApiKeys.set(providerName, config.apiKey); + } + + if (config.models && config.models.length > 0) { + // Full replacement: remove existing models for this provider + this.models = this.models.filter((m) => m.provider !== providerName); + + // Validate required fields + if (!config.baseUrl) { + throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`); + } + if (!config.apiKey && !config.oauth) { + throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`); + } + + // Parse and add new models + for (const modelDef of config.models) { + const api = modelDef.api || config.api; + if (!api) { + throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`); + } + + // Merge headers + const providerHeaders = resolveHeaders(config.headers); + const modelHeaders = resolveHeaders(modelDef.headers); + let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined; + + // If authHeader is true, add Authorization header + if (config.authHeader && config.apiKey) { + const resolvedKey = resolveConfigValue(config.apiKey); + if (resolvedKey) { + headers = { ...headers, Authorization: `Bearer ${resolvedKey}` }; + } + } + + this.models.push({ + id: modelDef.id, + name: modelDef.name, + api: api as Api, + provider: providerName, + baseUrl: config.baseUrl, + reasoning: modelDef.reasoning, + input: modelDef.input as ("text" | "image")[], + cost: modelDef.cost, + contextWindow: modelDef.contextWindow, + maxTokens: modelDef.maxTokens, + headers, + compat: modelDef.compat, + } as Model); + } + } else if (config.baseUrl) { + // Override-only: update baseUrl/headers for existing models + const resolvedHeaders = resolveHeaders(config.headers); + this.models = this.models.map((m) => { + if (m.provider !== providerName) return m; + return { + ...m, + baseUrl: config.baseUrl ?? m.baseUrl, + headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers, + }; + }); + } + } +} + +/** + * Input type for registerProvider API. + */ +export interface ProviderConfigInput { + baseUrl?: string; + apiKey?: string; + api?: Api; + headers?: Record; + authHeader?: boolean; + /** OAuth provider for /login support */ + oauth?: Omit; + models?: Array<{ + id: string; + name: string; + api?: Api; + reasoning: boolean; + input: ("text" | "image")[]; + cost: { input: number; output: number; cacheRead: number; cacheWrite: number }; + contextWindow: number; + maxTokens: number; + headers?: Record; + compat?: Model["compat"]; + }>; } diff --git a/packages/coding-agent/src/index.ts b/packages/coding-agent/src/index.ts index 5ababa02..3c5652e7 100644 --- a/packages/coding-agent/src/index.ts +++ b/packages/coding-agent/src/index.ts @@ -76,6 +76,8 @@ export type { LoadExtensionsResult, MessageRenderer, MessageRenderOptions, + ProviderConfig, + ProviderModelConfig, RegisteredCommand, RegisteredTool, SessionBeforeCompactEvent, diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index 932106bf..9869c3aa 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -432,10 +432,6 @@ export async function main(args: string[]) { // Run migrations (pass cwd for project-local migrations) const { migratedAuthProviders: migratedProviders, deprecationWarnings } = runMigrations(process.cwd()); - // Create AuthStorage and ModelRegistry upfront - const authStorage = new AuthStorage(); - const modelRegistry = new ModelRegistry(authStorage); - // First pass: parse args to get --extension paths const firstPass = parseArgs(args); @@ -443,6 +439,8 @@ export async function main(args: string[]) { const cwd = process.cwd(); const agentDir = getAgentDir(); const settingsManager = SettingsManager.create(cwd, agentDir); + const authStorage = new AuthStorage(); + const modelRegistry = new ModelRegistry(authStorage, getModelsPath()); const resourceLoader = new DefaultResourceLoader({ cwd, diff --git a/packages/coding-agent/src/modes/interactive/components/oauth-selector.ts b/packages/coding-agent/src/modes/interactive/components/oauth-selector.ts index 5b29281d..640e38fe 100644 --- a/packages/coding-agent/src/modes/interactive/components/oauth-selector.ts +++ b/packages/coding-agent/src/modes/interactive/components/oauth-selector.ts @@ -1,4 +1,4 @@ -import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai"; +import { getOAuthProviders, type OAuthProviderInterface } from "@mariozechner/pi-ai"; import { Container, getEditorKeybindings, Spacer, TruncatedText } from "@mariozechner/pi-tui"; import type { AuthStorage } from "../../../core/auth-storage.js"; import { theme } from "../theme/theme.js"; @@ -9,7 +9,7 @@ import { DynamicBorder } from "./dynamic-border.js"; */ export class OAuthSelectorComponent extends Container { private listContainer: Container; - private allProviders: OAuthProviderInfo[] = []; + private allProviders: OAuthProviderInterface[] = []; private selectedIndex: number = 0; private mode: "login" | "logout"; private authStorage: AuthStorage; @@ -66,7 +66,6 @@ export class OAuthSelectorComponent extends Container { if (!provider) continue; const isSelected = i === this.selectedIndex; - const isAvailable = provider.available; // Check if user is logged in for this provider const credentials = this.authStorage.get(provider.id); @@ -76,10 +75,10 @@ export class OAuthSelectorComponent extends Container { let line = ""; if (isSelected) { const prefix = theme.fg("accent", "→ "); - const text = isAvailable ? theme.fg("accent", provider.name) : theme.fg("dim", provider.name); + const text = theme.fg("accent", provider.name); line = prefix + text + statusIndicator; } else { - const text = isAvailable ? ` ${provider.name}` : theme.fg("dim", ` ${provider.name}`); + const text = ` ${provider.name}`; line = text + statusIndicator; } @@ -109,7 +108,7 @@ export class OAuthSelectorComponent extends Container { // Enter else if (kb.matches(keyData, "selectConfirm")) { const selectedProvider = this.allProviders[this.selectedIndex]; - if (selectedProvider?.available) { + if (selectedProvider) { this.onSelectCallback(selectedProvider.id); } } diff --git a/packages/coding-agent/src/modes/interactive/interactive-mode.ts b/packages/coding-agent/src/modes/interactive/interactive-mode.ts index b2b6d248..0ca5e50a 100644 --- a/packages/coding-agent/src/modes/interactive/interactive-mode.ts +++ b/packages/coding-agent/src/modes/interactive/interactive-mode.ts @@ -3529,8 +3529,7 @@ export class InteractiveMode { const providerName = providerInfo?.name || providerId; // Providers that use callback servers (can paste redirect URL) - const usesCallbackServer = - providerId === "openai-codex" || providerId === "google-gemini-cli" || providerId === "google-antigravity"; + const usesCallbackServer = providerInfo?.usesCallbackServer ?? false; // Create login dialog component const dialog = new LoginDialogComponent(this.ui, providerId, (_success, _message) => {