Refactor OAuth/API key handling: AuthStorage and ModelRegistry

- Add AuthStorage class for credential storage (auth.json)
- Add ModelRegistry class for model management with API key resolution
- Add discoverAuthStorage() and discoverModels() discovery functions
- Add migration from legacy oauth.json and settings.json apiKeys to auth.json
- Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels
- Remove apiKeys from Settings type and SettingsManager methods
- Rename getOAuthPath to getAuthPath
- Update SDK, examples, docs, tests, and mom package

Fixes #296
This commit is contained in:
Mario Zechner 2025-12-25 03:48:36 +01:00
parent 9f97f0c8da
commit 54018b6cc0
29 changed files with 953 additions and 2017 deletions

View file

@ -3,9 +3,18 @@
* Handles loading, saving, and refreshing credentials from auth.json.
*/
import { getApiKeyFromEnv, getOAuthApiKey, type OAuthCredentials, type OAuthProvider } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname } from "path";
import {
getEnvApiKey,
getOAuthApiKey,
loginAnthropic,
loginAntigravity,
loginGeminiCli,
loginGitHubCopilot,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, renameSync, writeFileSync } from "fs";
import { dirname, join } from "path";
export type ApiKeyCredential = {
type: "api_key";
@ -25,11 +34,29 @@ export type AuthStorageData = Record<string, AuthCredential>;
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
constructor(private authPath: string) {
this.reload();
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
this.fallbackResolver = resolver;
}
/**
* Reload credentials from disk.
*/
@ -101,14 +128,69 @@ export class AuthStorage {
return { ...this.data };
}
/**
* Login to an OAuth provider.
*/
async login(
provider: OAuthProvider,
callbacks: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress);
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
this.set(provider, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Get API key for a provider.
* Priority:
* 1. API key from auth.json
* 2. OAuth token from auth.json (auto-refreshed)
* 3. Environment variable (via getApiKeyFromEnv)
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | null> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[provider];
if (cred?.type === "api_key") {
@ -116,30 +198,83 @@ export class AuthStorage {
}
if (cred?.type === "oauth") {
// Build OAuthCredentials map (without type discriminator)
// Filter to only oauth credentials for getOAuthApiKey
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(this.data)) {
if (value.type === "oauth") {
const { type: _, ...rest } = value;
oauthCreds[key] = rest;
oauthCreds[key] = value;
}
}
try {
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
if (result) {
// Save refreshed credentials
this.data[provider] = { type: "oauth", ...result.newCredentials };
this.save();
return result.apiKey;
}
} catch {
// Token refresh failed, remove invalid credentials
this.remove(provider);
}
}
// Fall back to environment variable
return getApiKeyFromEnv(provider) ?? null;
const envKey = getEnvApiKey(provider);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? null;
}
/**
* Migrate credentials from legacy oauth.json and settings.json apiKeys to auth.json.
* Only runs if auth.json doesn't exist yet. Returns list of migrated providers.
*/
static migrateLegacy(authPath: string, agentDir: string): string[] {
const oauthPath = join(agentDir, "oauth.json");
const settingsPath = join(agentDir, "settings.json");
// Skip if auth.json already exists
if (existsSync(authPath)) return [];
const migrated: AuthStorageData = {};
const providers: string[] = [];
// Migrate oauth.json
if (existsSync(oauthPath)) {
try {
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
for (const [provider, cred] of Object.entries(oauth)) {
migrated[provider] = { type: "oauth", ...(cred as object) } as OAuthCredential;
providers.push(provider);
}
renameSync(oauthPath, `${oauthPath}.migrated`);
} catch {}
}
// Migrate settings.json apiKeys
if (existsSync(settingsPath)) {
try {
const content = readFileSync(settingsPath, "utf-8");
const settings = JSON.parse(content);
if (settings.apiKeys && typeof settings.apiKeys === "object") {
for (const [provider, key] of Object.entries(settings.apiKeys)) {
if (!migrated[provider] && typeof key === "string") {
migrated[provider] = { type: "api_key", key };
providers.push(provider);
}
}
delete settings.apiKeys;
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
}
} catch {}
}
if (Object.keys(migrated).length > 0) {
mkdirSync(dirname(authPath), { recursive: true });
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
}
return providers;
}
}