Refactor OAuth/API key handling: AuthStorage and ModelRegistry

- Add AuthStorage class for credential storage (auth.json)
- Add ModelRegistry class for model management with API key resolution
- Add discoverAuthStorage() and discoverModels() discovery functions
- Add migration from legacy oauth.json and settings.json apiKeys to auth.json
- Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels
- Remove apiKeys from Settings type and SettingsManager methods
- Rename getOAuthPath to getAuthPath
- Update SDK, examples, docs, tests, and mom package

Fixes #296
This commit is contained in:
Mario Zechner 2025-12-25 03:48:36 +01:00
parent 9f97f0c8da
commit 54018b6cc0
29 changed files with 953 additions and 2017 deletions

View file

@ -23,7 +23,7 @@ import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custo
import { exportSessionToHtml } from "./export-html.js";
import type { HookRunner, SessionEventResult, TurnEndEvent, TurnStartEvent } from "./hooks/index.js";
import type { BashExecutionMessage } from "./messages.js";
import { getApiKeyForModel, getAvailableModels } from "./models-json.js";
import type { ModelRegistry } from "./model-registry.js";
import type { CompactionEntry, SessionManager } from "./session-manager.js";
import type { SettingsManager, SkillsSettings } from "./settings-manager.js";
import { expandSlashCommand, type FileSlashCommand } from "./slash-commands.js";
@ -56,8 +56,8 @@ export interface AgentSessionConfig {
/** Custom tools for session lifecycle events */
customTools?: LoadedCustomTool[];
skillsSettings?: Required<SkillsSettings>;
/** Resolve API key for a model. Default: getApiKeyForModel */
resolveApiKey?: (model: Model<any>) => Promise<string | undefined>;
/** Model registry for API key resolution and model discovery */
modelRegistry: ModelRegistry;
}
/** Options for AgentSession.prompt() */
@ -153,8 +153,8 @@ export class AgentSession {
private _skillsSettings: Required<SkillsSettings> | undefined;
// API key resolver
private _resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
// Model registry for API key resolution
private _modelRegistry: ModelRegistry;
constructor(config: AgentSessionConfig) {
this.agent = config.agent;
@ -165,7 +165,12 @@ export class AgentSession {
this._hookRunner = config.hookRunner ?? null;
this._customTools = config.customTools ?? [];
this._skillsSettings = config.skillsSettings;
this._resolveApiKey = config.resolveApiKey ?? getApiKeyForModel;
this._modelRegistry = config.modelRegistry;
}
/** Model registry for API key resolution and model discovery */
get modelRegistry(): ModelRegistry {
return this._modelRegistry;
}
// =========================================================================
@ -434,7 +439,7 @@ export class AgentSession {
}
// Validate API key
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
throw new Error(
`No API key found for ${this.model.provider}.\n\n` +
@ -561,7 +566,7 @@ export class AgentSession {
* @throws Error if no API key available for the model
*/
async setModel(model: Model<any>): Promise<void> {
const apiKey = await this._resolveApiKey(model);
const apiKey = await this._modelRegistry.getApiKey(model);
if (!apiKey) {
throw new Error(`No API key for ${model.provider}/${model.id}`);
}
@ -599,7 +604,7 @@ export class AgentSession {
const next = this._scopedModels[nextIndex];
// Validate API key
const apiKey = await this._resolveApiKey(next.model);
const apiKey = await this._modelRegistry.getApiKey(next.model);
if (!apiKey) {
throw new Error(`No API key for ${next.model.provider}/${next.model.id}`);
}
@ -616,10 +621,7 @@ export class AgentSession {
}
private async _cycleAvailableModel(): Promise<ModelCycleResult | null> {
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
if (error) throw new Error(`Failed to load models: ${error}`);
const availableModels = await this._modelRegistry.getAvailable();
if (availableModels.length <= 1) return null;
const currentModel = this.model;
@ -631,7 +633,7 @@ export class AgentSession {
const nextIndex = (currentIndex + 1) % availableModels.length;
const nextModel = availableModels[nextIndex];
const apiKey = await this._resolveApiKey(nextModel);
const apiKey = await this._modelRegistry.getApiKey(nextModel);
if (!apiKey) {
throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`);
}
@ -650,11 +652,7 @@ export class AgentSession {
* Get all available models with valid API keys.
*/
async getAvailableModels(): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
if (error) throw new Error(error);
return models;
return this._modelRegistry.getAvailable();
}
// =========================================================================
@ -747,7 +745,7 @@ export class AgentSession {
throw new Error("No model selected");
}
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
throw new Error(`No API key for ${this.model.provider}`);
}
@ -786,7 +784,7 @@ export class AgentSession {
tokensBefore: preparation.tokensBefore,
customInstructions,
model: this.model,
resolveApiKey: this._resolveApiKey,
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
signal: this._compactionAbortController.signal,
})) as SessionEventResult | undefined;
@ -908,7 +906,7 @@ export class AgentSession {
return;
}
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false });
return;
@ -948,7 +946,7 @@ export class AgentSession {
tokensBefore: preparation.tokensBefore,
customInstructions: undefined,
model: this.model,
resolveApiKey: this._resolveApiKey,
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
signal: this._autoCompactionAbortController.signal,
})) as SessionEventResult | undefined;
@ -1334,9 +1332,7 @@ export class AgentSession {
// Restore model if saved
if (sessionContext.model) {
const availableModels = (
await getAvailableModels(undefined, (provider) => this.settingsManager.getApiKey(provider))
).models;
const availableModels = await this._modelRegistry.getAvailable();
const match = availableModels.find(
(m) => m.provider === sessionContext.model!.provider && m.id === sessionContext.model!.modelId,
);

View file

@ -3,9 +3,18 @@
* Handles loading, saving, and refreshing credentials from auth.json.
*/
import { getApiKeyFromEnv, getOAuthApiKey, type OAuthCredentials, type OAuthProvider } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname } from "path";
import {
getEnvApiKey,
getOAuthApiKey,
loginAnthropic,
loginAntigravity,
loginGeminiCli,
loginGitHubCopilot,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, renameSync, writeFileSync } from "fs";
import { dirname, join } from "path";
export type ApiKeyCredential = {
type: "api_key";
@ -25,11 +34,29 @@ export type AuthStorageData = Record<string, AuthCredential>;
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
constructor(private authPath: string) {
this.reload();
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
this.fallbackResolver = resolver;
}
/**
* Reload credentials from disk.
*/
@ -101,14 +128,69 @@ export class AuthStorage {
return { ...this.data };
}
/**
* Login to an OAuth provider.
*/
async login(
provider: OAuthProvider,
callbacks: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress);
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
this.set(provider, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Get API key for a provider.
* Priority:
* 1. API key from auth.json
* 2. OAuth token from auth.json (auto-refreshed)
* 3. Environment variable (via getApiKeyFromEnv)
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | null> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[provider];
if (cred?.type === "api_key") {
@ -116,30 +198,83 @@ export class AuthStorage {
}
if (cred?.type === "oauth") {
// Build OAuthCredentials map (without type discriminator)
// Filter to only oauth credentials for getOAuthApiKey
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(this.data)) {
if (value.type === "oauth") {
const { type: _, ...rest } = value;
oauthCreds[key] = rest;
oauthCreds[key] = value;
}
}
try {
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
if (result) {
// Save refreshed credentials
this.data[provider] = { type: "oauth", ...result.newCredentials };
this.save();
return result.apiKey;
}
} catch {
// Token refresh failed, remove invalid credentials
this.remove(provider);
}
}
// Fall back to environment variable
return getApiKeyFromEnv(provider) ?? null;
const envKey = getEnvApiKey(provider);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? null;
}
/**
* Migrate credentials from legacy oauth.json and settings.json apiKeys to auth.json.
* Only runs if auth.json doesn't exist yet. Returns list of migrated providers.
*/
static migrateLegacy(authPath: string, agentDir: string): string[] {
const oauthPath = join(agentDir, "oauth.json");
const settingsPath = join(agentDir, "settings.json");
// Skip if auth.json already exists
if (existsSync(authPath)) return [];
const migrated: AuthStorageData = {};
const providers: string[] = [];
// Migrate oauth.json
if (existsSync(oauthPath)) {
try {
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
for (const [provider, cred] of Object.entries(oauth)) {
migrated[provider] = { type: "oauth", ...(cred as object) } as OAuthCredential;
providers.push(provider);
}
renameSync(oauthPath, `${oauthPath}.migrated`);
} catch {}
}
// Migrate settings.json apiKeys
if (existsSync(settingsPath)) {
try {
const content = readFileSync(settingsPath, "utf-8");
const settings = JSON.parse(content);
if (settings.apiKeys && typeof settings.apiKeys === "object") {
for (const [provider, key] of Object.entries(settings.apiKeys)) {
if (!migrated[provider] && typeof key === "string") {
migrated[provider] = { type: "api_key", key };
providers.push(provider);
}
}
delete settings.apiKeys;
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
}
} catch {}
}
if (Object.keys(migrated).length > 0) {
mkdirSync(dirname(authPath), { recursive: true });
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
}
return providers;
}
}

View file

@ -0,0 +1,315 @@
/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import {
type Api,
getGitHubCopilotBaseUrl,
getModels,
getProviders,
type KnownProvider,
type Model,
normalizeDomain,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import type { AuthStorage } from "./auth-storage.js";
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for OpenAI compatibility settings
const OpenAICompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
});
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/**
* Resolve an API key config value to an actual key.
* Checks environment variable first, then treats as literal.
*/
function resolveApiKeyConfig(keyConfig: string): string | undefined {
const envValue = process.env[keyConfig];
if (envValue) return envValue;
return keyConfig;
}
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private loadError: string | null = null;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | null = null,
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveApiKeyConfig(keyConfig);
}
return undefined;
});
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = null;
this.loadModels();
}
/**
* Get any error from loading models.json (null if no error).
*/
getError(): string | null {
return this.loadError;
}
private loadModels(): void {
// Load built-in models
const builtInModels: Model<Api>[] = [];
for (const provider of getProviders()) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models from models.json (if path provided)
let customModels: Model<Api>[] = [];
if (this.modelsJsonPath) {
const result = this.loadCustomModels(this.modelsJsonPath);
if (result.error) {
this.loadError = result.error;
// Keep built-in models even if custom models failed to load
} else {
customModels = result.models;
}
}
const combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth credentials
const copilotCred = this.authStorage.get("github-copilot");
if (copilotCred?.type === "oauth") {
const domain = copilotCred.enterpriseUrl
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
} else {
this.models = combined;
}
}
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
if (!existsSync(modelsJsonPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
};
}
// Additional validation
this.validateConfig(config);
// Parse models
return { models: this.parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for fallback resolver
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
let headers =
providerConfig.headers || modelDef.headers
? { ...providerConfig.headers, ...modelDef.headers }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
}
/**
* Get only models that have valid API keys available.
*/
async getAvailable(): Promise<Model<Api>[]> {
const available: Model<Api>[] = [];
for (const model of this.models) {
const apiKey = await this.authStorage.getApiKey(model.provider);
if (apiKey) {
available.push(model);
}
}
return available;
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | null {
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | null> {
return this.authStorage.getApiKey(model.provider);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
}

View file

@ -6,8 +6,7 @@ import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Api, KnownProvider, Model } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { isValidThinkingLevel } from "../cli/args.js";
import { findModel, getApiKeyForModel, getAvailableModels } from "./models-json.js";
import type { SettingsManager } from "./settings-manager.js";
import type { ModelRegistry } from "./model-registry.js";
/** Default model IDs for each known provider */
export const defaultModelPerProvider: Record<KnownProvider, string> = {
@ -167,21 +166,9 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
* The algorithm tries to match the full pattern first, then progressively
* strips colon-suffixes to find a match.
*
* @param patterns - Model patterns to resolve
* @param settingsManager - Optional settings manager for API key fallback from settings.json
*/
export async function resolveModelScope(patterns: string[], settingsManager?: SettingsManager): Promise<ScopedModel[]> {
const { models: availableModels, error } = await getAvailableModels(
undefined,
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
);
if (error) {
console.warn(chalk.yellow(`Warning: Error loading models: ${error}`));
return [];
}
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
const availableModels = await modelRegistry.getAvailable();
const scopedModels: ScopedModel[] = [];
for (const pattern of patterns) {
@ -224,20 +211,28 @@ export async function findInitialModel(options: {
cliModel?: string;
scopedModels: ScopedModel[];
isContinuing: boolean;
settingsManager: SettingsManager;
defaultProvider?: string;
defaultModelId?: string;
defaultThinkingLevel?: ThinkingLevel;
modelRegistry: ModelRegistry;
}): Promise<InitialModelResult> {
const { cliProvider, cliModel, scopedModels, isContinuing, settingsManager } = options;
const {
cliProvider,
cliModel,
scopedModels,
isContinuing,
defaultProvider,
defaultModelId,
defaultThinkingLevel,
modelRegistry,
} = options;
let model: Model<Api> | null = null;
let thinkingLevel: ThinkingLevel = "off";
// 1. CLI args take priority
if (cliProvider && cliModel) {
const { model: found, error } = findModel(cliProvider, cliModel);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const found = modelRegistry.find(cliProvider, cliModel);
if (!found) {
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
process.exit(1);
@ -255,34 +250,19 @@ export async function findInitialModel(options: {
}
// 3. Try saved default from settings
const defaultProvider = settingsManager.getDefaultProvider();
const defaultModelId = settingsManager.getDefaultModel();
if (defaultProvider && defaultModelId) {
const { model: found, error } = findModel(defaultProvider, defaultModelId);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const found = modelRegistry.find(defaultProvider, defaultModelId);
if (found) {
model = found;
// Also load saved thinking level
const savedThinking = settingsManager.getDefaultThinkingLevel();
if (savedThinking) {
thinkingLevel = savedThinking;
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: null };
}
}
// 4. Try first available model with valid API key
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
settingsManager.getApiKey(provider),
);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
@ -310,17 +290,12 @@ export async function restoreModelFromSession(
savedModelId: string,
currentModel: Model<Api> | null,
shouldPrintMessages: boolean,
settingsManager?: SettingsManager,
modelRegistry: ModelRegistry,
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
const { model: restoredModel, error } = findModel(savedProvider, savedModelId);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
const hasApiKey = restoredModel ? !!(await getApiKeyForModel(restoredModel)) : false;
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
if (restoredModel && hasApiKey) {
if (shouldPrintMessages) {
@ -348,14 +323,7 @@ export async function restoreModelFromSession(
}
// Try to find any available model
const { models: availableModels, error: availableError } = await getAvailableModels(
undefined,
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
);
if (availableError) {
console.error(chalk.red(availableError));
process.exit(1);
}
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers

View file

@ -1,467 +0,0 @@
import {
type Api,
getApiKey,
getGitHubCopilotBaseUrl,
getModels,
getProviders,
type KnownProvider,
loadOAuthCredentials,
type Model,
normalizeDomain,
refreshGitHubCopilotToken,
removeOAuthCredentials,
saveOAuthCredentials,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for OpenAI compatibility settings
const OpenAICompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
});
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
// Custom provider API key mappings (provider name -> apiKey config)
const customProviderApiKeys: Map<string, string> = new Map();
/**
* Resolve an API key config value to an actual key.
* First checks if it's an environment variable, then treats as literal.
*/
export function resolveApiKey(keyConfig: string): string | undefined {
// First check if it's an env var name
const envValue = process.env[keyConfig];
if (envValue) return envValue;
// Otherwise treat as literal API key
return keyConfig;
}
/**
* Load custom models from a models.json file
* Returns { models, error } - either models array or error message
*/
function loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
if (!existsSync(modelsJsonPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
};
}
// Additional validation
try {
validateConfig(config);
} catch (error) {
return {
models: [],
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
// Parse models
return { models: parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
}
/**
* Validate config structure and requirements
*/
function validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
`Set at provider or model level.`,
);
}
// Validate required fields
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
/**
* Parse config into Model objects
*/
function parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
// Clear and rebuild custom provider API key mappings
customProviderApiKeys.clear();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for this provider
customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
// Model-level api overrides provider-level api
const api = modelDef.api || providerConfig.api;
if (!api) {
// This should have been caught by validateConfig, but be safe
continue;
}
// Merge headers: provider headers are base, model headers override
let headers =
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
const resolvedKey = resolveApiKey(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom), freshly loaded
* Returns { models, error } - either models array or error message
*/
export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model<Api>[]; error: string | null } {
const builtInModels: Model<Api>[] = [];
const providers = getProviders();
// Load all built-in models
for (const provider of providers) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models
const { models: customModels, error } = loadCustomModels(join(agentDir, "models.json"));
if (error) {
return { models: [], error };
}
const combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth token or enterprise domain
const copilotCreds = loadOAuthCredentials("github-copilot");
if (copilotCreds) {
const domain = copilotCreds.enterpriseUrl ? normalizeDomain(copilotCreds.enterpriseUrl) : undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCreds.access, domain ?? undefined);
return {
models: combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)),
error: null,
};
}
return { models: combined, error: null };
}
/**
* Get API key for a model (checks custom providers first, then built-in)
* Now async to support OAuth token refresh.
* Note: OAuth storage location is configured globally via setOAuthStorage.
*/
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
// For custom providers, check their apiKey config
const customKeyConfig = customProviderApiKeys.get(model.provider);
if (customKeyConfig) {
return resolveApiKey(customKeyConfig);
}
// For Anthropic, check OAuth first
if (model.provider === "anthropic") {
// 1. Check OAuth storage (auto-refresh if needed)
const oauthToken = await getOAuthToken("anthropic");
if (oauthToken) {
return oauthToken;
}
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv) {
return oauthEnv;
}
// 3. Fall back to ANTHROPIC_API_KEY env var
}
if (model.provider === "github-copilot") {
// 1. Check OAuth storage (from device flow login)
const oauthToken = await getOAuthToken("github-copilot");
if (oauthToken) {
return oauthToken;
}
// 2. Use GitHub token directly (works with copilot scope on github.com)
const githubToken = process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
if (!githubToken) {
return undefined;
}
// 3. For enterprise, exchange token for short-lived Copilot token
const enterpriseDomain = process.env.COPILOT_ENTERPRISE_URL
? normalizeDomain(process.env.COPILOT_ENTERPRISE_URL)
: undefined;
if (enterpriseDomain) {
const creds = await refreshGitHubCopilotToken(githubToken, enterpriseDomain);
saveOAuthCredentials("github-copilot", creds);
return creds.access;
}
// 4. For github.com, use token directly
return githubToken;
}
// For Google Gemini CLI and Antigravity, check OAuth and encode projectId with token
if (model.provider === "google-gemini-cli" || model.provider === "google-antigravity") {
const oauthProvider = model.provider as "google-gemini-cli" | "google-antigravity";
const credentials = loadOAuthCredentials(oauthProvider);
if (!credentials) {
return undefined;
}
// Check if token is expired
if (Date.now() >= credentials.expires) {
try {
await refreshToken(oauthProvider);
const refreshedCreds = loadOAuthCredentials(oauthProvider);
if (refreshedCreds?.projectId) {
return JSON.stringify({ token: refreshedCreds.access, projectId: refreshedCreds.projectId });
}
} catch {
removeOAuthCredentials(oauthProvider);
return undefined;
}
}
if (credentials.projectId) {
return JSON.stringify({ token: credentials.access, projectId: credentials.projectId });
}
return undefined;
}
// For built-in providers, use getApiKey from @mariozechner/pi-ai
return getApiKey(model.provider as KnownProvider);
}
/**
* Get only models that have valid API keys available
* Returns { models, error } - either models array or error message
*
* @param agentDir - Agent config directory
* @param fallbackKeyResolver - Optional function to check for API keys not found by getApiKeyForModel
* (e.g., keys from settings.json)
*/
export async function getAvailableModels(
agentDir: string = getAgentDir(),
fallbackKeyResolver?: (provider: string) => string | undefined,
): Promise<{ models: Model<Api>[]; error: string | null }> {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { models: [], error };
}
const availableModels: Model<Api>[] = [];
for (const model of allModels) {
let apiKey = await getApiKeyForModel(model);
// Check fallback resolver if primary lookup failed
if (!apiKey && fallbackKeyResolver) {
apiKey = fallbackKeyResolver(model.provider);
}
if (apiKey) {
availableModels.push(model);
}
}
return { models: availableModels, error: null };
}
/**
* Find a specific model by provider and ID.
*
* Searches models from:
* 1. Built-in models from @mariozechner/pi-ai
* 2. Custom models defined in ~/.pi/agent/models.json
*
* Returns { model, error } - either the model or an error message.
*/
export function findModel(
provider: string,
modelId: string,
agentDir: string = getAgentDir(),
): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { model: null, error };
}
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
return { model, error: null };
}
/**
* Mapping from model provider to OAuth provider ID.
* Only providers that support OAuth are listed here.
*/
const providerToOAuthProvider: Record<string, OAuthProvider> = {
anthropic: "anthropic",
"github-copilot": "github-copilot",
"google-gemini-cli": "google-gemini-cli",
"google-antigravity": "google-antigravity",
};
// Cache for OAuth status per provider (avoids file reads on every render)
const oauthStatusCache: Map<string, boolean> = new Map();
/**
* Invalidate the OAuth status cache.
* Call this after login/logout operations.
*/
export function invalidateOAuthCache(): void {
oauthStatusCache.clear();
}
/**
* Check if a model is using OAuth credentials (subscription).
* This checks if OAuth credentials exist and would be used for the model,
* without actually fetching or refreshing the token.
* Results are cached until invalidateOAuthCache() is called.
*/
export function isModelUsingOAuth(model: Model<Api>): boolean {
const oauthProvider = providerToOAuthProvider[model.provider];
if (!oauthProvider) {
return false;
}
// Check cache first
if (oauthStatusCache.has(oauthProvider)) {
return oauthStatusCache.get(oauthProvider)!;
}
// Check if OAuth credentials exist for this provider
let usingOAuth = false;
const credentials = loadOAuthCredentials(oauthProvider);
if (credentials) {
usingOAuth = true;
}
// Also check for manual OAuth token env var (for Anthropic)
if (!usingOAuth && model.provider === "anthropic" && process.env.ANTHROPIC_OAUTH_TOKEN) {
usingOAuth = true;
}
oauthStatusCache.set(oauthProvider, usingOAuth);
return usingOAuth;
}

View file

@ -30,22 +30,17 @@
*/
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import type { Model } from "@mariozechner/pi-ai";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
import type { CustomAgentTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js";
import { messageTransformer } from "./messages.js";
import {
findModel as findModelInternal,
getApiKeyForModel,
getAvailableModels,
loadAndMergeModels,
} from "./models-json.js";
import { ModelRegistry } from "./model-registry.js";
import { SessionManager } from "./session-manager.js";
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
import { loadSkills as loadSkillsInternal, type Skill } from "./skills.js";
@ -86,6 +81,11 @@ export interface CreateAgentSessionOptions {
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: discoverAuthStorage(agentDir) */
authStorage?: AuthStorage;
/** Model registry. Default: discoverModels(authStorage, agentDir) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'off' (clamped to model capabilities) */
@ -93,9 +93,6 @@ export interface CreateAgentSessionOptions {
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel: ThinkingLevel }>;
/** API key resolver. Default: defaultGetApiKey() */
getApiKey?: (model: Model<any>) => Promise<string | undefined>;
/** System prompt. String replaces default, function receives default and returns final. */
systemPrompt?: string | ((defaultPrompt: string) => string);
@ -177,73 +174,20 @@ function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Configure OAuth storage to use the specified agent directory.
* Must be called before using OAuth-based authentication.
*/
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
const oauthPath = join(agentDir, "oauth.json");
setOAuthStorage({
load: () => {
if (!existsSync(oauthPath)) {
return {};
}
try {
return JSON.parse(readFileSync(oauthPath, "utf-8"));
} catch {
return {};
}
},
save: (storage) => {
const dir = dirname(oauthPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(oauthPath, 0o600);
},
});
}
// Discovery Functions
/**
* Get all models (built-in + custom from models.json).
* Create an AuthStorage instance for the given agent directory.
*/
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
const { models, error } = loadAndMergeModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
export function discoverAuthStorage(agentDir: string = getDefaultAgentDir()): AuthStorage {
return new AuthStorage(join(agentDir, "auth.json"));
}
/**
* Get models that have valid API keys available.
* Create a ModelRegistry for the given agent directory.
*/
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
}
/**
* Find a model by provider and ID.
* @returns The model, or null if not found
*/
export function findModel(
provider: string,
modelId: string,
agentDir: string = getDefaultAgentDir(),
): Model<any> | null {
const { model, error } = findModelInternal(provider, modelId, agentDir);
if (error) {
throw new Error(error);
}
return model;
export function discoverModels(authStorage: AuthStorage, agentDir: string = getDefaultAgentDir()): ModelRegistry {
return new ModelRegistry(authStorage, join(agentDir, "models.json"));
}
/**
@ -326,30 +270,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
// API Key Helpers
/**
* Create the default API key resolver.
* Priority: OAuth > custom providers (models.json) > environment variables > settings.json apiKeys.
*
* OAuth takes priority so users logged in with a plan (e.g. unlimited tokens) aren't
* accidentally billed via a PAYG API key sitting in settings.json.
*/
export function defaultGetApiKey(
settingsManager?: SettingsManager,
): (model: Model<any>) => Promise<string | undefined> {
return async (model: Model<any>) => {
// Check OAuth, custom providers, env vars first
const resolvedKey = await getApiKeyForModel(model);
if (resolvedKey) {
return resolvedKey;
}
// Fall back to settings.json apiKeys
if (settingsManager) {
return settingsManager.getApiKey(model.provider);
}
return undefined;
};
}
// System Prompt
export interface BuildSystemPromptOptions {
@ -457,8 +377,9 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@mariozechner/pi-ai';
* const { session } = await createAgentSession({
* model: findModel('anthropic', 'claude-sonnet-4-20250514'),
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
@ -483,22 +404,16 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
// Configure OAuth storage for this agentDir
configureOAuthStorage(agentDir);
time("configureOAuthStorage");
// Use provided or create AuthStorage and ModelRegistry
const authStorage = options.authStorage ?? discoverAuthStorage(agentDir);
const modelRegistry = options.modelRegistry ?? discoverModels(authStorage, agentDir);
time("discoverModels");
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
time("settingsManager");
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, agentDir);
time("sessionManager");
// Helper to check API key availability (settings first, then OAuth/env vars)
const hasApiKey = async (m: Model<any>): Promise<boolean> => {
const settingsKey = settingsManager.getApiKey(m.provider);
if (settingsKey) return true;
return !!(await getApiKeyForModel(m));
};
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
time("loadSession");
@ -509,8 +424,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = findModel(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await hasApiKey(restoredModel))) {
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
@ -523,8 +438,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const defaultProvider = settingsManager.getDefaultProvider();
const defaultModelId = settingsManager.getDefaultModel();
if (defaultProvider && defaultModelId) {
const settingsModel = findModel(defaultProvider, defaultModelId);
if (settingsModel && (await hasApiKey(settingsModel))) {
const settingsModel = modelRegistry.find(defaultProvider, defaultModelId);
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
model = settingsModel;
}
}
@ -532,14 +447,13 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Fall back to first available model with a valid API key
if (!model) {
const allModels = discoverModels(agentDir);
for (const m of allModels) {
if (await hasApiKey(m)) {
for (const m of modelRegistry.getAll()) {
if (await modelRegistry.getApiKey(m)) {
model = m;
break;
}
}
time("discoverAvailableModels");
time("findAvailableModel");
if (model) {
if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
@ -567,8 +481,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
thinkingLevel = "off";
}
const getApiKey = options.getApiKey ?? defaultGetApiKey(settingsManager);
const skills = options.skills ?? discoverSkills(cwd, agentDir, settingsManager.getSkillsSettings());
time("discoverSkills");
@ -661,7 +573,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (!currentModel) {
throw new Error("No model selected");
}
const key = await getApiKey(currentModel);
const key = await modelRegistry.getApiKey(currentModel);
if (!key) {
throw new Error(`No API key found for provider "${currentModel.provider}"`);
}
@ -685,7 +597,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
hookRunner,
customTools: customToolsResult.tools,
skillsSettings: settingsManager.getSkillsSettings(),
resolveApiKey: getApiKey,
modelRegistry,
});
time("createAgentSession");

View file

@ -47,7 +47,6 @@ export interface Settings {
customTools?: string[]; // Array of custom tool file paths
skills?: SkillsSettings;
terminal?: TerminalSettings;
apiKeys?: Record<string, string>; // provider -> API key (e.g., { "anthropic": "sk-..." })
}
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
@ -366,27 +365,4 @@ export class SettingsManager {
this.globalSettings.terminal.showImages = show;
this.save();
}
getApiKey(provider: string): string | undefined {
return this.settings.apiKeys?.[provider];
}
setApiKey(provider: string, key: string): void {
if (!this.globalSettings.apiKeys) {
this.globalSettings.apiKeys = {};
}
this.globalSettings.apiKeys[provider] = key;
this.save();
}
removeApiKey(provider: string): void {
if (this.globalSettings.apiKeys) {
delete this.globalSettings.apiKeys[provider];
this.save();
}
}
getApiKeys(): Record<string, string> {
return this.settings.apiKeys ?? {};
}
}