Refactor OAuth/API key handling: AuthStorage and ModelRegistry

- Add AuthStorage class for credential storage (auth.json)
- Add ModelRegistry class for model management with API key resolution
- Add discoverAuthStorage() and discoverModels() discovery functions
- Add migration from legacy oauth.json and settings.json apiKeys to auth.json
- Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels
- Remove apiKeys from Settings type and SettingsManager methods
- Rename getOAuthPath to getAuthPath
- Update SDK, examples, docs, tests, and mom package

Fixes #296
This commit is contained in:
Mario Zechner 2025-12-25 03:48:36 +01:00
parent 9f97f0c8da
commit 54018b6cc0
29 changed files with 953 additions and 2017 deletions

View file

@ -30,22 +30,17 @@
*/
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import type { Model } from "@mariozechner/pi-ai";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
import type { CustomAgentTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js";
import { messageTransformer } from "./messages.js";
import {
findModel as findModelInternal,
getApiKeyForModel,
getAvailableModels,
loadAndMergeModels,
} from "./models-json.js";
import { ModelRegistry } from "./model-registry.js";
import { SessionManager } from "./session-manager.js";
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
import { loadSkills as loadSkillsInternal, type Skill } from "./skills.js";
@ -86,6 +81,11 @@ export interface CreateAgentSessionOptions {
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: discoverAuthStorage(agentDir) */
authStorage?: AuthStorage;
/** Model registry. Default: discoverModels(authStorage, agentDir) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'off' (clamped to model capabilities) */
@ -93,9 +93,6 @@ export interface CreateAgentSessionOptions {
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel: ThinkingLevel }>;
/** API key resolver. Default: defaultGetApiKey() */
getApiKey?: (model: Model<any>) => Promise<string | undefined>;
/** System prompt. String replaces default, function receives default and returns final. */
systemPrompt?: string | ((defaultPrompt: string) => string);
@ -177,73 +174,20 @@ function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Configure OAuth storage to use the specified agent directory.
* Must be called before using OAuth-based authentication.
*/
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
const oauthPath = join(agentDir, "oauth.json");
setOAuthStorage({
load: () => {
if (!existsSync(oauthPath)) {
return {};
}
try {
return JSON.parse(readFileSync(oauthPath, "utf-8"));
} catch {
return {};
}
},
save: (storage) => {
const dir = dirname(oauthPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(oauthPath, 0o600);
},
});
}
// Discovery Functions
/**
* Get all models (built-in + custom from models.json).
* Create an AuthStorage instance for the given agent directory.
*/
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
const { models, error } = loadAndMergeModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
export function discoverAuthStorage(agentDir: string = getDefaultAgentDir()): AuthStorage {
return new AuthStorage(join(agentDir, "auth.json"));
}
/**
* Get models that have valid API keys available.
* Create a ModelRegistry for the given agent directory.
*/
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
}
/**
* Find a model by provider and ID.
* @returns The model, or null if not found
*/
export function findModel(
provider: string,
modelId: string,
agentDir: string = getDefaultAgentDir(),
): Model<any> | null {
const { model, error } = findModelInternal(provider, modelId, agentDir);
if (error) {
throw new Error(error);
}
return model;
export function discoverModels(authStorage: AuthStorage, agentDir: string = getDefaultAgentDir()): ModelRegistry {
return new ModelRegistry(authStorage, join(agentDir, "models.json"));
}
/**
@ -326,30 +270,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
// API Key Helpers
/**
* Create the default API key resolver.
* Priority: OAuth > custom providers (models.json) > environment variables > settings.json apiKeys.
*
* OAuth takes priority so users logged in with a plan (e.g. unlimited tokens) aren't
* accidentally billed via a PAYG API key sitting in settings.json.
*/
export function defaultGetApiKey(
settingsManager?: SettingsManager,
): (model: Model<any>) => Promise<string | undefined> {
return async (model: Model<any>) => {
// Check OAuth, custom providers, env vars first
const resolvedKey = await getApiKeyForModel(model);
if (resolvedKey) {
return resolvedKey;
}
// Fall back to settings.json apiKeys
if (settingsManager) {
return settingsManager.getApiKey(model.provider);
}
return undefined;
};
}
// System Prompt
export interface BuildSystemPromptOptions {
@ -457,8 +377,9 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@mariozechner/pi-ai';
* const { session } = await createAgentSession({
* model: findModel('anthropic', 'claude-sonnet-4-20250514'),
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
@ -483,22 +404,16 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
// Configure OAuth storage for this agentDir
configureOAuthStorage(agentDir);
time("configureOAuthStorage");
// Use provided or create AuthStorage and ModelRegistry
const authStorage = options.authStorage ?? discoverAuthStorage(agentDir);
const modelRegistry = options.modelRegistry ?? discoverModels(authStorage, agentDir);
time("discoverModels");
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
time("settingsManager");
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, agentDir);
time("sessionManager");
// Helper to check API key availability (settings first, then OAuth/env vars)
const hasApiKey = async (m: Model<any>): Promise<boolean> => {
const settingsKey = settingsManager.getApiKey(m.provider);
if (settingsKey) return true;
return !!(await getApiKeyForModel(m));
};
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
time("loadSession");
@ -509,8 +424,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = findModel(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await hasApiKey(restoredModel))) {
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
@ -523,8 +438,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const defaultProvider = settingsManager.getDefaultProvider();
const defaultModelId = settingsManager.getDefaultModel();
if (defaultProvider && defaultModelId) {
const settingsModel = findModel(defaultProvider, defaultModelId);
if (settingsModel && (await hasApiKey(settingsModel))) {
const settingsModel = modelRegistry.find(defaultProvider, defaultModelId);
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
model = settingsModel;
}
}
@ -532,14 +447,13 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Fall back to first available model with a valid API key
if (!model) {
const allModels = discoverModels(agentDir);
for (const m of allModels) {
if (await hasApiKey(m)) {
for (const m of modelRegistry.getAll()) {
if (await modelRegistry.getApiKey(m)) {
model = m;
break;
}
}
time("discoverAvailableModels");
time("findAvailableModel");
if (model) {
if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
@ -567,8 +481,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
thinkingLevel = "off";
}
const getApiKey = options.getApiKey ?? defaultGetApiKey(settingsManager);
const skills = options.skills ?? discoverSkills(cwd, agentDir, settingsManager.getSkillsSettings());
time("discoverSkills");
@ -661,7 +573,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (!currentModel) {
throw new Error("No model selected");
}
const key = await getApiKey(currentModel);
const key = await modelRegistry.getApiKey(currentModel);
if (!key) {
throw new Error(`No API key found for provider "${currentModel.provider}"`);
}
@ -685,7 +597,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
hookRunner,
customTools: customToolsResult.tools,
skillsSettings: settingsManager.getSkillsSettings(),
resolveApiKey: getApiKey,
modelRegistry,
});
time("createAgentSession");