Add agentDir parameter to model config functions and OAuth storage

- loadAndMergeModels(agentDir) - loads models.json from agentDir
- getAvailableModels(agentDir) - uses agentDir for model discovery
- findModel(provider, id, agentDir) - uses agentDir for model lookup
- configureOAuthStorage(agentDir) - configures OAuth to use agentDir/oauth.json
- createAgentSession calls configureOAuthStorage with options.agentDir
- Export configureOAuthStorage from SDK
This commit is contained in:
Mario Zechner 2025-12-22 01:42:27 +01:00
parent 6d4ff74430
commit b168a6cae3
4 changed files with 69 additions and 52 deletions

View file

@ -15,7 +15,8 @@ import {
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { getModelsPath } from "../config.js";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js";
// Handle both default and named exports
@ -97,8 +98,8 @@ export function resolveApiKey(keyConfig: string): string | undefined {
* Load custom models from models.json in agent config dir
* Returns { models, error } - either models array or error message
*/
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
const configPath = getModelsPath();
function loadCustomModels(agentDir: string): { models: Model<Api>[]; error: string | null } {
const configPath = join(agentDir, "models.json");
if (!existsSync(configPath)) {
return { models: [], error: null };
}
@ -232,7 +233,7 @@ function parseModels(config: ModelsConfig): Model<Api>[] {
* Get all models (built-in + custom), freshly loaded
* Returns { models, error } - either models array or error message
*/
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model<Api>[]; error: string | null } {
const builtInModels: Model<Api>[] = [];
const providers = getProviders();
@ -243,7 +244,7 @@ export function loadAndMergeModels(): { models: Model<Api>[]; error: string | nu
}
// Load custom models
const { models: customModels, error } = loadCustomModels();
const { models: customModels, error } = loadCustomModels(agentDir);
if (error) {
return { models: [], error };
@ -267,7 +268,8 @@ export function loadAndMergeModels(): { models: Model<Api>[]; error: string | nu
/**
* Get API key for a model (checks custom providers first, then built-in)
* Now async to support OAuth token refresh
* Now async to support OAuth token refresh.
* Note: OAuth storage location is configured globally via setOAuthStorage.
*/
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
// For custom providers, check their apiKey config
@ -357,8 +359,10 @@ export async function getApiKeyForModel(model: Model<Api>): Promise<string | und
* Get only models that have valid API keys available
* Returns { models, error } - either models array or error message
*/
export async function getAvailableModels(): Promise<{ models: Model<Api>[]; error: string | null }> {
const { models: allModels, error } = loadAndMergeModels();
export async function getAvailableModels(
agentDir: string = getAgentDir(),
): Promise<{ models: Model<Api>[]; error: string | null }> {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { models: [], error };
@ -390,8 +394,12 @@ export async function getAvailableModels(): Promise<{ models: Model<Api>[]; erro
* Find a specific model by provider and ID
* Returns { model, error } - either model or error message
*/
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels();
export function findModel(
provider: string,
modelId: string,
agentDir: string = getAgentDir(),
): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { model: null, error };

View file

@ -30,7 +30,9 @@
*/
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
@ -154,14 +156,42 @@ function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Configure OAuth storage to use the specified agent directory.
* Must be called before using OAuth-based authentication.
*/
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
const oauthPath = join(agentDir, "oauth.json");
setOAuthStorage({
load: () => {
if (!existsSync(oauthPath)) {
return {};
}
try {
return JSON.parse(readFileSync(oauthPath, "utf-8"));
} catch {
return {};
}
},
save: (storage) => {
const dir = dirname(oauthPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(oauthPath, 0o600);
},
});
}
// Discovery Functions
/**
* Get all models (built-in + custom from models.json).
* Note: Uses default agentDir for models.json location.
*/
export function discoverModels(): Model<any>[] {
const { models, error } = loadAndMergeModels();
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
const { models, error } = loadAndMergeModels(agentDir);
if (error) {
throw new Error(error);
}
@ -170,10 +200,9 @@ export function discoverModels(): Model<any>[] {
/**
* Get models that have valid API keys available.
* Note: Uses default agentDir for models.json and oauth.json location.
*/
export async function discoverAvailableModels(): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels();
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(agentDir);
if (error) {
throw new Error(error);
}
@ -182,11 +211,14 @@ export async function discoverAvailableModels(): Promise<Model<any>[]> {
/**
* Find a model by provider and ID.
* Note: Uses default agentDir for models.json location.
* @returns The model, or null if not found
*/
export function findModel(provider: string, modelId: string): Model<any> | null {
const { model, error } = findModelInternal(provider, modelId);
export function findModel(
provider: string,
modelId: string,
agentDir: string = getDefaultAgentDir(),
): Model<any> | null {
const { model, error } = findModelInternal(provider, modelId, agentDir);
if (error) {
throw new Error(error);
}
@ -276,7 +308,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
/**
* Create the default API key resolver.
* Checks custom providers (models.json), OAuth, and environment variables.
* Note: Uses default agentDir for models.json and oauth.json location.
*/
export function defaultGetApiKey(): (model: Model<any>) => Promise<string | undefined> {
return getApiKeyForModel;
@ -415,6 +446,9 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
// Configure OAuth storage for this agentDir
configureOAuthStorage(agentDir);
const settingsManager = new SettingsManager(agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);

View file

@ -85,11 +85,12 @@ export {
} from "./core/oauth/index.js";
// SDK for programmatic usage
export {
allBuiltInTools,
type BuildSystemPromptOptions,
buildSystemPrompt,
type CreateAgentSessionOptions,
type CreateAgentSessionResult,
// Configuration
configureOAuthStorage,
// Factory
createAgentSession,
// Helpers

View file

@ -6,21 +6,20 @@
*/
import type { Attachment } from "@mariozechner/pi-agent-core";
import { setOAuthStorage, supportsXhigh } from "@mariozechner/pi-ai";
import { supportsXhigh } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname } from "path";
import { existsSync, readFileSync } from "fs";
import { type Args, parseArgs, printHelp } from "./cli/args.js";
import { processFileArguments } from "./cli/file-processor.js";
import { listModels } from "./cli/list-models.js";
import { selectSession } from "./cli/session-picker.js";
import { getModelsPath, getOAuthPath, VERSION } from "./config.js";
import { getModelsPath, VERSION } from "./config.js";
import type { AgentSession } from "./core/agent-session.js";
import type { LoadedCustomTool } from "./core/custom-tools/index.js";
import { exportFromFile } from "./core/export-html.js";
import { findModel } from "./core/model-config.js";
import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
import { type CreateAgentSessionOptions, createAgentSession } from "./core/sdk.js";
import { type CreateAgentSessionOptions, configureOAuthStorage, createAgentSession } from "./core/sdk.js";
import { SessionManager } from "./core/session-manager.js";
import { SettingsManager } from "./core/settings-manager.js";
import { allTools } from "./core/tools/index.js";
@ -29,31 +28,6 @@ import { initTheme, stopThemeWatcher } from "./modes/interactive/theme/theme.js"
import { getChangelogPath, getNewEntries, parseChangelog } from "./utils/changelog.js";
import { ensureTool } from "./utils/tools-manager.js";
function configureOAuthStorage(): void {
const oauthPath = getOAuthPath();
setOAuthStorage({
load: () => {
if (!existsSync(oauthPath)) {
return {};
}
try {
return JSON.parse(readFileSync(oauthPath, "utf-8"));
} catch {
return {};
}
},
save: (storage) => {
const dir = dirname(oauthPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(oauthPath, 0o600);
},
});
}
async function checkForNewVersion(currentVersion: string): Promise<string | null> {
try {
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi-coding-agent/latest");