refactor(oauth): add provider registry

This commit is contained in:
Mario Zechner 2026-01-24 21:17:05 +01:00
parent 89636cfe6e
commit 3256d3c083
19 changed files with 655 additions and 291 deletions

View file

@ -9,13 +9,11 @@
import {
getEnvApiKey,
getOAuthApiKey,
loginAnthropic,
loginAntigravity,
loginGeminiCli,
loginGitHubCopilot,
loginOpenAICodex,
getOAuthProvider,
getOAuthProviders,
type OAuthCredentials,
type OAuthProvider,
type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
@ -156,54 +154,14 @@ export class AuthStorage {
/**
* Login to an OAuth provider.
*/
async login(
provider: OAuthProvider,
callbacks: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
/** For providers with local callback servers (e.g., openai-codex), races with browser callback */
onManualCodeInput?: () => Promise<string>;
/** For cancellation support (e.g., github-copilot polling) */
signal?: AbortSignal;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
break;
case "openai-codex":
credentials = await loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
this.set(provider, { type: "oauth", ...credentials });
const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
}
/**
@ -219,8 +177,13 @@ export class AuthStorage {
* This ensures only one instance refreshes while others wait and use the result.
*/
private async refreshOAuthTokenWithLock(
provider: OAuthProvider,
providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
// Ensure auth file exists for locking
if (!existsSync(this.authPath)) {
const dir = dirname(this.authPath);
@ -250,7 +213,7 @@ export class AuthStorage {
// Re-read file after acquiring lock - another instance may have refreshed
this.reload();
const cred = this.data[provider];
const cred = this.data[providerId];
if (cred?.type !== "oauth") {
return null;
}
@ -259,10 +222,7 @@ export class AuthStorage {
// (another instance may have already refreshed it)
if (Date.now() < cred.expires) {
// Token is now valid - another instance refreshed it
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
const apiKey = needsProjectId
? JSON.stringify({ token: cred.access, projectId: cred.projectId })
: cred.access;
const apiKey = provider.getApiKey(cred);
return { apiKey, newCredentials: cred };
}
@ -274,9 +234,9 @@ export class AuthStorage {
}
}
const result = await getOAuthApiKey(provider, oauthCreds);
const result = await getOAuthApiKey(providerId, oauthCreds);
if (result) {
this.data[provider] = { type: "oauth", ...result.newCredentials };
this.data[providerId] = { type: "oauth", ...result.newCredentials };
this.save();
return result;
}
@ -303,41 +263,44 @@ export class AuthStorage {
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | undefined> {
async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[provider];
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return cred.key;
}
if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) {
// Use locked refresh to prevent race conditions
try {
const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider);
const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) {
return result.apiKey;
}
} catch {
// Refresh failed - re-read file to check if another instance succeeded
this.reload();
const updatedCred = this.data[provider];
const updatedCred = this.data[providerId];
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
// Another instance refreshed successfully, use those credentials
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
return needsProjectId
? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId })
: updatedCred.access;
return provider.getApiKey(updatedCred);
}
// Refresh truly failed - return undefined so model discovery skips this provider
@ -346,16 +309,22 @@ export class AuthStorage {
}
} else {
// Token not expired, use current access token
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access;
return provider.getApiKey(cred);
}
}
// Fall back to environment variable
const envKey = getEnvApiKey(provider);
const envKey = getEnvApiKey(providerId);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? undefined;
return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
}
}