diff --git a/packages/coding-agent/src/core/model-config.ts b/packages/coding-agent/src/core/model-config.ts index a188b61e..8208bab3 100644 --- a/packages/coding-agent/src/core/model-config.ts +++ b/packages/coding-agent/src/core/model-config.ts @@ -15,7 +15,8 @@ import { import { type Static, Type } from "@sinclair/typebox"; import AjvModule from "ajv"; import { existsSync, readFileSync } from "fs"; -import { getModelsPath } from "../config.js"; +import { join } from "path"; +import { getAgentDir } from "../config.js"; import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js"; // Handle both default and named exports @@ -97,8 +98,8 @@ export function resolveApiKey(keyConfig: string): string | undefined { * Load custom models from models.json in agent config dir * Returns { models, error } - either models array or error message */ -function loadCustomModels(): { models: Model[]; error: string | null } { - const configPath = getModelsPath(); +function loadCustomModels(agentDir: string): { models: Model[]; error: string | null } { + const configPath = join(agentDir, "models.json"); if (!existsSync(configPath)) { return { models: [], error: null }; } @@ -232,7 +233,7 @@ function parseModels(config: ModelsConfig): Model[] { * Get all models (built-in + custom), freshly loaded * Returns { models, error } - either models array or error message */ -export function loadAndMergeModels(): { models: Model[]; error: string | null } { +export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model[]; error: string | null } { const builtInModels: Model[] = []; const providers = getProviders(); @@ -243,7 +244,7 @@ export function loadAndMergeModels(): { models: Model[]; error: string | nu } // Load custom models - const { models: customModels, error } = loadCustomModels(); + const { models: customModels, error } = loadCustomModels(agentDir); if (error) { return { models: [], error }; @@ -267,7 +268,8 @@ export function loadAndMergeModels(): { models: Model[]; error: string | nu /** * Get API key for a model (checks custom providers first, then built-in) - * Now async to support OAuth token refresh + * Now async to support OAuth token refresh. + * Note: OAuth storage location is configured globally via setOAuthStorage. */ export async function getApiKeyForModel(model: Model): Promise { // For custom providers, check their apiKey config @@ -357,8 +359,10 @@ export async function getApiKeyForModel(model: Model): Promise[]; error: string | null }> { - const { models: allModels, error } = loadAndMergeModels(); +export async function getAvailableModels( + agentDir: string = getAgentDir(), +): Promise<{ models: Model[]; error: string | null }> { + const { models: allModels, error } = loadAndMergeModels(agentDir); if (error) { return { models: [], error }; @@ -390,8 +394,12 @@ export async function getAvailableModels(): Promise<{ models: Model[]; erro * Find a specific model by provider and ID * Returns { model, error } - either model or error message */ -export function findModel(provider: string, modelId: string): { model: Model | null; error: string | null } { - const { models: allModels, error } = loadAndMergeModels(); +export function findModel( + provider: string, + modelId: string, + agentDir: string = getAgentDir(), +): { model: Model | null; error: string | null } { + const { models: allModels, error } = loadAndMergeModels(agentDir); if (error) { return { model: null, error }; diff --git a/packages/coding-agent/src/core/sdk.ts b/packages/coding-agent/src/core/sdk.ts index 9b2f3220..0ecf4f24 100644 --- a/packages/coding-agent/src/core/sdk.ts +++ b/packages/coding-agent/src/core/sdk.ts @@ -30,7 +30,9 @@ */ import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core"; -import type { Model } from "@mariozechner/pi-ai"; +import { type Model, setOAuthStorage } from "@mariozechner/pi-ai"; +import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; +import { dirname, join } from "path"; import { getAgentDir } from "../config.js"; import { AgentSession } from "./agent-session.js"; import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js"; @@ -154,14 +156,42 @@ function getDefaultAgentDir(): string { return getAgentDir(); } +/** + * Configure OAuth storage to use the specified agent directory. + * Must be called before using OAuth-based authentication. + */ +export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void { + const oauthPath = join(agentDir, "oauth.json"); + + setOAuthStorage({ + load: () => { + if (!existsSync(oauthPath)) { + return {}; + } + try { + return JSON.parse(readFileSync(oauthPath, "utf-8")); + } catch { + return {}; + } + }, + save: (storage) => { + const dir = dirname(oauthPath); + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true, mode: 0o700 }); + } + writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8"); + chmodSync(oauthPath, 0o600); + }, + }); +} + // Discovery Functions /** * Get all models (built-in + custom from models.json). - * Note: Uses default agentDir for models.json location. */ -export function discoverModels(): Model[] { - const { models, error } = loadAndMergeModels(); +export function discoverModels(agentDir: string = getDefaultAgentDir()): Model[] { + const { models, error } = loadAndMergeModels(agentDir); if (error) { throw new Error(error); } @@ -170,10 +200,9 @@ export function discoverModels(): Model[] { /** * Get models that have valid API keys available. - * Note: Uses default agentDir for models.json and oauth.json location. */ -export async function discoverAvailableModels(): Promise[]> { - const { models, error } = await getAvailableModels(); +export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise[]> { + const { models, error } = await getAvailableModels(agentDir); if (error) { throw new Error(error); } @@ -182,11 +211,14 @@ export async function discoverAvailableModels(): Promise[]> { /** * Find a model by provider and ID. - * Note: Uses default agentDir for models.json location. * @returns The model, or null if not found */ -export function findModel(provider: string, modelId: string): Model | null { - const { model, error } = findModelInternal(provider, modelId); +export function findModel( + provider: string, + modelId: string, + agentDir: string = getDefaultAgentDir(), +): Model | null { + const { model, error } = findModelInternal(provider, modelId, agentDir); if (error) { throw new Error(error); } @@ -276,7 +308,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas /** * Create the default API key resolver. * Checks custom providers (models.json), OAuth, and environment variables. - * Note: Uses default agentDir for models.json and oauth.json location. */ export function defaultGetApiKey(): (model: Model) => Promise { return getApiKeyForModel; @@ -415,6 +446,9 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} const cwd = options.cwd ?? process.cwd(); const agentDir = options.agentDir ?? getDefaultAgentDir(); + // Configure OAuth storage for this agentDir + configureOAuthStorage(agentDir); + const settingsManager = new SettingsManager(agentDir); const sessionManager = options.sessionManager ?? SessionManager.create(cwd); diff --git a/packages/coding-agent/src/index.ts b/packages/coding-agent/src/index.ts index 14c207ef..93056daa 100644 --- a/packages/coding-agent/src/index.ts +++ b/packages/coding-agent/src/index.ts @@ -85,11 +85,12 @@ export { } from "./core/oauth/index.js"; // SDK for programmatic usage export { - allBuiltInTools, type BuildSystemPromptOptions, buildSystemPrompt, type CreateAgentSessionOptions, type CreateAgentSessionResult, + // Configuration + configureOAuthStorage, // Factory createAgentSession, // Helpers diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index 217b563c..37bb13aa 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -6,21 +6,20 @@ */ import type { Attachment } from "@mariozechner/pi-agent-core"; -import { setOAuthStorage, supportsXhigh } from "@mariozechner/pi-ai"; +import { supportsXhigh } from "@mariozechner/pi-ai"; import chalk from "chalk"; -import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; -import { dirname } from "path"; +import { existsSync, readFileSync } from "fs"; import { type Args, parseArgs, printHelp } from "./cli/args.js"; import { processFileArguments } from "./cli/file-processor.js"; import { listModels } from "./cli/list-models.js"; import { selectSession } from "./cli/session-picker.js"; -import { getModelsPath, getOAuthPath, VERSION } from "./config.js"; +import { getModelsPath, VERSION } from "./config.js"; import type { AgentSession } from "./core/agent-session.js"; import type { LoadedCustomTool } from "./core/custom-tools/index.js"; import { exportFromFile } from "./core/export-html.js"; import { findModel } from "./core/model-config.js"; import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js"; -import { type CreateAgentSessionOptions, createAgentSession } from "./core/sdk.js"; +import { type CreateAgentSessionOptions, configureOAuthStorage, createAgentSession } from "./core/sdk.js"; import { SessionManager } from "./core/session-manager.js"; import { SettingsManager } from "./core/settings-manager.js"; import { allTools } from "./core/tools/index.js"; @@ -29,31 +28,6 @@ import { initTheme, stopThemeWatcher } from "./modes/interactive/theme/theme.js" import { getChangelogPath, getNewEntries, parseChangelog } from "./utils/changelog.js"; import { ensureTool } from "./utils/tools-manager.js"; -function configureOAuthStorage(): void { - const oauthPath = getOAuthPath(); - - setOAuthStorage({ - load: () => { - if (!existsSync(oauthPath)) { - return {}; - } - try { - return JSON.parse(readFileSync(oauthPath, "utf-8")); - } catch { - return {}; - } - }, - save: (storage) => { - const dir = dirname(oauthPath); - if (!existsSync(dir)) { - mkdirSync(dir, { recursive: true, mode: 0o700 }); - } - writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8"); - chmodSync(oauthPath, 0o600); - }, - }); -} - async function checkForNewVersion(currentVersion: string): Promise { try { const response = await fetch("https://registry.npmjs.org/@mariozechner/pi-coding-agent/latest");