refactor(oauth): add provider registry

This commit is contained in:
Mario Zechner 2026-01-24 21:17:05 +01:00
parent 89636cfe6e
commit 3256d3c083
19 changed files with 655 additions and 291 deletions

View file

@ -2,13 +2,8 @@
import { existsSync, readFileSync, writeFileSync } from "fs"; import { existsSync, readFileSync, writeFileSync } from "fs";
import { createInterface } from "readline"; import { createInterface } from "readline";
import { loginAnthropic } from "./utils/oauth/anthropic.js"; import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
import { loginGitHubCopilot } from "./utils/oauth/github-copilot.js"; import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
import { loginAntigravity } from "./utils/oauth/google-antigravity.js";
import { loginGeminiCli } from "./utils/oauth/google-gemini-cli.js";
import { getOAuthProviders } from "./utils/oauth/index.js";
import { loginOpenAICodex } from "./utils/oauth/openai-codex.js";
import type { OAuthCredentials, OAuthProvider } from "./utils/oauth/types.js";
const AUTH_FILE = "auth.json"; const AUTH_FILE = "auth.json";
const PROVIDERS = getOAuthProviders(); const PROVIDERS = getOAuthProviders();
@ -30,63 +25,18 @@ function saveAuth(auth: Record<string, { type: "oauth" } & OAuthCredentials>): v
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8"); writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
} }
async function login(provider: OAuthProvider): Promise<void> { async function login(providerId: OAuthProviderId): Promise<void> {
const rl = createInterface({ input: process.stdin, output: process.stdout }); const provider = getOAuthProvider(providerId);
if (!provider) {
console.error(`Unknown provider: ${providerId}`);
process.exit(1);
}
const rl = createInterface({ input: process.stdin, output: process.stdout });
const promptFn = (msg: string) => prompt(rl, `${msg} `); const promptFn = (msg: string) => prompt(rl, `${msg} `);
try { try {
let credentials: OAuthCredentials; const credentials = await provider.login({
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => {
console.log(`\nOpen this URL in your browser:\n${url}\n`);
},
async () => {
return await promptFn("Paste the authorization code:");
},
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => {
console.log(`\nOpen this URL in your browser:\n${url}`);
if (instructions) console.log(instructions);
console.log();
},
onPrompt: async (p) => {
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
},
onProgress: (msg) => console.log(msg),
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(
(info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions);
console.log();
},
(msg) => console.log(msg),
);
break;
case "google-antigravity":
credentials = await loginAntigravity(
(info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions);
console.log();
},
(msg) => console.log(msg),
);
break;
case "openai-codex":
credentials = await loginOpenAICodex({
onAuth: (info) => { onAuth: (info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`); console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions); if (info.instructions) console.log(info.instructions);
@ -97,11 +47,9 @@ async function login(provider: OAuthProvider): Promise<void> {
}, },
onProgress: (msg) => console.log(msg), onProgress: (msg) => console.log(msg),
}); });
break;
}
const auth = loadAuth(); const auth = loadAuth();
auth[provider] = { type: "oauth", ...credentials }; auth[providerId] = { type: "oauth", ...credentials };
saveAuth(auth); saveAuth(auth);
console.log(`\nCredentials saved to ${AUTH_FILE}`); console.log(`\nCredentials saved to ${AUTH_FILE}`);
@ -115,6 +63,7 @@ async function main(): Promise<void> {
const command = args[0]; const command = args[0];
if (!command || command === "help" || command === "--help" || command === "-h") { if (!command || command === "help" || command === "--help" || command === "-h") {
const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n");
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider] console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
Commands: Commands:
@ -122,11 +71,7 @@ Commands:
list List available providers list List available providers
Providers: Providers:
anthropic Anthropic (Claude Pro/Max) ${providerList}
github-copilot GitHub Copilot
google-gemini-cli Google Gemini CLI
google-antigravity Antigravity (Gemini 3, Claude, GPT-OSS)
openai-codex OpenAI Codex (ChatGPT Plus/Pro)
Examples: Examples:
npx @mariozechner/pi-ai login # interactive provider selection npx @mariozechner/pi-ai login # interactive provider selection
@ -145,7 +90,7 @@ Examples:
} }
if (command === "login") { if (command === "login") {
let provider = args[1] as OAuthProvider | undefined; let provider = args[1] as OAuthProviderId | undefined;
if (!provider) { if (!provider) {
const rl = createInterface({ input: process.stdin, output: process.stdout }); const rl = createInterface({ input: process.stdin, output: process.stdout });

View file

@ -3,7 +3,7 @@
*/ */
import { generatePKCE } from "./pkce.js"; import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js"; import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
const decode = (s: string) => atob(s); const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl"); const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
@ -116,3 +116,23 @@ export async function refreshAnthropicToken(refreshToken: string): Promise<OAuth
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000, expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
}; };
} }
export const anthropicOAuthProvider: OAuthProviderInterface = {
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshAnthropicToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -3,7 +3,12 @@
*/ */
import { getModels } from "../../models.js"; import { getModels } from "../../models.js";
import type { OAuthCredentials } from "./types.js"; import type { Api, Model } from "../../types.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type CopilotCredentials = OAuthCredentials & {
enterpriseUrl?: string;
};
const decode = (s: string) => atob(s); const decode = (s: string) => atob(s);
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg="); const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
@ -344,3 +349,33 @@ export async function loginGitHubCopilot(options: {
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined); await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
return credentials; return credentials;
} }
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
id: "github-copilot",
name: "GitHub Copilot",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as CopilotCredentials;
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
modifyModels(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[] {
const creds = credentials as CopilotCredentials;
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
},
};

View file

@ -8,7 +8,11 @@
import type { Server } from "http"; import type { Server } from "http";
import { generatePKCE } from "./pkce.js"; import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js"; import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type AntigravityCredentials = OAuthCredentials & {
projectId: string;
};
// Antigravity OAuth credentials (different from Gemini CLI) // Antigravity OAuth credentials (different from Gemini CLI)
const decode = (s: string) => atob(s); const decode = (s: string) => atob(s);
@ -411,3 +415,26 @@ export async function loginAntigravity(
server.server.close(); server.server.close();
} }
} }
export const antigravityOAuthProvider: OAuthProviderInterface = {
id: "google-antigravity",
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as AntigravityCredentials;
if (!creds.projectId) {
throw new Error("Antigravity credentials missing projectId");
}
return refreshAntigravityToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as AntigravityCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -8,7 +8,11 @@
import type { Server } from "http"; import type { Server } from "http";
import { generatePKCE } from "./pkce.js"; import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js"; import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type GeminiCredentials = OAuthCredentials & {
projectId: string;
};
const decode = (s: string) => atob(s); const decode = (s: string) => atob(s);
const CLIENT_ID = decode( const CLIENT_ID = decode(
@ -553,3 +557,26 @@ export async function loginGeminiCli(
server.server.close(); server.server.close();
} }
} }
export const geminiCliOAuthProvider: OAuthProviderInterface = {
id: "google-gemini-cli",
name: "Google Cloud Code Assist (Gemini CLI)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as GeminiCredentials;
if (!creds.projectId) {
throw new Error("Google Cloud credentials missing projectId");
}
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as GeminiCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -10,100 +10,111 @@
*/ */
// Anthropic // Anthropic
export { loginAnthropic, refreshAnthropicToken } from "./anthropic.js"; export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
// GitHub Copilot // GitHub Copilot
export { export {
getGitHubCopilotBaseUrl, getGitHubCopilotBaseUrl,
githubCopilotOAuthProvider,
loginGitHubCopilot, loginGitHubCopilot,
normalizeDomain, normalizeDomain,
refreshGitHubCopilotToken, refreshGitHubCopilotToken,
} from "./github-copilot.js"; } from "./github-copilot.js";
// Google Antigravity // Google Antigravity
export { export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js";
loginAntigravity,
refreshAntigravityToken,
} from "./google-antigravity.js";
// Google Gemini CLI // Google Gemini CLI
export { export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js";
loginGeminiCli,
refreshGoogleCloudToken,
} from "./google-gemini-cli.js";
// OpenAI Codex (ChatGPT OAuth) // OpenAI Codex (ChatGPT OAuth)
export { export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js";
loginOpenAICodex,
refreshOpenAICodexToken,
} from "./openai-codex.js";
export * from "./types.js"; export * from "./types.js";
// ============================================================================ // ============================================================================
// High-level API // Provider Registry
// ============================================================================ // ============================================================================
import { refreshAnthropicToken } from "./anthropic.js"; import { anthropicOAuthProvider } from "./anthropic.js";
import { refreshGitHubCopilotToken } from "./github-copilot.js"; import { githubCopilotOAuthProvider } from "./github-copilot.js";
import { refreshAntigravityToken } from "./google-antigravity.js"; import { antigravityOAuthProvider } from "./google-antigravity.js";
import { refreshGoogleCloudToken } from "./google-gemini-cli.js"; import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
import { refreshOpenAICodexToken } from "./openai-codex.js"; import { openaiCodexOAuthProvider } from "./openai-codex.js";
import type { OAuthCredentials, OAuthProvider, OAuthProviderInfo } from "./types.js"; import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js";
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>([
[anthropicOAuthProvider.id, anthropicOAuthProvider],
[githubCopilotOAuthProvider.id, githubCopilotOAuthProvider],
[geminiCliOAuthProvider.id, geminiCliOAuthProvider],
[antigravityOAuthProvider.id, antigravityOAuthProvider],
[openaiCodexOAuthProvider.id, openaiCodexOAuthProvider],
]);
/**
* Get an OAuth provider by ID
*/
export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined {
return oauthProviderRegistry.get(id);
}
/**
* Register a custom OAuth provider
*/
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
oauthProviderRegistry.set(provider.id, provider);
}
/**
* Get all registered OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInterface[] {
return Array.from(oauthProviderRegistry.values());
}
/**
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
*/
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
return getOAuthProviders().map((p) => ({
id: p.id,
name: p.name,
available: true,
}));
}
// ============================================================================
// High-level API (uses provider registry)
// ============================================================================
/** /**
* Refresh token for any OAuth provider. * Refresh token for any OAuth provider.
* Saves the new credentials and returns the new access token. * @deprecated Use getOAuthProvider(id).refreshToken() instead
*/ */
export async function refreshOAuthToken( export async function refreshOAuthToken(
provider: OAuthProvider, providerId: OAuthProviderId,
credentials: OAuthCredentials, credentials: OAuthCredentials,
): Promise<OAuthCredentials> { ): Promise<OAuthCredentials> {
if (!credentials) { const provider = getOAuthProvider(providerId);
throw new Error(`No OAuth credentials found for ${provider}`); if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
} }
return provider.refreshToken(credentials);
let newCredentials: OAuthCredentials;
switch (provider) {
case "anthropic":
newCredentials = await refreshAnthropicToken(credentials.refresh);
break;
case "github-copilot":
newCredentials = await refreshGitHubCopilotToken(credentials.refresh, credentials.enterpriseUrl);
break;
case "google-gemini-cli":
if (!credentials.projectId) {
throw new Error("Google Cloud credentials missing projectId");
}
newCredentials = await refreshGoogleCloudToken(credentials.refresh, credentials.projectId);
break;
case "google-antigravity":
if (!credentials.projectId) {
throw new Error("Antigravity credentials missing projectId");
}
newCredentials = await refreshAntigravityToken(credentials.refresh, credentials.projectId);
break;
case "openai-codex":
newCredentials = await refreshOpenAICodexToken(credentials.refresh);
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
return newCredentials;
} }
/** /**
* Get API key for a provider from OAuth credentials. * Get API key for a provider from OAuth credentials.
* Automatically refreshes expired tokens. * Automatically refreshes expired tokens.
* *
* For google-gemini-cli and antigravity, returns JSON-encoded { token, projectId } * @returns API key string and updated credentials, or null if no credentials
*
* @returns API key string, or null if no credentials
* @throws Error if refresh fails * @throws Error if refresh fails
*/ */
export async function getOAuthApiKey( export async function getOAuthApiKey(
provider: OAuthProvider, providerId: OAuthProviderId,
credentials: Record<string, OAuthCredentials>, credentials: Record<string, OAuthCredentials>,
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> { ): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
let creds = credentials[provider]; const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
let creds = credentials[providerId];
if (!creds) { if (!creds) {
return null; return null;
} }
@ -111,47 +122,12 @@ export async function getOAuthApiKey(
// Refresh if expired // Refresh if expired
if (Date.now() >= creds.expires) { if (Date.now() >= creds.expires) {
try { try {
creds = await refreshOAuthToken(provider, creds); creds = await provider.refreshToken(creds);
} catch (_error) { } catch (_error) {
throw new Error(`Failed to refresh OAuth token for ${provider}`); throw new Error(`Failed to refresh OAuth token for ${providerId}`);
} }
} }
// For providers that need projectId, return JSON const apiKey = provider.getApiKey(creds);
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
const apiKey = needsProjectId ? JSON.stringify({ token: creds.access, projectId: creds.projectId }) : creds.access;
return { newCredentials: creds, apiKey }; return { newCredentials: creds, apiKey };
} }
/**
* Get list of OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInfo[] {
return [
{
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
available: true,
},
{
id: "openai-codex",
name: "ChatGPT Plus/Pro (Codex Subscription)",
available: true,
},
{
id: "github-copilot",
name: "GitHub Copilot",
available: true,
},
{
id: "google-gemini-cli",
name: "Google Cloud Code Assist (Gemini CLI)",
available: true,
},
{
id: "google-antigravity",
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
available: true,
},
];
}

View file

@ -18,7 +18,7 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
} }
import { generatePKCE } from "./pkce.js"; import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials, OAuthPrompt } from "./types.js"; import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"; const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"; const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
@ -430,3 +430,26 @@ export async function refreshOpenAICodexToken(refreshToken: string): Promise<OAu
accountId, accountId,
}; };
} }
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
id: "openai-codex",
name: "ChatGPT Plus/Pro (Codex Subscription)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshOpenAICodexToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -1,19 +1,16 @@
import type { Api, Model } from "../../types.js";
export type OAuthCredentials = { export type OAuthCredentials = {
refresh: string; refresh: string;
access: string; access: string;
expires: number; expires: number;
enterpriseUrl?: string; [key: string]: unknown;
projectId?: string;
email?: string;
accountId?: string;
}; };
export type OAuthProvider = export type OAuthProviderId = string;
| "anthropic"
| "github-copilot" /** @deprecated Use OAuthProviderId instead */
| "google-gemini-cli" export type OAuthProvider = OAuthProviderId;
| "google-antigravity"
| "openai-codex";
export type OAuthPrompt = { export type OAuthPrompt = {
message: string; message: string;
@ -26,8 +23,37 @@ export type OAuthAuthInfo = {
instructions?: string; instructions?: string;
}; };
export interface OAuthLoginCallbacks {
onAuth: (info: OAuthAuthInfo) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
signal?: AbortSignal;
}
export interface OAuthProviderInterface {
readonly id: OAuthProviderId;
readonly name: string;
/** Run the login flow, return credentials to persist */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Whether login uses a local callback server and supports manual code input. */
usesCallbackServer?: boolean;
/** Refresh expired credentials, return updated credentials to persist */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl) */
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
}
/** @deprecated Use OAuthProviderInterface instead */
export interface OAuthProviderInfo { export interface OAuthProviderInfo {
id: OAuthProvider; id: OAuthProviderId;
name: string; name: string;
available: boolean; available: boolean;
} }

View file

@ -1146,6 +1146,67 @@ pi.events.on("my:event", (data) => { ... });
pi.events.emit("my:event", { ... }); pi.events.emit("my:event", { ... });
``` ```
### pi.registerProvider(name, config)
Register or override a model provider dynamically. Useful for proxies, custom endpoints, or team-wide model configurations.
```typescript
// Register a new provider with custom models
pi.registerProvider("my-proxy", {
baseUrl: "https://proxy.example.com",
apiKey: "PROXY_API_KEY", // env var name or literal
api: "anthropic-messages",
models: [
{
id: "claude-sonnet-4-20250514",
name: "Claude 4 Sonnet (proxy)",
reasoning: false,
input: ["text", "image"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 16384
}
]
});
// Override baseUrl for an existing provider (keeps all models)
pi.registerProvider("anthropic", {
baseUrl: "https://proxy.example.com"
});
// Register provider with OAuth support for /login
pi.registerProvider("corporate-ai", {
baseUrl: "https://ai.corp.com",
api: "openai-responses",
models: [...],
oauth: {
name: "Corporate AI (SSO)",
async login(callbacks) {
// Custom OAuth flow
callbacks.onAuth({ url: "https://sso.corp.com/..." });
const code = await callbacks.onPrompt({ message: "Enter code:" });
return { refresh: code, access: code, expires: Date.now() + 3600000 };
},
async refreshToken(credentials) {
// Refresh logic
return credentials;
},
getApiKey(credentials) {
return credentials.access;
}
}
});
```
**Config options:**
- `baseUrl` - API endpoint URL. Required when defining models.
- `apiKey` - API key or environment variable name. Required when defining models (unless `oauth` provided).
- `api` - API type: `"anthropic-messages"`, `"openai-completions"`, `"openai-responses"`, etc.
- `headers` - Custom headers to include in requests.
- `authHeader` - If true, adds `Authorization: Bearer` header automatically.
- `models` - Array of model definitions. If provided, replaces all existing models for this provider.
- `oauth` - OAuth provider config for `/login` support. When provided, the provider appears in the login menu.
## State Management ## State Management
Extensions with state should store it in tool result `details` for proper branching support: Extensions with state should store it in tool result `details` for proper branching support:

View file

@ -9,13 +9,11 @@
import { import {
getEnvApiKey, getEnvApiKey,
getOAuthApiKey, getOAuthApiKey,
loginAnthropic, getOAuthProvider,
loginAntigravity, getOAuthProviders,
loginGeminiCli,
loginGitHubCopilot,
loginOpenAICodex,
type OAuthCredentials, type OAuthCredentials,
type OAuthProvider, type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@mariozechner/pi-ai"; } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path"; import { dirname, join } from "path";
@ -156,54 +154,14 @@ export class AuthStorage {
/** /**
* Login to an OAuth provider. * Login to an OAuth provider.
*/ */
async login( async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
provider: OAuthProvider, const provider = getOAuthProvider(providerId);
callbacks: { if (!provider) {
onAuth: (info: { url: string; instructions?: string }) => void; throw new Error(`Unknown OAuth provider: ${providerId}`);
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
/** For providers with local callback servers (e.g., openai-codex), races with browser callback */
onManualCodeInput?: () => Promise<string>;
/** For cancellation support (e.g., github-copilot polling) */
signal?: AbortSignal;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "openai-codex":
credentials = await loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
} }
this.set(provider, { type: "oauth", ...credentials }); const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
} }
/** /**
@ -219,8 +177,13 @@ export class AuthStorage {
* This ensures only one instance refreshes while others wait and use the result. * This ensures only one instance refreshes while others wait and use the result.
*/ */
private async refreshOAuthTokenWithLock( private async refreshOAuthTokenWithLock(
provider: OAuthProvider, providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> { ): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
// Ensure auth file exists for locking // Ensure auth file exists for locking
if (!existsSync(this.authPath)) { if (!existsSync(this.authPath)) {
const dir = dirname(this.authPath); const dir = dirname(this.authPath);
@ -250,7 +213,7 @@ export class AuthStorage {
// Re-read file after acquiring lock - another instance may have refreshed // Re-read file after acquiring lock - another instance may have refreshed
this.reload(); this.reload();
const cred = this.data[provider]; const cred = this.data[providerId];
if (cred?.type !== "oauth") { if (cred?.type !== "oauth") {
return null; return null;
} }
@ -259,10 +222,7 @@ export class AuthStorage {
// (another instance may have already refreshed it) // (another instance may have already refreshed it)
if (Date.now() < cred.expires) { if (Date.now() < cred.expires) {
// Token is now valid - another instance refreshed it // Token is now valid - another instance refreshed it
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; const apiKey = provider.getApiKey(cred);
const apiKey = needsProjectId
? JSON.stringify({ token: cred.access, projectId: cred.projectId })
: cred.access;
return { apiKey, newCredentials: cred }; return { apiKey, newCredentials: cred };
} }
@ -274,9 +234,9 @@ export class AuthStorage {
} }
} }
const result = await getOAuthApiKey(provider, oauthCreds); const result = await getOAuthApiKey(providerId, oauthCreds);
if (result) { if (result) {
this.data[provider] = { type: "oauth", ...result.newCredentials }; this.data[providerId] = { type: "oauth", ...result.newCredentials };
this.save(); this.save();
return result; return result;
} }
@ -303,41 +263,44 @@ export class AuthStorage {
* 4. Environment variable * 4. Environment variable
* 5. Fallback resolver (models.json custom providers) * 5. Fallback resolver (models.json custom providers)
*/ */
async getApiKey(provider: string): Promise<string | undefined> { async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority // Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider); const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) { if (runtimeKey) {
return runtimeKey; return runtimeKey;
} }
const cred = this.data[provider]; const cred = this.data[providerId];
if (cred?.type === "api_key") { if (cred?.type === "api_key") {
return cred.key; return cred.key;
} }
if (cred?.type === "oauth") { if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh // Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires; const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) { if (needsRefresh) {
// Use locked refresh to prevent race conditions // Use locked refresh to prevent race conditions
try { try {
const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider); const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) { if (result) {
return result.apiKey; return result.apiKey;
} }
} catch { } catch {
// Refresh failed - re-read file to check if another instance succeeded // Refresh failed - re-read file to check if another instance succeeded
this.reload(); this.reload();
const updatedCred = this.data[provider]; const updatedCred = this.data[providerId];
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) { if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
// Another instance refreshed successfully, use those credentials // Another instance refreshed successfully, use those credentials
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; return provider.getApiKey(updatedCred);
return needsProjectId
? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId })
: updatedCred.access;
} }
// Refresh truly failed - return undefined so model discovery skips this provider // Refresh truly failed - return undefined so model discovery skips this provider
@ -346,16 +309,22 @@ export class AuthStorage {
} }
} else { } else {
// Token not expired, use current access token // Token not expired, use current access token
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity"; return provider.getApiKey(cred);
return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access;
} }
} }
// Fall back to environment variable // Fall back to environment variable
const envKey = getEnvApiKey(provider); const envKey = getEnvApiKey(providerId);
if (envKey) return envKey; if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers) // Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? undefined; return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
} }
} }

View file

@ -76,6 +76,9 @@ export type {
MessageRenderOptions, MessageRenderOptions,
ModelSelectEvent, ModelSelectEvent,
ModelSelectSource, ModelSelectSource,
// Provider Registration
ProviderConfig,
ProviderModelConfig,
ReadToolResultEvent, ReadToolResultEvent,
// Commands // Commands
RegisteredCommand, RegisteredCommand,

View file

@ -32,6 +32,7 @@ import type {
ExtensionRuntime, ExtensionRuntime,
LoadExtensionsResult, LoadExtensionsResult,
MessageRenderer, MessageRenderer,
ProviderConfig,
RegisteredCommand, RegisteredCommand,
ToolDefinition, ToolDefinition,
} from "./types.js"; } from "./types.js";
@ -122,6 +123,7 @@ export function createExtensionRuntime(): ExtensionRuntime {
getThinkingLevel: notInitialized, getThinkingLevel: notInitialized,
setThinkingLevel: notInitialized, setThinkingLevel: notInitialized,
flagValues: new Map(), flagValues: new Map(),
pendingProviderRegistrations: [],
}; };
} }
@ -238,6 +240,10 @@ function createExtensionAPI(
runtime.setThinkingLevel(level); runtime.setThinkingLevel(level);
}, },
registerProvider(name: string, config: ProviderConfig) {
runtime.pendingProviderRegistrations.push({ name, config });
},
events: eventBus, events: eventBus,
} as ExtensionAPI; } as ExtensionAPI;

View file

@ -203,6 +203,12 @@ export class ExtensionRunner {
this.shutdownHandler = contextActions.shutdown; this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage; this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact; this.compactFn = contextActions.compact;
// Process provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
}
this.runtime.pendingProviderRegistrations = [];
} }
bindCommandContext(actions?: ExtensionCommandContextActions): void { bindCommandContext(actions?: ExtensionCommandContextActions): void {

View file

@ -14,7 +14,15 @@ import type {
AgentToolUpdateCallback, AgentToolUpdateCallback,
ThinkingLevel, ThinkingLevel,
} from "@mariozechner/pi-agent-core"; } from "@mariozechner/pi-agent-core";
import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai"; import type {
Api,
ImageContent,
Model,
OAuthCredentials,
OAuthLoginCallbacks,
TextContent,
ToolResultMessage,
} from "@mariozechner/pi-ai";
import type { import type {
AutocompleteItem, AutocompleteItem,
Component, Component,
@ -854,10 +862,119 @@ export interface ExtensionAPI {
/** Set thinking level (clamped to model capabilities). */ /** Set thinking level (clamped to model capabilities). */
setThinkingLevel(level: ThinkingLevel): void; setThinkingLevel(level: ThinkingLevel): void;
// =========================================================================
// Provider Registration
// =========================================================================
/**
* Register or override a model provider.
*
* If `models` is provided: replaces all existing models for this provider.
* If only `baseUrl` is provided: overrides the URL for existing models.
* If `oauth` is provided: registers OAuth provider for /login support.
*
* @example
* // Register a new provider with custom models
* pi.registerProvider("my-proxy", {
* baseUrl: "https://proxy.example.com",
* apiKey: "PROXY_API_KEY",
* api: "anthropic-messages",
* models: [
* {
* id: "claude-sonnet-4-20250514",
* name: "Claude 4 Sonnet (proxy)",
* reasoning: false,
* input: ["text", "image"],
* cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
* contextWindow: 200000,
* maxTokens: 16384
* }
* ]
* });
*
* @example
* // Override baseUrl for an existing provider
* pi.registerProvider("anthropic", {
* baseUrl: "https://proxy.example.com"
* });
*
* @example
* // Register provider with OAuth support
* pi.registerProvider("corporate-ai", {
* baseUrl: "https://ai.corp.com",
* api: "openai-responses",
* models: [...],
* oauth: {
* name: "Corporate AI (SSO)",
* async login(callbacks) { ... },
* async refreshToken(credentials) { ... },
* getApiKey(credentials) { return credentials.access; }
* }
* });
*/
registerProvider(name: string, config: ProviderConfig): void;
/** Shared event bus for extension communication. */ /** Shared event bus for extension communication. */
events: EventBus; events: EventBus;
} }
// ============================================================================
// Provider Registration Types
// ============================================================================
/** Configuration for registering a provider via pi.registerProvider(). */
export interface ProviderConfig {
/** Base URL for the API endpoint. Required when defining models. */
baseUrl?: string;
/** API key or environment variable name. Required when defining models (unless oauth provided). */
apiKey?: string;
/** API type. Required at provider or model level when defining models. */
api?: Api;
/** Custom headers to include in requests. */
headers?: Record<string, string>;
/** If true, adds Authorization: Bearer header with the resolved API key. */
authHeader?: boolean;
/** Models to register. If provided, replaces all existing models for this provider. */
models?: ProviderModelConfig[];
/** OAuth provider for /login support. The `id` is set automatically from the provider name. */
oauth?: {
/** Display name for the provider in login UI. */
name: string;
/** Run the login flow, return credentials to persist. */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Refresh expired credentials, return updated credentials to persist. */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider. */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl based on credentials). */
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
};
}
/** Configuration for a model within a provider. */
export interface ProviderModelConfig {
/** Model ID (e.g., "claude-sonnet-4-20250514"). */
id: string;
/** Display name (e.g., "Claude 4 Sonnet"). */
name: string;
/** API type override for this model. */
api?: Api;
/** Whether the model supports extended thinking. */
reasoning: boolean;
/** Supported input types. */
input: ("text" | "image")[];
/** Cost per token (for tracking, can be 0). */
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
/** Maximum context window size in tokens. */
contextWindow: number;
/** Maximum output tokens. */
maxTokens: number;
/** Custom headers for this model. */
headers?: Record<string, string>;
/** OpenAI compatibility settings. */
compat?: Model<Api>["compat"];
}
/** Extension factory function type. Supports both sync and async initialization. */ /** Extension factory function type. Supports both sync and async initialization. */
export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise<void>; export type ExtensionFactory = (pi: ExtensionAPI) => void | Promise<void>;
@ -926,6 +1043,8 @@ export type SetLabelHandler = (entryId: string, label: string | undefined) => vo
*/ */
export interface ExtensionRuntimeState { export interface ExtensionRuntimeState {
flagValues: Map<string, boolean | string>; flagValues: Map<string, boolean | string>;
/** Provider registrations queued during extension loading, processed when runner binds */
pendingProviderRegistrations: Array<{ name: string; config: ProviderConfig }>;
} }
/** /**

View file

@ -4,12 +4,12 @@
import { import {
type Api, type Api,
getGitHubCopilotBaseUrl,
getModels, getModels,
getProviders, getProviders,
type KnownProvider, type KnownProvider,
type Model, type Model,
normalizeDomain, type OAuthProviderInterface,
registerOAuthProvider,
} from "@mariozechner/pi-ai"; } from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox"; import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv"; import AjvModule from "ajv";
@ -26,7 +26,13 @@ const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()), supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()), supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()), supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])), maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai")])),
}); });
const OpenAIResponsesCompatSchema = Type.Object({ const OpenAIResponsesCompatSchema = Type.Object({
@ -174,6 +180,7 @@ export function clearApiKeyCache(): void {
export class ModelRegistry { export class ModelRegistry {
private models: Model<Api>[] = []; private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map(); private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined; private loadError: string | undefined = undefined;
constructor( constructor(
@ -200,6 +207,10 @@ export class ModelRegistry {
this.customProviderApiKeys.clear(); this.customProviderApiKeys.clear();
this.loadError = undefined; this.loadError = undefined;
this.loadModels(); this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
} }
/** /**
@ -224,21 +235,19 @@ export class ModelRegistry {
} }
const builtInModels = this.loadBuiltInModels(replacedProviders, overrides); const builtInModels = this.loadBuiltInModels(replacedProviders, overrides);
const combined = [...builtInModels, ...customModels]; let combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth credentials // Let OAuth providers modify their models (e.g., update baseUrl)
const copilotCred = this.authStorage.get("github-copilot"); for (const oauthProvider of this.authStorage.getOAuthProviders()) {
if (copilotCred?.type === "oauth") { const cred = this.authStorage.get(oauthProvider.id);
const domain = copilotCred.enterpriseUrl if (cred?.type === "oauth" && oauthProvider.modifyModels) {
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined) combined = oauthProvider.modifyModels(combined, cred);
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
} else {
this.models = combined;
} }
} }
this.models = combined;
}
/** Load built-in models, skipping replaced providers and applying overrides */ /** Load built-in models, skipping replaced providers and applying overrides */
private loadBuiltInModels(replacedProviders: Set<string>, overrides: Map<string, ProviderOverride>): Model<Api>[] { private loadBuiltInModels(replacedProviders: Set<string>, overrides: Map<string, ProviderOverride>): Model<Api>[] {
return getProviders() return getProviders()
@ -449,4 +458,118 @@ export class ModelRegistry {
const cred = this.authStorage.get(model.provider); const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth"; return cred?.type === "oauth";
} }
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
}
if (!config.apiKey && !config.oauth) {
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
};
});
}
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
reasoning: boolean;
input: ("text" | "image")[];
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
}>;
} }

View file

@ -76,6 +76,8 @@ export type {
LoadExtensionsResult, LoadExtensionsResult,
MessageRenderer, MessageRenderer,
MessageRenderOptions, MessageRenderOptions,
ProviderConfig,
ProviderModelConfig,
RegisteredCommand, RegisteredCommand,
RegisteredTool, RegisteredTool,
SessionBeforeCompactEvent, SessionBeforeCompactEvent,

View file

@ -432,10 +432,6 @@ export async function main(args: string[]) {
// Run migrations (pass cwd for project-local migrations) // Run migrations (pass cwd for project-local migrations)
const { migratedAuthProviders: migratedProviders, deprecationWarnings } = runMigrations(process.cwd()); const { migratedAuthProviders: migratedProviders, deprecationWarnings } = runMigrations(process.cwd());
// Create AuthStorage and ModelRegistry upfront
const authStorage = new AuthStorage();
const modelRegistry = new ModelRegistry(authStorage);
// First pass: parse args to get --extension paths // First pass: parse args to get --extension paths
const firstPass = parseArgs(args); const firstPass = parseArgs(args);
@ -443,6 +439,8 @@ export async function main(args: string[]) {
const cwd = process.cwd(); const cwd = process.cwd();
const agentDir = getAgentDir(); const agentDir = getAgentDir();
const settingsManager = SettingsManager.create(cwd, agentDir); const settingsManager = SettingsManager.create(cwd, agentDir);
const authStorage = new AuthStorage();
const modelRegistry = new ModelRegistry(authStorage, getModelsPath());
const resourceLoader = new DefaultResourceLoader({ const resourceLoader = new DefaultResourceLoader({
cwd, cwd,

View file

@ -1,4 +1,4 @@
import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai"; import { getOAuthProviders, type OAuthProviderInterface } from "@mariozechner/pi-ai";
import { Container, getEditorKeybindings, Spacer, TruncatedText } from "@mariozechner/pi-tui"; import { Container, getEditorKeybindings, Spacer, TruncatedText } from "@mariozechner/pi-tui";
import type { AuthStorage } from "../../../core/auth-storage.js"; import type { AuthStorage } from "../../../core/auth-storage.js";
import { theme } from "../theme/theme.js"; import { theme } from "../theme/theme.js";
@ -9,7 +9,7 @@ import { DynamicBorder } from "./dynamic-border.js";
*/ */
export class OAuthSelectorComponent extends Container { export class OAuthSelectorComponent extends Container {
private listContainer: Container; private listContainer: Container;
private allProviders: OAuthProviderInfo[] = []; private allProviders: OAuthProviderInterface[] = [];
private selectedIndex: number = 0; private selectedIndex: number = 0;
private mode: "login" | "logout"; private mode: "login" | "logout";
private authStorage: AuthStorage; private authStorage: AuthStorage;
@ -66,7 +66,6 @@ export class OAuthSelectorComponent extends Container {
if (!provider) continue; if (!provider) continue;
const isSelected = i === this.selectedIndex; const isSelected = i === this.selectedIndex;
const isAvailable = provider.available;
// Check if user is logged in for this provider // Check if user is logged in for this provider
const credentials = this.authStorage.get(provider.id); const credentials = this.authStorage.get(provider.id);
@ -76,10 +75,10 @@ export class OAuthSelectorComponent extends Container {
let line = ""; let line = "";
if (isSelected) { if (isSelected) {
const prefix = theme.fg("accent", "→ "); const prefix = theme.fg("accent", "→ ");
const text = isAvailable ? theme.fg("accent", provider.name) : theme.fg("dim", provider.name); const text = theme.fg("accent", provider.name);
line = prefix + text + statusIndicator; line = prefix + text + statusIndicator;
} else { } else {
const text = isAvailable ? ` ${provider.name}` : theme.fg("dim", ` ${provider.name}`); const text = ` ${provider.name}`;
line = text + statusIndicator; line = text + statusIndicator;
} }
@ -109,7 +108,7 @@ export class OAuthSelectorComponent extends Container {
// Enter // Enter
else if (kb.matches(keyData, "selectConfirm")) { else if (kb.matches(keyData, "selectConfirm")) {
const selectedProvider = this.allProviders[this.selectedIndex]; const selectedProvider = this.allProviders[this.selectedIndex];
if (selectedProvider?.available) { if (selectedProvider) {
this.onSelectCallback(selectedProvider.id); this.onSelectCallback(selectedProvider.id);
} }
} }

View file

@ -3529,8 +3529,7 @@ export class InteractiveMode {
const providerName = providerInfo?.name || providerId; const providerName = providerInfo?.name || providerId;
// Providers that use callback servers (can paste redirect URL) // Providers that use callback servers (can paste redirect URL)
const usesCallbackServer = const usesCallbackServer = providerInfo?.usesCallbackServer ?? false;
providerId === "openai-codex" || providerId === "google-gemini-cli" || providerId === "google-antigravity";
// Create login dialog component // Create login dialog component
const dialog = new LoginDialogComponent(this.ui, providerId, (_success, _message) => { const dialog = new LoginDialogComponent(this.ui, providerId, (_success, _message) => {