Refactor OAuth/API key handling: AuthStorage and ModelRegistry

- Add AuthStorage class for credential storage (auth.json)
- Add ModelRegistry class for model management with API key resolution
- Add discoverAuthStorage() and discoverModels() discovery functions
- Add migration from legacy oauth.json and settings.json apiKeys to auth.json
- Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels
- Remove apiKeys from Settings type and SettingsManager methods
- Rename getOAuthPath to getAuthPath
- Update SDK, examples, docs, tests, and mom package

Fixes #296
This commit is contained in:
Mario Zechner 2025-12-25 03:48:36 +01:00
parent 9f97f0c8da
commit 54018b6cc0
29 changed files with 953 additions and 2017 deletions

View file

@ -3,7 +3,7 @@ import type { AssistantMessage } from "@mariozechner/pi-ai";
import { type Component, visibleWidth } from "@mariozechner/pi-tui";
import { existsSync, type FSWatcher, readFileSync, watch } from "fs";
import { dirname, join } from "path";
import { isModelUsingOAuth } from "../../../core/models-json.js";
import type { ModelRegistry } from "../../../core/model-registry.js";
import { theme } from "../theme/theme.js";
/**
@ -31,13 +31,15 @@ function findGitHeadPath(): string | null {
*/
export class FooterComponent implements Component {
private state: AgentState;
private modelRegistry: ModelRegistry;
private cachedBranch: string | null | undefined = undefined; // undefined = not checked yet, null = not in git repo, string = branch name
private gitWatcher: FSWatcher | null = null;
private onBranchChange: (() => void) | null = null;
private autoCompactEnabled: boolean = true;
constructor(state: AgentState) {
constructor(state: AgentState, modelRegistry: ModelRegistry) {
this.state = state;
this.modelRegistry = modelRegistry;
}
setAutoCompactEnabled(enabled: boolean): void {
@ -207,7 +209,7 @@ export class FooterComponent implements Component {
if (totalCacheWrite) statsParts.push(`W${formatTokens(totalCacheWrite)}`);
// Show cost with "(sub)" indicator if using OAuth subscription
const usingSubscription = this.state.model ? isModelUsingOAuth(this.state.model) : false;
const usingSubscription = this.state.model ? this.modelRegistry.isUsingOAuth(this.state.model) : false;
if (totalCost || usingSubscription) {
const costStr = `$${totalCost.toFixed(3)}${usingSubscription ? " (sub)" : ""}`;
statsParts.push(costStr);

View file

@ -10,7 +10,7 @@ import {
Text,
type TUI,
} from "@mariozechner/pi-tui";
import { getAvailableModels } from "../../../core/models-json.js";
import type { ModelRegistry } from "../../../core/model-registry.js";
import type { SettingsManager } from "../../../core/settings-manager.js";
import { fuzzyFilter } from "../../../utils/fuzzy.js";
import { theme } from "../theme/theme.js";
@ -38,6 +38,7 @@ export class ModelSelectorComponent extends Container {
private selectedIndex: number = 0;
private currentModel: Model<any> | null;
private settingsManager: SettingsManager;
private modelRegistry: ModelRegistry;
private onSelectCallback: (model: Model<any>) => void;
private onCancelCallback: () => void;
private errorMessage: string | null = null;
@ -48,6 +49,7 @@ export class ModelSelectorComponent extends Container {
tui: TUI,
currentModel: Model<any> | null,
settingsManager: SettingsManager,
modelRegistry: ModelRegistry,
scopedModels: ReadonlyArray<ScopedModelItem>,
onSelect: (model: Model<any>) => void,
onCancel: () => void,
@ -57,6 +59,7 @@ export class ModelSelectorComponent extends Container {
this.tui = tui;
this.currentModel = currentModel;
this.settingsManager = settingsManager;
this.modelRegistry = modelRegistry;
this.scopedModels = scopedModels;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
@ -113,26 +116,29 @@ export class ModelSelectorComponent extends Container {
model: scoped.model,
}));
} else {
// Load available models fresh (includes custom models from models.json)
// Pass settings manager's key resolver as fallback for settings.json apiKeys
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
// Refresh to pick up any changes to models.json
this.modelRegistry.refresh();
// If there's an error loading models.json, we'll show it via the "no models" path
// The error will be displayed to the user
if (error) {
this.allModels = [];
this.filteredModels = [];
this.errorMessage = error;
return;
// Check for models.json errors
const loadError = this.modelRegistry.getError();
if (loadError) {
this.errorMessage = loadError;
}
models = availableModels.map((model) => ({
provider: model.provider,
id: model.id,
model,
}));
// Load available models (built-in models still work even if models.json failed)
try {
const availableModels = await this.modelRegistry.getAvailable();
models = availableModels.map((model: Model<any>) => ({
provider: model.provider,
id: model.id,
model,
}));
} catch (error) {
this.allModels = [];
this.filteredModels = [];
this.errorMessage = error instanceof Error ? error.message : String(error);
return;
}
}
// Sort: current model first, then by provider

View file

@ -1,6 +1,6 @@
import { loadOAuthCredentials } from "@mariozechner/pi-ai";
import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai";
import { Container, isArrowDown, isArrowUp, isEnter, isEscape, Spacer, TruncatedText } from "@mariozechner/pi-tui";
import { getOAuthProviders, type OAuthProviderInfo } from "../../../core/oauth/index.js";
import type { AuthStorage } from "../../../core/auth-storage.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
@ -12,13 +12,20 @@ export class OAuthSelectorComponent extends Container {
private allProviders: OAuthProviderInfo[] = [];
private selectedIndex: number = 0;
private mode: "login" | "logout";
private authStorage: AuthStorage;
private onSelectCallback: (providerId: string) => void;
private onCancelCallback: () => void;
constructor(mode: "login" | "logout", onSelect: (providerId: string) => void, onCancel: () => void) {
constructor(
mode: "login" | "logout",
authStorage: AuthStorage,
onSelect: (providerId: string) => void,
onCancel: () => void,
) {
super();
this.mode = mode;
this.authStorage = authStorage;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
@ -49,7 +56,6 @@ export class OAuthSelectorComponent extends Container {
private loadProviders(): void {
this.allProviders = getOAuthProviders();
this.allProviders = this.allProviders.filter((p) => p.available);
}
private updateList(): void {
@ -63,8 +69,8 @@ export class OAuthSelectorComponent extends Container {
const isAvailable = provider.available;
// Check if user is logged in for this provider
const credentials = loadOAuthCredentials(provider.id);
const isLoggedIn = credentials !== null;
const credentials = this.authStorage.get(provider.id);
const isLoggedIn = credentials?.type === "oauth";
const statusIndicator = isLoggedIn ? theme.fg("success", " ✓ logged in") : "";
let line = "";

View file

@ -7,7 +7,7 @@ import * as fs from "node:fs";
import * as os from "node:os";
import * as path from "node:path";
import type { AgentState, AppMessage, Attachment } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Message } from "@mariozechner/pi-ai";
import type { AssistantMessage, Message, OAuthProvider } from "@mariozechner/pi-ai";
import type { SlashCommand } from "@mariozechner/pi-tui";
import {
CombinedAutocompleteProvider,
@ -25,13 +25,11 @@ import {
visibleWidth,
} from "@mariozechner/pi-tui";
import { exec, spawnSync } from "child_process";
import { APP_NAME, getDebugLogPath, getOAuthPath } from "../../config.js";
import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js";
import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js";
import type { HookUIContext } from "../../core/hooks/index.js";
import { isBashExecutionMessage } from "../../core/messages.js";
import { invalidateOAuthCache } from "../../core/models-json.js";
import { listOAuthProviders, login, logout, type OAuthProvider } from "../../core/oauth/index.js";
import {
getLatestCompactionEntry,
SessionManager,
@ -154,7 +152,7 @@ export class InteractiveMode {
this.editor = new CustomEditor(getEditorTheme());
this.editorContainer = new Container();
this.editorContainer.addChild(this.editor);
this.footer = new FooterComponent(session.state);
this.footer = new FooterComponent(session.state, session.modelRegistry);
this.footer.setAutoCompactEnabled(session.autoCompactionEnabled);
// Define slash commands for autocomplete
@ -1484,6 +1482,7 @@ export class InteractiveMode {
this.ui,
this.session.model,
this.settingsManager,
this.session.modelRegistry,
this.session.scopedModels,
async (model) => {
try {
@ -1588,7 +1587,10 @@ export class InteractiveMode {
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
if (mode === "logout") {
const loggedInProviders = listOAuthProviders();
const providers = this.session.modelRegistry.authStorage.list();
const loggedInProviders = providers.filter(
(p) => this.session.modelRegistry.authStorage.get(p)?.type === "oauth",
);
if (loggedInProviders.length === 0) {
this.showStatus("No OAuth providers logged in. Use /login first.");
return;
@ -1598,6 +1600,7 @@ export class InteractiveMode {
this.showSelector((done) => {
const selector = new OAuthSelectorComponent(
mode,
this.session.modelRegistry.authStorage,
async (providerId: string) => {
done();
@ -1605,9 +1608,8 @@ export class InteractiveMode {
this.showStatus(`Logging in to ${providerId}...`);
try {
await login(
providerId as OAuthProvider,
(info) => {
await this.session.modelRegistry.authStorage.login(providerId as OAuthProvider, {
onAuth: (info: { url: string; instructions?: string }) => {
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(theme.fg("accent", "Opening browser to:"), 1, 0));
this.chatContainer.addChild(new Text(theme.fg("accent", info.url), 1, 0));
@ -1625,7 +1627,7 @@ export class InteractiveMode {
: "xdg-open";
exec(`${openCmd} "${info.url}"`);
},
async (prompt) => {
onPrompt: async (prompt: { message: string; placeholder?: string }) => {
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(theme.fg("warning", prompt.message), 1, 0));
if (prompt.placeholder) {
@ -1648,32 +1650,35 @@ export class InteractiveMode {
this.ui.requestRender();
});
},
(message) => {
onProgress: (message: string) => {
this.chatContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
this.ui.requestRender();
},
);
invalidateOAuthCache();
});
// Refresh models to pick up new baseUrl (e.g., github-copilot)
this.session.modelRegistry.refresh();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(
new Text(theme.fg("success", `✓ Successfully logged in to ${providerId}`), 1, 0),
);
this.chatContainer.addChild(new Text(theme.fg("dim", `Tokens saved to ${getOAuthPath()}`), 1, 0));
this.chatContainer.addChild(
new Text(theme.fg("dim", `Credentials saved to ${getAuthPath()}`), 1, 0),
);
this.ui.requestRender();
} catch (error: unknown) {
this.showError(`Login failed: ${error instanceof Error ? error.message : String(error)}`);
}
} else {
try {
await logout(providerId as OAuthProvider);
invalidateOAuthCache();
this.session.modelRegistry.authStorage.logout(providerId);
// Refresh models to reset baseUrl
this.session.modelRegistry.refresh();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(
new Text(theme.fg("success", `✓ Successfully logged out of ${providerId}`), 1, 0),
);
this.chatContainer.addChild(
new Text(theme.fg("dim", `Credentials removed from ${getOAuthPath()}`), 1, 0),
new Text(theme.fg("dim", `Credentials removed from ${getAuthPath()}`), 1, 0),
);
this.ui.requestRender();
} catch (error: unknown) {