mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 18:02:31 +00:00
refactor(oauth): add provider registry
This commit is contained in:
parent
89636cfe6e
commit
3256d3c083
19 changed files with 655 additions and 291 deletions
|
|
@ -9,13 +9,11 @@
|
|||
import {
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
loginAnthropic,
|
||||
loginAntigravity,
|
||||
loginGeminiCli,
|
||||
loginGitHubCopilot,
|
||||
loginOpenAICodex,
|
||||
getOAuthProvider,
|
||||
getOAuthProviders,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
type OAuthLoginCallbacks,
|
||||
type OAuthProviderId,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
|
|
@ -156,54 +154,14 @@ export class AuthStorage {
|
|||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(
|
||||
provider: OAuthProvider,
|
||||
callbacks: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
/** For providers with local callback servers (e.g., openai-codex), races with browser callback */
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
/** For cancellation support (e.g., github-copilot polling) */
|
||||
signal?: AbortSignal;
|
||||
},
|
||||
): Promise<void> {
|
||||
let credentials: OAuthCredentials;
|
||||
|
||||
switch (provider) {
|
||||
case "anthropic":
|
||||
credentials = await loginAnthropic(
|
||||
(url) => callbacks.onAuth({ url }),
|
||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||
);
|
||||
break;
|
||||
case "github-copilot":
|
||||
credentials = await loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
signal: callbacks.signal,
|
||||
});
|
||||
break;
|
||||
case "google-gemini-cli":
|
||||
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
break;
|
||||
case "google-antigravity":
|
||||
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
break;
|
||||
case "openai-codex":
|
||||
credentials = await loginOpenAICodex({
|
||||
onAuth: callbacks.onAuth,
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
onManualCodeInput: callbacks.onManualCodeInput,
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
||||
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
this.set(provider, { type: "oauth", ...credentials });
|
||||
const credentials = await provider.login(callbacks);
|
||||
this.set(providerId, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -219,8 +177,13 @@ export class AuthStorage {
|
|||
* This ensures only one instance refreshes while others wait and use the result.
|
||||
*/
|
||||
private async refreshOAuthTokenWithLock(
|
||||
provider: OAuthProvider,
|
||||
providerId: OAuthProviderId,
|
||||
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Ensure auth file exists for locking
|
||||
if (!existsSync(this.authPath)) {
|
||||
const dir = dirname(this.authPath);
|
||||
|
|
@ -250,7 +213,7 @@ export class AuthStorage {
|
|||
// Re-read file after acquiring lock - another instance may have refreshed
|
||||
this.reload();
|
||||
|
||||
const cred = this.data[provider];
|
||||
const cred = this.data[providerId];
|
||||
if (cred?.type !== "oauth") {
|
||||
return null;
|
||||
}
|
||||
|
|
@ -259,10 +222,7 @@ export class AuthStorage {
|
|||
// (another instance may have already refreshed it)
|
||||
if (Date.now() < cred.expires) {
|
||||
// Token is now valid - another instance refreshed it
|
||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
const apiKey = needsProjectId
|
||||
? JSON.stringify({ token: cred.access, projectId: cred.projectId })
|
||||
: cred.access;
|
||||
const apiKey = provider.getApiKey(cred);
|
||||
return { apiKey, newCredentials: cred };
|
||||
}
|
||||
|
||||
|
|
@ -274,9 +234,9 @@ export class AuthStorage {
|
|||
}
|
||||
}
|
||||
|
||||
const result = await getOAuthApiKey(provider, oauthCreds);
|
||||
const result = await getOAuthApiKey(providerId, oauthCreds);
|
||||
if (result) {
|
||||
this.data[provider] = { type: "oauth", ...result.newCredentials };
|
||||
this.data[providerId] = { type: "oauth", ...result.newCredentials };
|
||||
this.save();
|
||||
return result;
|
||||
}
|
||||
|
|
@ -303,41 +263,44 @@ export class AuthStorage {
|
|||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(provider: string): Promise<string | undefined> {
|
||||
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(provider);
|
||||
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[provider];
|
||||
const cred = this.data[providerId];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return cred.key;
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
// Unknown OAuth provider, can't get API key
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
|
||||
if (needsRefresh) {
|
||||
// Use locked refresh to prevent race conditions
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(provider as OAuthProvider);
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) {
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch {
|
||||
// Refresh failed - re-read file to check if another instance succeeded
|
||||
this.reload();
|
||||
const updatedCred = this.data[provider];
|
||||
const updatedCred = this.data[providerId];
|
||||
|
||||
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
|
||||
// Another instance refreshed successfully, use those credentials
|
||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
return needsProjectId
|
||||
? JSON.stringify({ token: updatedCred.access, projectId: updatedCred.projectId })
|
||||
: updatedCred.access;
|
||||
return provider.getApiKey(updatedCred);
|
||||
}
|
||||
|
||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||
|
|
@ -346,16 +309,22 @@ export class AuthStorage {
|
|||
}
|
||||
} else {
|
||||
// Token not expired, use current access token
|
||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
return needsProjectId ? JSON.stringify({ token: cred.access, projectId: cred.projectId }) : cred.access;
|
||||
return provider.getApiKey(cred);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
const envKey = getEnvApiKey(provider);
|
||||
const envKey = getEnvApiKey(providerId);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(provider) ?? undefined;
|
||||
return this.fallbackResolver?.(providerId) ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
getOAuthProviders() {
|
||||
return getOAuthProviders();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue