Refactor OAuth/API key handling: AuthStorage and ModelRegistry

- Add AuthStorage class for credential storage (auth.json)
- Add ModelRegistry class for model management with API key resolution
- Add discoverAuthStorage() and discoverModels() discovery functions
- Add migration from legacy oauth.json and settings.json apiKeys to auth.json
- Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels
- Remove apiKeys from Settings type and SettingsManager methods
- Rename getOAuthPath to getAuthPath
- Update SDK, examples, docs, tests, and mom package

Fixes #296
This commit is contained in:
Mario Zechner 2025-12-25 03:48:36 +01:00
parent 9f97f0c8da
commit 54018b6cc0
29 changed files with 953 additions and 2017 deletions

File diff suppressed because one or more lines are too long

View file

@ -6359,6 +6359,23 @@ export const MODELS = {
contextWindow: 128000,
maxTokens: 16384,
} satisfies Model<"openai-completions">,
"meta-llama/llama-3.1-70b-instruct": {
id: "meta-llama/llama-3.1-70b-instruct",
name: "Meta: Llama 3.1 70B Instruct",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text"],
cost: {
input: 0.39999999999999997,
output: 0.39999999999999997,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 131072,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"meta-llama/llama-3.1-8b-instruct": {
id: "meta-llama/llama-3.1-8b-instruct",
name: "Meta: Llama 3.1 8B Instruct",
@ -6393,23 +6410,6 @@ export const MODELS = {
contextWindow: 10000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"meta-llama/llama-3.1-70b-instruct": {
id: "meta-llama/llama-3.1-70b-instruct",
name: "Meta: Llama 3.1 70B Instruct",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text"],
cost: {
input: 0.39999999999999997,
output: 0.39999999999999997,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 131072,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"mistralai/mistral-nemo": {
id: "mistralai/mistral-nemo",
name: "Mistral: Mistral Nemo",
@ -6546,23 +6546,6 @@ export const MODELS = {
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openai/gpt-4o-2024-05-13": {
id: "openai/gpt-4o-2024-05-13",
name: "OpenAI: GPT-4o (2024-05-13)",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text", "image"],
cost: {
input: 5,
output: 15,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openai/gpt-4o": {
id: "openai/gpt-4o",
name: "OpenAI: GPT-4o",
@ -6597,6 +6580,23 @@ export const MODELS = {
contextWindow: 128000,
maxTokens: 64000,
} satisfies Model<"openai-completions">,
"openai/gpt-4o-2024-05-13": {
id: "openai/gpt-4o-2024-05-13",
name: "OpenAI: GPT-4o (2024-05-13)",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text", "image"],
cost: {
input: 5,
output: 15,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"meta-llama/llama-3-70b-instruct": {
id: "meta-llama/llama-3-70b-instruct",
name: "Meta: Llama 3 70B Instruct",
@ -6716,23 +6716,6 @@ export const MODELS = {
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openai/gpt-3.5-turbo-0613": {
id: "openai/gpt-3.5-turbo-0613",
name: "OpenAI: GPT-3.5 Turbo (older v0613)",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text"],
cost: {
input: 1,
output: 2,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 4095,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openai/gpt-4-turbo-preview": {
id: "openai/gpt-4-turbo-preview",
name: "OpenAI: GPT-4 Turbo Preview",
@ -6750,6 +6733,23 @@ export const MODELS = {
contextWindow: 128000,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"openai/gpt-3.5-turbo-0613": {
id: "openai/gpt-3.5-turbo-0613",
name: "OpenAI: GPT-3.5 Turbo (older v0613)",
api: "openai-completions",
provider: "openrouter",
baseUrl: "https://openrouter.ai/api/v1",
reasoning: false,
input: ["text"],
cost: {
input: 1,
output: 2,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 4095,
maxTokens: 4096,
} satisfies Model<"openai-completions">,
"mistralai/mistral-tiny": {
id: "mistralai/mistral-tiny",
name: "Mistral Tiny",

View file

@ -106,29 +106,21 @@ For most users, [Git for Windows](https://git-scm.com/download/win) is sufficien
### API Keys & OAuth
**Option 1: Settings file** (recommended)
**Option 1: Auth file** (recommended)
Add API keys to `~/.pi/agent/settings.json`:
Add API keys to `~/.pi/agent/auth.json`:
```json
{
"apiKeys": {
"anthropic": "sk-ant-...",
"openai": "sk-...",
"google": "...",
"mistral": "...",
"groq": "...",
"cerebras": "...",
"xai": "...",
"openrouter": "...",
"zai": "..."
}
"anthropic": { "type": "api_key", "key": "sk-ant-..." },
"openai": { "type": "api_key", "key": "sk-..." },
"google": { "type": "api_key", "key": "..." }
}
```
**Option 2: Environment variables**
| Provider | Settings Key | Environment Variable |
| Provider | Auth Key | Environment Variable |
|----------|--------------|---------------------|
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` |
| OpenAI | `openai` | `OPENAI_API_KEY` |
@ -140,7 +132,7 @@ Add API keys to `~/.pi/agent/settings.json`:
| OpenRouter | `openrouter` | `OPENROUTER_API_KEY` |
| ZAI | `zai` | `ZAI_API_KEY` |
Settings file keys take priority over environment variables.
Auth file keys take priority over environment variables.
**OAuth Providers:**
@ -158,6 +150,8 @@ pi
/login # Select provider, authorize in browser
```
**Note:** `/login` replaces any existing API key for that provider with OAuth credentials in `auth.json`.
**GitHub Copilot notes:**
- Press Enter for github.com, or enter your GitHub Enterprise Server domain
- If you get "model not supported" error, enable it in VS Code: Copilot Chat → model selector → select model → "Enable"
@ -167,7 +161,7 @@ pi
- Antigravity uses a sandbox endpoint with access to Gemini 3, Claude (sonnet/opus thinking), and GPT-OSS models
- Both are free with any Google account, subject to rate limits
Tokens stored in `~/.pi/agent/oauth.json`. Use `/logout` to clear.
Credentials stored in `~/.pi/agent/auth.json`. Use `/logout` to clear.
### Quick Start
@ -855,10 +849,15 @@ For adding new tools, see [Custom Tools](#custom-tools) in the Configuration sec
For embedding pi in Node.js/TypeScript applications, use the SDK:
```typescript
import { createAgentSession, SessionManager } from "@mariozechner/pi-coding-agent";
import { createAgentSession, discoverAuthStorage, discoverModels, SessionManager } from "@mariozechner/pi-coding-agent";
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
session.subscribe((event) => {

View file

@ -14,10 +14,16 @@ See [examples/sdk/](../examples/sdk/) for working examples from minimal to full
## Quick Start
```typescript
import { createAgentSession, SessionManager } from "@mariozechner/pi-coding-agent";
import { createAgentSession, discoverAuthStorage, discoverModels, SessionManager } from "@mariozechner/pi-coding-agent";
// Set up credential storage and model registry
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
session.subscribe((event) => {
@ -220,32 +226,42 @@ const { session } = await createAgentSession({
- Global commands (`commands/`)
- Global context file (`AGENTS.md`)
- Settings (`settings.json`)
- Models (`models.json`)
- OAuth tokens (`oauth.json`)
- Custom models (`models.json`)
- Credentials (`auth.json`)
- Sessions (`sessions/`)
### Model
```typescript
import { findModel, discoverAvailableModels } from "@mariozechner/pi-coding-agent";
import { getModel } from "@mariozechner/pi-ai";
import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
// Find specific model (returns { model, error })
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
if (error) throw new Error(error);
if (!model) throw new Error("Model not found");
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
// Or get all models with valid API keys
const available = await discoverAvailableModels();
// Find specific built-in model (doesn't check if API key exists)
const opus = getModel("anthropic", "claude-opus-4-5");
if (!opus) throw new Error("Model not found");
// Find any model by provider/id, including custom models from models.json
// (doesn't check if API key exists)
const customModel = modelRegistry.find("my-provider", "my-model");
// Get only models that have valid API keys configured
const available = await modelRegistry.getAvailable();
const { session } = await createAgentSession({
model: model,
model: opus,
thinkingLevel: "medium", // off, minimal, low, medium, high, xhigh
// Models for cycling (Ctrl+P in interactive mode)
scopedModels: [
{ model: sonnet, thinkingLevel: "high" },
{ model: opus, thinkingLevel: "high" },
{ model: haiku, thinkingLevel: "off" },
],
authStorage,
modelRegistry,
});
```
@ -256,38 +272,42 @@ If no model is provided:
> See [examples/sdk/02-custom-model.ts](../examples/sdk/02-custom-model.ts)
### API Keys
### API Keys and OAuth
API key resolution priority:
1. `settings.json` apiKeys (e.g., `{ "apiKeys": { "anthropic": "sk-..." } }`)
2. Custom providers from `models.json`
3. OAuth credentials from `oauth.json`
4. Environment variables (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.)
API key resolution priority (handled by AuthStorage):
1. Runtime overrides (via `setRuntimeApiKey`, not persisted)
2. Stored credentials in `auth.json` (API keys or OAuth tokens)
3. Environment variables (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.)
4. Fallback resolver (for custom provider keys from `models.json`)
```typescript
import { defaultGetApiKey, configureOAuthStorage } from "@mariozechner/pi-coding-agent";
import { AuthStorage, ModelRegistry, discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
// Default: checks settings.json, models.json, OAuth, environment variables
const { session } = await createAgentSession();
// Default: uses ~/.pi/agent/auth.json and ~/.pi/agent/models.json
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
// Custom resolver
const { session } = await createAgentSession({
getApiKey: async (model) => {
// Custom logic (secrets manager, database, etc.)
if (model.provider === "anthropic") {
return process.env.MY_ANTHROPIC_KEY;
}
// Fall back to default (pass settingsManager for settings.json lookup)
return defaultGetApiKey()(model);
},
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
// Use OAuth from ~/.pi/agent with custom agentDir for everything else
configureOAuthStorage(); // Must call before createAgentSession
// Runtime API key override (not persisted to disk)
authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
// Custom auth storage location
const customAuth = new AuthStorage("/my/app/auth.json");
const customRegistry = new ModelRegistry(customAuth, "/my/app/models.json");
const { session } = await createAgentSession({
agentDir: "/custom/config",
// OAuth tokens still come from ~/.pi/agent/oauth.json
sessionManager: SessionManager.inMemory(),
authStorage: customAuth,
modelRegistry: customRegistry,
});
// No custom models.json (built-in models only)
const simpleRegistry = new ModelRegistry(authStorage);
```
> See [examples/sdk/09-api-keys-and-oauth.ts](../examples/sdk/09-api-keys-and-oauth.ts)
@ -630,10 +650,12 @@ Project overrides global. Nested objects merge keys. Setters only modify global
All discovery functions accept optional `cwd` and `agentDir` parameters.
```typescript
import { getModel } from "@mariozechner/pi-ai";
import {
AuthStorage,
ModelRegistry,
discoverAuthStorage,
discoverModels,
discoverAvailableModels,
findModel,
discoverSkills,
discoverHooks,
discoverCustomTools,
@ -643,10 +665,13 @@ import {
buildSystemPrompt,
} from "@mariozechner/pi-coding-agent";
// Models
const allModels = discoverModels();
const available = await discoverAvailableModels();
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
// Auth and Models
const authStorage = discoverAuthStorage(); // ~/.pi/agent/auth.json
const modelRegistry = discoverModels(authStorage); // + ~/.pi/agent/models.json
const allModels = modelRegistry.getAll(); // All models (built-in + custom)
const available = await modelRegistry.getAvailable(); // Only models with API keys
const model = modelRegistry.find("provider", "id"); // Find specific model
const builtIn = getModel("anthropic", "claude-opus-4-5"); // Built-in only
// Skills
const skills = discoverSkills(cwd, agentDir, skillsSettings);
@ -698,12 +723,12 @@ interface CreateAgentSessionResult {
## Complete Example
```typescript
import { getModel } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import {
AuthStorage,
createAgentSession,
configureOAuthStorage,
defaultGetApiKey,
findModel,
ModelRegistry,
SessionManager,
SettingsManager,
readTool,
@ -711,18 +736,17 @@ import {
type HookFactory,
type CustomAgentTool,
} from "@mariozechner/pi-coding-agent";
import { getAgentDir } from "@mariozechner/pi-coding-agent/config";
// Use OAuth from default location
configureOAuthStorage(getAgentDir());
// Set up auth storage (custom location)
const authStorage = new AuthStorage("/custom/agent/auth.json");
// Custom API key with fallback
const getApiKey = async (model: { provider: string }) => {
if (model.provider === "anthropic" && process.env.MY_KEY) {
return process.env.MY_KEY;
}
return defaultGetApiKey()(model as any);
};
// Runtime API key override (not persisted)
if (process.env.MY_KEY) {
authStorage.setRuntimeApiKey("anthropic", process.env.MY_KEY);
}
// Model registry (no custom models.json)
const modelRegistry = new ModelRegistry(authStorage);
// Inline hook
const auditHook: HookFactory = (api) => {
@ -744,8 +768,7 @@ const statusTool: CustomAgentTool = {
}),
};
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
if (error) throw new Error(error);
const model = getModel("anthropic", "claude-opus-4-5");
if (!model) throw new Error("Model not found");
// In-memory settings with overrides
@ -760,7 +783,8 @@ const { session } = await createAgentSession({
model,
thinkingLevel: "off",
getApiKey,
authStorage,
modelRegistry,
systemPrompt: "You are a minimal assistant. Be concise.",
@ -812,12 +836,14 @@ The main entry point exports:
```typescript
// Factory
createAgentSession
configureOAuthStorage
// Auth and Models
AuthStorage
ModelRegistry
discoverAuthStorage
discoverModels
// Discovery
discoverModels
discoverAvailableModels
findModel
discoverSkills
discoverHooks
discoverCustomTools
@ -825,7 +851,6 @@ discoverContextFiles
discoverSlashCommands
// Helpers
defaultGetApiKey
loadSettings
buildSystemPrompt

View file

@ -13,8 +13,8 @@
* pi --hook examples/hooks/custom-compaction.ts
*/
import { complete } from "@mariozechner/pi-ai";
import { findModel, messageTransformer } from "@mariozechner/pi-coding-agent";
import { complete, getModel } from "@mariozechner/pi-ai";
import { messageTransformer } from "@mariozechner/pi-coding-agent";
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
@ -27,10 +27,9 @@ export default function (pi: HookAPI) {
event;
// Use Gemini Flash for summarization (cheaper/faster than most conversation models)
// findModel searches both built-in models and custom models from models.json
const { model, error } = findModel("google", "gemini-2.5-flash");
if (error || !model) {
ctx.ui.notify(`Could not find Gemini Flash model: ${error}, using default compaction`, "warning");
const model = getModel("google", "gemini-2.5-flash");
if (!model) {
ctx.ui.notify(`Could not find Gemini Flash model, using default compaction`, "warning");
return;
}

View file

@ -4,16 +4,27 @@
* Shows how to select a specific model and thinking level.
*/
import { createAgentSession, discoverAvailableModels, findModel } from "../../src/index.js";
import { getModel } from "@mariozechner/pi-ai";
import { createAgentSession, discoverAuthStorage, discoverModels } from "../../src/index.js";
// Option 1: Find a specific model by provider/id
const { model: sonnet } = findModel("anthropic", "claude-sonnet-4-20250514");
if (sonnet) {
console.log(`Found model: ${sonnet.provider}/${sonnet.id}`);
// Set up auth storage and model registry
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
// Option 1: Find a specific built-in model by provider/id
const opus = getModel("anthropic", "claude-opus-4-5");
if (opus) {
console.log(`Found model: ${opus.provider}/${opus.id}`);
}
// Option 2: Pick from available models (have valid API keys)
const available = await discoverAvailableModels();
// Option 2: Find model via registry (includes custom models from models.json)
const customModel = modelRegistry.find("my-provider", "my-model");
if (customModel) {
console.log(`Found custom model: ${customModel.provider}/${customModel.id}`);
}
// Option 3: Pick from available models (have valid API keys)
const available = await modelRegistry.getAvailable();
console.log(
"Available models:",
available.map((m) => `${m.provider}/${m.id}`),
@ -23,6 +34,8 @@ if (available.length > 0) {
const { session } = await createAgentSession({
model: available[0],
thinkingLevel: "medium", // off, low, medium, high
authStorage,
modelRegistry,
});
session.subscribe((event) => {

View file

@ -1,40 +1,55 @@
/**
* API Keys and OAuth
*
* Configure API key resolution. Default checks: models.json, OAuth, env vars.
* Configure API key resolution via AuthStorage and ModelRegistry.
*/
import { getAgentDir } from "../../src/config.js";
import { configureOAuthStorage, createAgentSession, defaultGetApiKey, SessionManager } from "../../src/index.js";
import {
AuthStorage,
createAgentSession,
discoverAuthStorage,
discoverModels,
ModelRegistry,
SessionManager,
} from "../../src/index.js";
// Default: uses env vars (ANTHROPIC_API_KEY, etc.), OAuth, and models.json
await createAgentSession({
sessionManager: SessionManager.inMemory(),
});
console.log("Session with default API key resolution");
// Custom resolver
await createAgentSession({
getApiKey: async (model) => {
// Custom logic (secrets manager, database, etc.)
if (model.provider === "anthropic") {
return process.env.MY_ANTHROPIC_KEY;
}
// Fall back to default
return defaultGetApiKey()(model);
},
sessionManager: SessionManager.inMemory(),
});
console.log("Session with custom API key resolver");
// Use OAuth from ~/.pi/agent while customizing everything else
configureOAuthStorage(getAgentDir()); // Must call before createAgentSession
// Default: discoverAuthStorage() uses ~/.pi/agent/auth.json
// discoverModels() loads built-in + custom models from ~/.pi/agent/models.json
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
await createAgentSession({
agentDir: "/tmp/custom-config", // Custom config location
// But OAuth tokens still come from ~/.pi/agent/oauth.json
systemPrompt: "You are helpful.",
skills: [],
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
console.log("Session with OAuth from default location, custom config elsewhere");
console.log("Session with default auth storage and model registry");
// Custom auth storage location
const customAuthStorage = new AuthStorage("/tmp/my-app/auth.json");
const customModelRegistry = new ModelRegistry(customAuthStorage, "/tmp/my-app/models.json");
await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage: customAuthStorage,
modelRegistry: customModelRegistry,
});
console.log("Session with custom auth storage location");
// Runtime API key override (not persisted to disk)
authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
console.log("Session with runtime API key override");
// No models.json - only built-in models
const simpleRegistry = new ModelRegistry(authStorage); // null = no models.json
await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry: simpleRegistry,
});
console.log("Session with only built-in models");

View file

@ -2,38 +2,36 @@
* Full Control
*
* Replace everything - no discovery, explicit configuration.
* Still uses OAuth from ~/.pi/agent for convenience.
*
* IMPORTANT: When providing `tools` with a custom `cwd`, use the tool factory
* functions (createReadTool, createBashTool, etc.) to ensure tools resolve
* paths relative to your cwd.
*/
import { getModel } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { getAgentDir } from "../../src/config.js";
import {
AuthStorage,
type CustomAgentTool,
configureOAuthStorage,
createAgentSession,
createBashTool,
createReadTool,
defaultGetApiKey,
findModel,
type HookFactory,
ModelRegistry,
SessionManager,
SettingsManager,
} from "../../src/index.js";
// Use OAuth from default location
configureOAuthStorage(getAgentDir());
// Custom auth storage location
const authStorage = new AuthStorage("/tmp/my-agent/auth.json");
// Custom API key with fallback
const getApiKey = async (model: { provider: string }) => {
if (model.provider === "anthropic" && process.env.MY_ANTHROPIC_KEY) {
return process.env.MY_ANTHROPIC_KEY;
}
return defaultGetApiKey()(model as any);
};
// Runtime API key override (not persisted)
if (process.env.MY_ANTHROPIC_KEY) {
authStorage.setRuntimeApiKey("anthropic", process.env.MY_ANTHROPIC_KEY);
}
// Model registry with no custom models.json
const modelRegistry = new ModelRegistry(authStorage);
// Inline hook
const auditHook: HookFactory = (api) => {
@ -55,7 +53,7 @@ const statusTool: CustomAgentTool = {
}),
};
const { model } = findModel("anthropic", "claude-sonnet-4-20250514");
const model = getModel("anthropic", "claude-opus-4-5");
if (!model) throw new Error("Model not found");
// In-memory settings with overrides
@ -73,7 +71,8 @@ const { session } = await createAgentSession({
model,
thinkingLevel: "off",
getApiKey,
authStorage,
modelRegistry,
systemPrompt: `You are a minimal assistant.
Available: read, bash, status. Be concise.`,

View file

@ -29,50 +29,63 @@ npx tsx examples/sdk/01-minimal.ts
## Quick Reference
```typescript
import { getModel } from "@mariozechner/pi-ai";
import {
AuthStorage,
createAgentSession,
configureOAuthStorage,
discoverAuthStorage,
discoverModels,
discoverSkills,
discoverHooks,
discoverCustomTools,
discoverContextFiles,
discoverSlashCommands,
discoverAvailableModels,
findModel,
defaultGetApiKey,
loadSettings,
buildSystemPrompt,
ModelRegistry,
SessionManager,
codingTools,
readOnlyTools,
readTool, bashTool, editTool, writeTool,
} from "@mariozechner/pi-coding-agent";
// Auth and models setup
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
// Minimal
const { session } = await createAgentSession();
const { session } = await createAgentSession({ authStorage, modelRegistry });
// Custom model
const { model } = findModel("anthropic", "claude-sonnet-4-20250514");
const { session } = await createAgentSession({ model, thinkingLevel: "high" });
const model = getModel("anthropic", "claude-opus-4-5");
const { session } = await createAgentSession({ model, thinkingLevel: "high", authStorage, modelRegistry });
// Modify prompt
const { session } = await createAgentSession({
systemPrompt: (defaultPrompt) => defaultPrompt + "\n\nBe concise.",
authStorage,
modelRegistry,
});
// Read-only
const { session } = await createAgentSession({ tools: readOnlyTools });
const { session } = await createAgentSession({ tools: readOnlyTools, authStorage, modelRegistry });
// In-memory
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
modelRegistry,
});
// Full control
configureOAuthStorage(); // Use OAuth from ~/.pi/agent
const customAuth = new AuthStorage("/my/app/auth.json");
customAuth.setRuntimeApiKey("anthropic", process.env.MY_KEY!);
const customRegistry = new ModelRegistry(customAuth);
const { session } = await createAgentSession({
model,
getApiKey: async (m) => process.env.MY_KEY,
authStorage: customAuth,
modelRegistry: customRegistry,
systemPrompt: "You are helpful.",
tools: [readTool, bashTool],
customTools: [{ tool: myTool }],
@ -81,7 +94,6 @@ const { session } = await createAgentSession({
contextFiles: [],
slashCommands: [],
sessionManager: SessionManager.inMemory(),
settings: { compaction: { enabled: false } },
});
// Run prompts
@ -97,11 +109,12 @@ await session.prompt("Hello");
| Option | Default | Description |
|--------|---------|-------------|
| `authStorage` | `discoverAuthStorage()` | Credential storage |
| `modelRegistry` | `discoverModels(authStorage)` | Model registry |
| `cwd` | `process.cwd()` | Working directory |
| `agentDir` | `~/.pi/agent` | Config directory |
| `model` | From settings/first available | Model to use |
| `thinkingLevel` | From settings/"off" | off, low, medium, high |
| `getApiKey` | Built-in resolver | API key function |
| `systemPrompt` | Discovered | String or `(default) => modified` |
| `tools` | `codingTools` | Built-in tools |
| `customTools` | Discovered | Replaces discovery |
@ -112,7 +125,7 @@ await session.prompt("Hello");
| `contextFiles` | Discovered | AGENTS.md files |
| `slashCommands` | Discovered | File commands |
| `sessionManager` | `SessionManager.create(cwd)` | Persistence |
| `settings` | From agentDir | Overrides |
| `settingsManager` | From agentDir | Settings overrides |
## Events

View file

@ -3,8 +3,7 @@
*/
import type { Api, Model } from "@mariozechner/pi-ai";
import { getAvailableModels } from "../core/models-json.js";
import type { SettingsManager } from "../core/settings-manager.js";
import type { ModelRegistry } from "../core/model-registry.js";
import { fuzzyFilter } from "../utils/fuzzy.js";
/**
@ -25,16 +24,8 @@ function formatTokenCount(count: number): string {
/**
* List available models, optionally filtered by search pattern
*/
export async function listModels(searchPattern?: string, settingsManager?: SettingsManager): Promise<void> {
const { models, error } = await getAvailableModels(
undefined,
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
);
if (error) {
console.error(error);
process.exit(1);
}
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
const models = await modelRegistry.getAvailable();
if (models.length === 0) {
console.log("No models available. Set API keys in environment variables.");

View file

@ -112,9 +112,9 @@ export function getModelsPath(): string {
return join(getAgentDir(), "models.json");
}
/** Get path to oauth.json */
export function getOAuthPath(): string {
return join(getAgentDir(), "oauth.json");
/** Get path to auth.json */
export function getAuthPath(): string {
return join(getAgentDir(), "auth.json");
}
/** Get path to settings.json */

View file

@ -23,7 +23,7 @@ import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custo
import { exportSessionToHtml } from "./export-html.js";
import type { HookRunner, SessionEventResult, TurnEndEvent, TurnStartEvent } from "./hooks/index.js";
import type { BashExecutionMessage } from "./messages.js";
import { getApiKeyForModel, getAvailableModels } from "./models-json.js";
import type { ModelRegistry } from "./model-registry.js";
import type { CompactionEntry, SessionManager } from "./session-manager.js";
import type { SettingsManager, SkillsSettings } from "./settings-manager.js";
import { expandSlashCommand, type FileSlashCommand } from "./slash-commands.js";
@ -56,8 +56,8 @@ export interface AgentSessionConfig {
/** Custom tools for session lifecycle events */
customTools?: LoadedCustomTool[];
skillsSettings?: Required<SkillsSettings>;
/** Resolve API key for a model. Default: getApiKeyForModel */
resolveApiKey?: (model: Model<any>) => Promise<string | undefined>;
/** Model registry for API key resolution and model discovery */
modelRegistry: ModelRegistry;
}
/** Options for AgentSession.prompt() */
@ -153,8 +153,8 @@ export class AgentSession {
private _skillsSettings: Required<SkillsSettings> | undefined;
// API key resolver
private _resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
// Model registry for API key resolution
private _modelRegistry: ModelRegistry;
constructor(config: AgentSessionConfig) {
this.agent = config.agent;
@ -165,7 +165,12 @@ export class AgentSession {
this._hookRunner = config.hookRunner ?? null;
this._customTools = config.customTools ?? [];
this._skillsSettings = config.skillsSettings;
this._resolveApiKey = config.resolveApiKey ?? getApiKeyForModel;
this._modelRegistry = config.modelRegistry;
}
/** Model registry for API key resolution and model discovery */
get modelRegistry(): ModelRegistry {
return this._modelRegistry;
}
// =========================================================================
@ -434,7 +439,7 @@ export class AgentSession {
}
// Validate API key
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
throw new Error(
`No API key found for ${this.model.provider}.\n\n` +
@ -561,7 +566,7 @@ export class AgentSession {
* @throws Error if no API key available for the model
*/
async setModel(model: Model<any>): Promise<void> {
const apiKey = await this._resolveApiKey(model);
const apiKey = await this._modelRegistry.getApiKey(model);
if (!apiKey) {
throw new Error(`No API key for ${model.provider}/${model.id}`);
}
@ -599,7 +604,7 @@ export class AgentSession {
const next = this._scopedModels[nextIndex];
// Validate API key
const apiKey = await this._resolveApiKey(next.model);
const apiKey = await this._modelRegistry.getApiKey(next.model);
if (!apiKey) {
throw new Error(`No API key for ${next.model.provider}/${next.model.id}`);
}
@ -616,10 +621,7 @@ export class AgentSession {
}
private async _cycleAvailableModel(): Promise<ModelCycleResult | null> {
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
if (error) throw new Error(`Failed to load models: ${error}`);
const availableModels = await this._modelRegistry.getAvailable();
if (availableModels.length <= 1) return null;
const currentModel = this.model;
@ -631,7 +633,7 @@ export class AgentSession {
const nextIndex = (currentIndex + 1) % availableModels.length;
const nextModel = availableModels[nextIndex];
const apiKey = await this._resolveApiKey(nextModel);
const apiKey = await this._modelRegistry.getApiKey(nextModel);
if (!apiKey) {
throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`);
}
@ -650,11 +652,7 @@ export class AgentSession {
* Get all available models with valid API keys.
*/
async getAvailableModels(): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
if (error) throw new Error(error);
return models;
return this._modelRegistry.getAvailable();
}
// =========================================================================
@ -747,7 +745,7 @@ export class AgentSession {
throw new Error("No model selected");
}
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
throw new Error(`No API key for ${this.model.provider}`);
}
@ -786,7 +784,7 @@ export class AgentSession {
tokensBefore: preparation.tokensBefore,
customInstructions,
model: this.model,
resolveApiKey: this._resolveApiKey,
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
signal: this._compactionAbortController.signal,
})) as SessionEventResult | undefined;
@ -908,7 +906,7 @@ export class AgentSession {
return;
}
const apiKey = await this._resolveApiKey(this.model);
const apiKey = await this._modelRegistry.getApiKey(this.model);
if (!apiKey) {
this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false });
return;
@ -948,7 +946,7 @@ export class AgentSession {
tokensBefore: preparation.tokensBefore,
customInstructions: undefined,
model: this.model,
resolveApiKey: this._resolveApiKey,
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
signal: this._autoCompactionAbortController.signal,
})) as SessionEventResult | undefined;
@ -1334,9 +1332,7 @@ export class AgentSession {
// Restore model if saved
if (sessionContext.model) {
const availableModels = (
await getAvailableModels(undefined, (provider) => this.settingsManager.getApiKey(provider))
).models;
const availableModels = await this._modelRegistry.getAvailable();
const match = availableModels.find(
(m) => m.provider === sessionContext.model!.provider && m.id === sessionContext.model!.modelId,
);

View file

@ -3,9 +3,18 @@
* Handles loading, saving, and refreshing credentials from auth.json.
*/
import { getApiKeyFromEnv, getOAuthApiKey, type OAuthCredentials, type OAuthProvider } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname } from "path";
import {
getEnvApiKey,
getOAuthApiKey,
loginAnthropic,
loginAntigravity,
loginGeminiCli,
loginGitHubCopilot,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, renameSync, writeFileSync } from "fs";
import { dirname, join } from "path";
export type ApiKeyCredential = {
type: "api_key";
@ -25,11 +34,29 @@ export type AuthStorageData = Record<string, AuthCredential>;
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
constructor(private authPath: string) {
this.reload();
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
this.fallbackResolver = resolver;
}
/**
* Reload credentials from disk.
*/
@ -101,14 +128,69 @@ export class AuthStorage {
return { ...this.data };
}
/**
* Login to an OAuth provider.
*/
async login(
provider: OAuthProvider,
callbacks: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
onProgress?: (message: string) => void;
},
): Promise<void> {
let credentials: OAuthCredentials;
switch (provider) {
case "anthropic":
credentials = await loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
break;
case "github-copilot":
credentials = await loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
});
break;
case "google-gemini-cli":
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress);
break;
case "google-antigravity":
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress);
break;
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
this.set(provider, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Get API key for a provider.
* Priority:
* 1. API key from auth.json
* 2. OAuth token from auth.json (auto-refreshed)
* 3. Environment variable (via getApiKeyFromEnv)
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | null> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[provider];
if (cred?.type === "api_key") {
@ -116,30 +198,83 @@ export class AuthStorage {
}
if (cred?.type === "oauth") {
// Build OAuthCredentials map (without type discriminator)
// Filter to only oauth credentials for getOAuthApiKey
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(this.data)) {
if (value.type === "oauth") {
const { type: _, ...rest } = value;
oauthCreds[key] = rest;
oauthCreds[key] = value;
}
}
try {
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
if (result) {
// Save refreshed credentials
this.data[provider] = { type: "oauth", ...result.newCredentials };
this.save();
return result.apiKey;
}
} catch {
// Token refresh failed, remove invalid credentials
this.remove(provider);
}
}
// Fall back to environment variable
return getApiKeyFromEnv(provider) ?? null;
const envKey = getEnvApiKey(provider);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? null;
}
/**
* Migrate credentials from legacy oauth.json and settings.json apiKeys to auth.json.
* Only runs if auth.json doesn't exist yet. Returns list of migrated providers.
*/
static migrateLegacy(authPath: string, agentDir: string): string[] {
const oauthPath = join(agentDir, "oauth.json");
const settingsPath = join(agentDir, "settings.json");
// Skip if auth.json already exists
if (existsSync(authPath)) return [];
const migrated: AuthStorageData = {};
const providers: string[] = [];
// Migrate oauth.json
if (existsSync(oauthPath)) {
try {
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
for (const [provider, cred] of Object.entries(oauth)) {
migrated[provider] = { type: "oauth", ...(cred as object) } as OAuthCredential;
providers.push(provider);
}
renameSync(oauthPath, `${oauthPath}.migrated`);
} catch {}
}
// Migrate settings.json apiKeys
if (existsSync(settingsPath)) {
try {
const content = readFileSync(settingsPath, "utf-8");
const settings = JSON.parse(content);
if (settings.apiKeys && typeof settings.apiKeys === "object") {
for (const [provider, key] of Object.entries(settings.apiKeys)) {
if (!migrated[provider] && typeof key === "string") {
migrated[provider] = { type: "api_key", key };
providers.push(provider);
}
}
delete settings.apiKeys;
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
}
} catch {}
}
if (Object.keys(migrated).length > 0) {
mkdirSync(dirname(authPath), { recursive: true });
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
}
return providers;
}
}

View file

@ -0,0 +1,315 @@
/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import {
type Api,
getGitHubCopilotBaseUrl,
getModels,
getProviders,
type KnownProvider,
type Model,
normalizeDomain,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import type { AuthStorage } from "./auth-storage.js";
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for OpenAI compatibility settings
const OpenAICompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
});
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/**
* Resolve an API key config value to an actual key.
* Checks environment variable first, then treats as literal.
*/
function resolveApiKeyConfig(keyConfig: string): string | undefined {
const envValue = process.env[keyConfig];
if (envValue) return envValue;
return keyConfig;
}
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private loadError: string | null = null;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | null = null,
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveApiKeyConfig(keyConfig);
}
return undefined;
});
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = null;
this.loadModels();
}
/**
* Get any error from loading models.json (null if no error).
*/
getError(): string | null {
return this.loadError;
}
private loadModels(): void {
// Load built-in models
const builtInModels: Model<Api>[] = [];
for (const provider of getProviders()) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models from models.json (if path provided)
let customModels: Model<Api>[] = [];
if (this.modelsJsonPath) {
const result = this.loadCustomModels(this.modelsJsonPath);
if (result.error) {
this.loadError = result.error;
// Keep built-in models even if custom models failed to load
} else {
customModels = result.models;
}
}
const combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth credentials
const copilotCred = this.authStorage.get("github-copilot");
if (copilotCred?.type === "oauth") {
const domain = copilotCred.enterpriseUrl
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
} else {
this.models = combined;
}
}
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
if (!existsSync(modelsJsonPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
};
}
// Additional validation
this.validateConfig(config);
// Parse models
return { models: this.parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for fallback resolver
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
let headers =
providerConfig.headers || modelDef.headers
? { ...providerConfig.headers, ...modelDef.headers }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
}
/**
* Get only models that have valid API keys available.
*/
async getAvailable(): Promise<Model<Api>[]> {
const available: Model<Api>[] = [];
for (const model of this.models) {
const apiKey = await this.authStorage.getApiKey(model.provider);
if (apiKey) {
available.push(model);
}
}
return available;
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | null {
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | null> {
return this.authStorage.getApiKey(model.provider);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
}

View file

@ -6,8 +6,7 @@ import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Api, KnownProvider, Model } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { isValidThinkingLevel } from "../cli/args.js";
import { findModel, getApiKeyForModel, getAvailableModels } from "./models-json.js";
import type { SettingsManager } from "./settings-manager.js";
import type { ModelRegistry } from "./model-registry.js";
/** Default model IDs for each known provider */
export const defaultModelPerProvider: Record<KnownProvider, string> = {
@ -167,21 +166,9 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
* The algorithm tries to match the full pattern first, then progressively
* strips colon-suffixes to find a match.
*
* @param patterns - Model patterns to resolve
* @param settingsManager - Optional settings manager for API key fallback from settings.json
*/
export async function resolveModelScope(patterns: string[], settingsManager?: SettingsManager): Promise<ScopedModel[]> {
const { models: availableModels, error } = await getAvailableModels(
undefined,
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
);
if (error) {
console.warn(chalk.yellow(`Warning: Error loading models: ${error}`));
return [];
}
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
const availableModels = await modelRegistry.getAvailable();
const scopedModels: ScopedModel[] = [];
for (const pattern of patterns) {
@ -224,20 +211,28 @@ export async function findInitialModel(options: {
cliModel?: string;
scopedModels: ScopedModel[];
isContinuing: boolean;
settingsManager: SettingsManager;
defaultProvider?: string;
defaultModelId?: string;
defaultThinkingLevel?: ThinkingLevel;
modelRegistry: ModelRegistry;
}): Promise<InitialModelResult> {
const { cliProvider, cliModel, scopedModels, isContinuing, settingsManager } = options;
const {
cliProvider,
cliModel,
scopedModels,
isContinuing,
defaultProvider,
defaultModelId,
defaultThinkingLevel,
modelRegistry,
} = options;
let model: Model<Api> | null = null;
let thinkingLevel: ThinkingLevel = "off";
// 1. CLI args take priority
if (cliProvider && cliModel) {
const { model: found, error } = findModel(cliProvider, cliModel);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const found = modelRegistry.find(cliProvider, cliModel);
if (!found) {
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
process.exit(1);
@ -255,34 +250,19 @@ export async function findInitialModel(options: {
}
// 3. Try saved default from settings
const defaultProvider = settingsManager.getDefaultProvider();
const defaultModelId = settingsManager.getDefaultModel();
if (defaultProvider && defaultModelId) {
const { model: found, error } = findModel(defaultProvider, defaultModelId);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const found = modelRegistry.find(defaultProvider, defaultModelId);
if (found) {
model = found;
// Also load saved thinking level
const savedThinking = settingsManager.getDefaultThinkingLevel();
if (savedThinking) {
thinkingLevel = savedThinking;
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: null };
}
}
// 4. Try first available model with valid API key
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
settingsManager.getApiKey(provider),
);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
@ -310,17 +290,12 @@ export async function restoreModelFromSession(
savedModelId: string,
currentModel: Model<Api> | null,
shouldPrintMessages: boolean,
settingsManager?: SettingsManager,
modelRegistry: ModelRegistry,
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
const { model: restoredModel, error } = findModel(savedProvider, savedModelId);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
const hasApiKey = restoredModel ? !!(await getApiKeyForModel(restoredModel)) : false;
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
if (restoredModel && hasApiKey) {
if (shouldPrintMessages) {
@ -348,14 +323,7 @@ export async function restoreModelFromSession(
}
// Try to find any available model
const { models: availableModels, error: availableError } = await getAvailableModels(
undefined,
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
);
if (availableError) {
console.error(chalk.red(availableError));
process.exit(1);
}
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers

View file

@ -1,467 +0,0 @@
import {
type Api,
getApiKey,
getGitHubCopilotBaseUrl,
getModels,
getProviders,
type KnownProvider,
loadOAuthCredentials,
type Model,
normalizeDomain,
refreshGitHubCopilotToken,
removeOAuthCredentials,
saveOAuthCredentials,
} from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for OpenAI compatibility settings
const OpenAICompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
});
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
// Custom provider API key mappings (provider name -> apiKey config)
const customProviderApiKeys: Map<string, string> = new Map();
/**
* Resolve an API key config value to an actual key.
* First checks if it's an environment variable, then treats as literal.
*/
export function resolveApiKey(keyConfig: string): string | undefined {
// First check if it's an env var name
const envValue = process.env[keyConfig];
if (envValue) return envValue;
// Otherwise treat as literal API key
return keyConfig;
}
/**
* Load custom models from a models.json file
* Returns { models, error } - either models array or error message
*/
function loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
if (!existsSync(modelsJsonPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
};
}
// Additional validation
try {
validateConfig(config);
} catch (error) {
return {
models: [],
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
// Parse models
return { models: parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
};
}
}
/**
* Validate config structure and requirements
*/
function validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
`Set at provider or model level.`,
);
}
// Validate required fields
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
/**
* Parse config into Model objects
*/
function parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
// Clear and rebuild custom provider API key mappings
customProviderApiKeys.clear();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for this provider
customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
// Model-level api overrides provider-level api
const api = modelDef.api || providerConfig.api;
if (!api) {
// This should have been caught by validateConfig, but be safe
continue;
}
// Merge headers: provider headers are base, model headers override
let headers =
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
const resolvedKey = resolveApiKey(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom), freshly loaded
* Returns { models, error } - either models array or error message
*/
export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model<Api>[]; error: string | null } {
const builtInModels: Model<Api>[] = [];
const providers = getProviders();
// Load all built-in models
for (const provider of providers) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models
const { models: customModels, error } = loadCustomModels(join(agentDir, "models.json"));
if (error) {
return { models: [], error };
}
const combined = [...builtInModels, ...customModels];
// Update github-copilot base URL based on OAuth token or enterprise domain
const copilotCreds = loadOAuthCredentials("github-copilot");
if (copilotCreds) {
const domain = copilotCreds.enterpriseUrl ? normalizeDomain(copilotCreds.enterpriseUrl) : undefined;
const baseUrl = getGitHubCopilotBaseUrl(copilotCreds.access, domain ?? undefined);
return {
models: combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)),
error: null,
};
}
return { models: combined, error: null };
}
/**
* Get API key for a model (checks custom providers first, then built-in)
* Now async to support OAuth token refresh.
* Note: OAuth storage location is configured globally via setOAuthStorage.
*/
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
// For custom providers, check their apiKey config
const customKeyConfig = customProviderApiKeys.get(model.provider);
if (customKeyConfig) {
return resolveApiKey(customKeyConfig);
}
// For Anthropic, check OAuth first
if (model.provider === "anthropic") {
// 1. Check OAuth storage (auto-refresh if needed)
const oauthToken = await getOAuthToken("anthropic");
if (oauthToken) {
return oauthToken;
}
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv) {
return oauthEnv;
}
// 3. Fall back to ANTHROPIC_API_KEY env var
}
if (model.provider === "github-copilot") {
// 1. Check OAuth storage (from device flow login)
const oauthToken = await getOAuthToken("github-copilot");
if (oauthToken) {
return oauthToken;
}
// 2. Use GitHub token directly (works with copilot scope on github.com)
const githubToken = process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
if (!githubToken) {
return undefined;
}
// 3. For enterprise, exchange token for short-lived Copilot token
const enterpriseDomain = process.env.COPILOT_ENTERPRISE_URL
? normalizeDomain(process.env.COPILOT_ENTERPRISE_URL)
: undefined;
if (enterpriseDomain) {
const creds = await refreshGitHubCopilotToken(githubToken, enterpriseDomain);
saveOAuthCredentials("github-copilot", creds);
return creds.access;
}
// 4. For github.com, use token directly
return githubToken;
}
// For Google Gemini CLI and Antigravity, check OAuth and encode projectId with token
if (model.provider === "google-gemini-cli" || model.provider === "google-antigravity") {
const oauthProvider = model.provider as "google-gemini-cli" | "google-antigravity";
const credentials = loadOAuthCredentials(oauthProvider);
if (!credentials) {
return undefined;
}
// Check if token is expired
if (Date.now() >= credentials.expires) {
try {
await refreshToken(oauthProvider);
const refreshedCreds = loadOAuthCredentials(oauthProvider);
if (refreshedCreds?.projectId) {
return JSON.stringify({ token: refreshedCreds.access, projectId: refreshedCreds.projectId });
}
} catch {
removeOAuthCredentials(oauthProvider);
return undefined;
}
}
if (credentials.projectId) {
return JSON.stringify({ token: credentials.access, projectId: credentials.projectId });
}
return undefined;
}
// For built-in providers, use getApiKey from @mariozechner/pi-ai
return getApiKey(model.provider as KnownProvider);
}
/**
* Get only models that have valid API keys available
* Returns { models, error } - either models array or error message
*
* @param agentDir - Agent config directory
* @param fallbackKeyResolver - Optional function to check for API keys not found by getApiKeyForModel
* (e.g., keys from settings.json)
*/
export async function getAvailableModels(
agentDir: string = getAgentDir(),
fallbackKeyResolver?: (provider: string) => string | undefined,
): Promise<{ models: Model<Api>[]; error: string | null }> {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { models: [], error };
}
const availableModels: Model<Api>[] = [];
for (const model of allModels) {
let apiKey = await getApiKeyForModel(model);
// Check fallback resolver if primary lookup failed
if (!apiKey && fallbackKeyResolver) {
apiKey = fallbackKeyResolver(model.provider);
}
if (apiKey) {
availableModels.push(model);
}
}
return { models: availableModels, error: null };
}
/**
* Find a specific model by provider and ID.
*
* Searches models from:
* 1. Built-in models from @mariozechner/pi-ai
* 2. Custom models defined in ~/.pi/agent/models.json
*
* Returns { model, error } - either the model or an error message.
*/
export function findModel(
provider: string,
modelId: string,
agentDir: string = getAgentDir(),
): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels(agentDir);
if (error) {
return { model: null, error };
}
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
return { model, error: null };
}
/**
* Mapping from model provider to OAuth provider ID.
* Only providers that support OAuth are listed here.
*/
const providerToOAuthProvider: Record<string, OAuthProvider> = {
anthropic: "anthropic",
"github-copilot": "github-copilot",
"google-gemini-cli": "google-gemini-cli",
"google-antigravity": "google-antigravity",
};
// Cache for OAuth status per provider (avoids file reads on every render)
const oauthStatusCache: Map<string, boolean> = new Map();
/**
* Invalidate the OAuth status cache.
* Call this after login/logout operations.
*/
export function invalidateOAuthCache(): void {
oauthStatusCache.clear();
}
/**
* Check if a model is using OAuth credentials (subscription).
* This checks if OAuth credentials exist and would be used for the model,
* without actually fetching or refreshing the token.
* Results are cached until invalidateOAuthCache() is called.
*/
export function isModelUsingOAuth(model: Model<Api>): boolean {
const oauthProvider = providerToOAuthProvider[model.provider];
if (!oauthProvider) {
return false;
}
// Check cache first
if (oauthStatusCache.has(oauthProvider)) {
return oauthStatusCache.get(oauthProvider)!;
}
// Check if OAuth credentials exist for this provider
let usingOAuth = false;
const credentials = loadOAuthCredentials(oauthProvider);
if (credentials) {
usingOAuth = true;
}
// Also check for manual OAuth token env var (for Anthropic)
if (!usingOAuth && model.provider === "anthropic" && process.env.ANTHROPIC_OAUTH_TOKEN) {
usingOAuth = true;
}
oauthStatusCache.set(oauthProvider, usingOAuth);
return usingOAuth;
}

View file

@ -30,22 +30,17 @@
*/
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import type { Model } from "@mariozechner/pi-ai";
import { join } from "path";
import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
import type { CustomAgentTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js";
import { messageTransformer } from "./messages.js";
import {
findModel as findModelInternal,
getApiKeyForModel,
getAvailableModels,
loadAndMergeModels,
} from "./models-json.js";
import { ModelRegistry } from "./model-registry.js";
import { SessionManager } from "./session-manager.js";
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
import { loadSkills as loadSkillsInternal, type Skill } from "./skills.js";
@ -86,6 +81,11 @@ export interface CreateAgentSessionOptions {
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: discoverAuthStorage(agentDir) */
authStorage?: AuthStorage;
/** Model registry. Default: discoverModels(authStorage, agentDir) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'off' (clamped to model capabilities) */
@ -93,9 +93,6 @@ export interface CreateAgentSessionOptions {
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel: ThinkingLevel }>;
/** API key resolver. Default: defaultGetApiKey() */
getApiKey?: (model: Model<any>) => Promise<string | undefined>;
/** System prompt. String replaces default, function receives default and returns final. */
systemPrompt?: string | ((defaultPrompt: string) => string);
@ -177,73 +174,20 @@ function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Configure OAuth storage to use the specified agent directory.
* Must be called before using OAuth-based authentication.
*/
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
const oauthPath = join(agentDir, "oauth.json");
setOAuthStorage({
load: () => {
if (!existsSync(oauthPath)) {
return {};
}
try {
return JSON.parse(readFileSync(oauthPath, "utf-8"));
} catch {
return {};
}
},
save: (storage) => {
const dir = dirname(oauthPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
chmodSync(oauthPath, 0o600);
},
});
}
// Discovery Functions
/**
* Get all models (built-in + custom from models.json).
* Create an AuthStorage instance for the given agent directory.
*/
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
const { models, error } = loadAndMergeModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
export function discoverAuthStorage(agentDir: string = getDefaultAgentDir()): AuthStorage {
return new AuthStorage(join(agentDir, "auth.json"));
}
/**
* Get models that have valid API keys available.
* Create a ModelRegistry for the given agent directory.
*/
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
const { models, error } = await getAvailableModels(agentDir);
if (error) {
throw new Error(error);
}
return models;
}
/**
* Find a model by provider and ID.
* @returns The model, or null if not found
*/
export function findModel(
provider: string,
modelId: string,
agentDir: string = getDefaultAgentDir(),
): Model<any> | null {
const { model, error } = findModelInternal(provider, modelId, agentDir);
if (error) {
throw new Error(error);
}
return model;
export function discoverModels(authStorage: AuthStorage, agentDir: string = getDefaultAgentDir()): ModelRegistry {
return new ModelRegistry(authStorage, join(agentDir, "models.json"));
}
/**
@ -326,30 +270,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
// API Key Helpers
/**
* Create the default API key resolver.
* Priority: OAuth > custom providers (models.json) > environment variables > settings.json apiKeys.
*
* OAuth takes priority so users logged in with a plan (e.g. unlimited tokens) aren't
* accidentally billed via a PAYG API key sitting in settings.json.
*/
export function defaultGetApiKey(
settingsManager?: SettingsManager,
): (model: Model<any>) => Promise<string | undefined> {
return async (model: Model<any>) => {
// Check OAuth, custom providers, env vars first
const resolvedKey = await getApiKeyForModel(model);
if (resolvedKey) {
return resolvedKey;
}
// Fall back to settings.json apiKeys
if (settingsManager) {
return settingsManager.getApiKey(model.provider);
}
return undefined;
};
}
// System Prompt
export interface BuildSystemPromptOptions {
@ -457,8 +377,9 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@mariozechner/pi-ai';
* const { session } = await createAgentSession({
* model: findModel('anthropic', 'claude-sonnet-4-20250514'),
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
@ -483,22 +404,16 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
// Configure OAuth storage for this agentDir
configureOAuthStorage(agentDir);
time("configureOAuthStorage");
// Use provided or create AuthStorage and ModelRegistry
const authStorage = options.authStorage ?? discoverAuthStorage(agentDir);
const modelRegistry = options.modelRegistry ?? discoverModels(authStorage, agentDir);
time("discoverModels");
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
time("settingsManager");
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, agentDir);
time("sessionManager");
// Helper to check API key availability (settings first, then OAuth/env vars)
const hasApiKey = async (m: Model<any>): Promise<boolean> => {
const settingsKey = settingsManager.getApiKey(m.provider);
if (settingsKey) return true;
return !!(await getApiKeyForModel(m));
};
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
time("loadSession");
@ -509,8 +424,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = findModel(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await hasApiKey(restoredModel))) {
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
@ -523,8 +438,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const defaultProvider = settingsManager.getDefaultProvider();
const defaultModelId = settingsManager.getDefaultModel();
if (defaultProvider && defaultModelId) {
const settingsModel = findModel(defaultProvider, defaultModelId);
if (settingsModel && (await hasApiKey(settingsModel))) {
const settingsModel = modelRegistry.find(defaultProvider, defaultModelId);
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
model = settingsModel;
}
}
@ -532,14 +447,13 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Fall back to first available model with a valid API key
if (!model) {
const allModels = discoverModels(agentDir);
for (const m of allModels) {
if (await hasApiKey(m)) {
for (const m of modelRegistry.getAll()) {
if (await modelRegistry.getApiKey(m)) {
model = m;
break;
}
}
time("discoverAvailableModels");
time("findAvailableModel");
if (model) {
if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
@ -567,8 +481,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
thinkingLevel = "off";
}
const getApiKey = options.getApiKey ?? defaultGetApiKey(settingsManager);
const skills = options.skills ?? discoverSkills(cwd, agentDir, settingsManager.getSkillsSettings());
time("discoverSkills");
@ -661,7 +573,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (!currentModel) {
throw new Error("No model selected");
}
const key = await getApiKey(currentModel);
const key = await modelRegistry.getApiKey(currentModel);
if (!key) {
throw new Error(`No API key found for provider "${currentModel.provider}"`);
}
@ -685,7 +597,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
hookRunner,
customTools: customToolsResult.tools,
skillsSettings: settingsManager.getSkillsSettings(),
resolveApiKey: getApiKey,
modelRegistry,
});
time("createAgentSession");

View file

@ -47,7 +47,6 @@ export interface Settings {
customTools?: string[]; // Array of custom tool file paths
skills?: SkillsSettings;
terminal?: TerminalSettings;
apiKeys?: Record<string, string>; // provider -> API key (e.g., { "anthropic": "sk-..." })
}
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
@ -366,27 +365,4 @@ export class SettingsManager {
this.globalSettings.terminal.showImages = show;
this.save();
}
getApiKey(provider: string): string | undefined {
return this.settings.apiKeys?.[provider];
}
setApiKey(provider: string, key: string): void {
if (!this.globalSettings.apiKeys) {
this.globalSettings.apiKeys = {};
}
this.globalSettings.apiKeys[provider] = key;
this.save();
}
removeApiKey(provider: string): void {
if (this.globalSettings.apiKeys) {
delete this.globalSettings.apiKeys[provider];
this.save();
}
}
getApiKeys(): Record<string, string> {
return this.settings.apiKeys ?? {};
}
}

View file

@ -9,6 +9,8 @@ export {
type PromptOptions,
type SessionStats,
} from "./core/agent-session.js";
// Auth and model registry
export { type ApiKeyCredential, type AuthCredential, AuthStorage, type OAuthCredential } from "./core/auth-storage.js";
// Compaction
export {
type CutPointResult,
@ -72,24 +74,13 @@ export {
isWriteToolResult,
} from "./core/hooks/index.js";
export { messageTransformer } from "./core/messages.js";
// Model configuration and OAuth
export { findModel, getApiKeyForModel, getAvailableModels } from "./core/models-json.js";
export {
getOAuthProviders,
login,
logout,
type OAuthAuthInfo,
type OAuthPrompt,
type OAuthProvider,
} from "./core/oauth/index.js";
export { ModelRegistry } from "./core/model-registry.js";
// SDK for programmatic usage
export {
type BuildSystemPromptOptions,
buildSystemPrompt,
type CreateAgentSessionOptions,
type CreateAgentSessionResult,
// Configuration
configureOAuthStorage,
// Factory
createAgentSession,
createBashTool,
@ -102,18 +93,15 @@ export {
createReadOnlyTools,
createReadTool,
createWriteTool,
// Helpers
defaultGetApiKey,
discoverAvailableModels,
// Discovery
discoverAuthStorage,
discoverContextFiles,
discoverCustomTools,
discoverHooks,
// Discovery
discoverModels,
discoverSkills,
discoverSlashCommands,
type FileSlashCommand,
findModel as findModelByProviderAndId,
loadSettings,
// Pre-built tools (use process.cwd())
readOnlyTools,

View file

@ -8,19 +8,20 @@
import type { Attachment } from "@mariozechner/pi-agent-core";
import { supportsXhigh } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { join } from "path";
import { type Args, parseArgs, printHelp } from "./cli/args.js";
import { processFileArguments } from "./cli/file-processor.js";
import { listModels } from "./cli/list-models.js";
import { selectSession } from "./cli/session-picker.js";
import { getModelsPath, VERSION } from "./config.js";
import { getAgentDir, getModelsPath, VERSION } from "./config.js";
import type { AgentSession } from "./core/agent-session.js";
import { AuthStorage } from "./core/auth-storage.js";
import type { LoadedCustomTool } from "./core/custom-tools/index.js";
import { exportFromFile } from "./core/export-html.js";
import type { HookUIContext } from "./core/index.js";
import type { ModelRegistry } from "./core/model-registry.js";
import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
import { findModel } from "./core/models-json.js";
import { type CreateAgentSessionOptions, configureOAuthStorage, createAgentSession } from "./core/sdk.js";
import { type CreateAgentSessionOptions, createAgentSession, discoverAuthStorage, discoverModels } from "./core/sdk.js";
import { SessionManager } from "./core/session-manager.js";
import { SettingsManager } from "./core/settings-manager.js";
import { resolvePromptInput } from "./core/system-prompt.js";
@ -33,7 +34,7 @@ import { ensureTool } from "./utils/tools-manager.js";
async function checkForNewVersion(currentVersion: string): Promise<string | null> {
try {
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi-coding-agent/latest");
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi -coding-agent/latest");
if (!response.ok) return null;
const data = (await response.json()) as { version?: string };
@ -54,6 +55,8 @@ async function runInteractiveMode(
version: string,
changelogMarkdown: string | null,
modelFallbackMessage: string | undefined,
modelsJsonError: string | null,
migratedProviders: string[],
versionCheckPromise: Promise<string | null>,
initialMessages: string[],
customTools: LoadedCustomTool[],
@ -74,6 +77,14 @@ async function runInteractiveMode(
mode.renderInitialMessages(session.state);
if (migratedProviders.length > 0) {
mode.showWarning(`Migrated credentials to auth.json: ${migratedProviders.join(", ")}`);
}
if (modelsJsonError) {
mode.showError(`models.json error: ${modelsJsonError}`);
}
if (modelFallbackMessage) {
mode.showWarning(modelFallbackMessage);
}
@ -175,6 +186,7 @@ function buildSessionOptions(
parsed: Args,
scopedModels: ScopedModel[],
sessionManager: SessionManager | null,
modelRegistry: ModelRegistry,
): CreateAgentSessionOptions {
const options: CreateAgentSessionOptions = {};
@ -187,11 +199,7 @@ function buildSessionOptions(
// Model from CLI
if (parsed.provider && parsed.model) {
const { model, error } = findModel(parsed.provider, parsed.model);
if (error) {
console.error(chalk.red(error));
process.exit(1);
}
const model = modelRegistry.find(parsed.provider, parsed.model);
if (!model) {
console.error(chalk.red(`Model ${parsed.provider}/${parsed.model} not found`));
process.exit(1);
@ -213,10 +221,8 @@ function buildSessionOptions(
options.scopedModels = scopedModels;
}
// API key from CLI
if (parsed.apiKey) {
options.getApiKey = async () => parsed.apiKey!;
}
// API key from CLI - set in authStorage
// (handled by caller before createAgentSession)
// System prompt
if (resolvedSystemPrompt && resolvedAppendPrompt) {
@ -252,8 +258,15 @@ function buildSessionOptions(
export async function main(args: string[]) {
time("start");
configureOAuthStorage();
time("configureOAuthStorage");
// Migrate legacy oauth.json and settings.json apiKeys to auth.json
const agentDir = getAgentDir();
const migratedProviders = AuthStorage.migrateLegacy(join(agentDir, "auth.json"), agentDir);
// Create AuthStorage and ModelRegistry upfront
const authStorage = discoverAuthStorage();
const modelRegistry = discoverModels(authStorage);
time("discoverModels");
const parsed = parseArgs(args);
time("parseArgs");
@ -270,8 +283,7 @@ export async function main(args: string[]) {
if (parsed.listModels !== undefined) {
const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined;
const settingsManager = SettingsManager.create(process.cwd());
await listModels(searchPattern, settingsManager);
await listModels(modelRegistry, searchPattern);
return;
}
@ -306,7 +318,7 @@ export async function main(args: string[]) {
let scopedModels: ScopedModel[] = [];
if (parsed.models && parsed.models.length > 0) {
scopedModels = await resolveModelScope(parsed.models, settingsManager);
scopedModels = await resolveModelScope(parsed.models, modelRegistry);
time("resolveModelScope");
}
@ -331,7 +343,19 @@ export async function main(args: string[]) {
sessionManager = SessionManager.open(selectedPath);
}
const sessionOptions = buildSessionOptions(parsed, scopedModels, sessionManager);
const sessionOptions = buildSessionOptions(parsed, scopedModels, sessionManager, modelRegistry);
sessionOptions.authStorage = authStorage;
sessionOptions.modelRegistry = modelRegistry;
// Handle CLI --api-key as runtime override (not persisted)
if (parsed.apiKey) {
if (!sessionOptions.model) {
console.error(chalk.red("--api-key requires a model to be specified via --provider/--model or -m/--models"));
process.exit(1);
}
authStorage.setRuntimeApiKey(sessionOptions.model.provider, parsed.apiKey);
}
time("buildSessionOptions");
const { session, customToolsResult, modelFallbackMessage } = await createAgentSession(sessionOptions);
time("createAgentSession");
@ -382,6 +406,8 @@ export async function main(args: string[]) {
VERSION,
changelogMarkdown,
modelFallbackMessage,
modelRegistry.getError(),
migratedProviders,
versionCheckPromise,
parsed.messages,
customToolsResult.tools,

View file

@ -3,7 +3,7 @@ import type { AssistantMessage } from "@mariozechner/pi-ai";
import { type Component, visibleWidth } from "@mariozechner/pi-tui";
import { existsSync, type FSWatcher, readFileSync, watch } from "fs";
import { dirname, join } from "path";
import { isModelUsingOAuth } from "../../../core/models-json.js";
import type { ModelRegistry } from "../../../core/model-registry.js";
import { theme } from "../theme/theme.js";
/**
@ -31,13 +31,15 @@ function findGitHeadPath(): string | null {
*/
export class FooterComponent implements Component {
private state: AgentState;
private modelRegistry: ModelRegistry;
private cachedBranch: string | null | undefined = undefined; // undefined = not checked yet, null = not in git repo, string = branch name
private gitWatcher: FSWatcher | null = null;
private onBranchChange: (() => void) | null = null;
private autoCompactEnabled: boolean = true;
constructor(state: AgentState) {
constructor(state: AgentState, modelRegistry: ModelRegistry) {
this.state = state;
this.modelRegistry = modelRegistry;
}
setAutoCompactEnabled(enabled: boolean): void {
@ -207,7 +209,7 @@ export class FooterComponent implements Component {
if (totalCacheWrite) statsParts.push(`W${formatTokens(totalCacheWrite)}`);
// Show cost with "(sub)" indicator if using OAuth subscription
const usingSubscription = this.state.model ? isModelUsingOAuth(this.state.model) : false;
const usingSubscription = this.state.model ? this.modelRegistry.isUsingOAuth(this.state.model) : false;
if (totalCost || usingSubscription) {
const costStr = `$${totalCost.toFixed(3)}${usingSubscription ? " (sub)" : ""}`;
statsParts.push(costStr);

View file

@ -10,7 +10,7 @@ import {
Text,
type TUI,
} from "@mariozechner/pi-tui";
import { getAvailableModels } from "../../../core/models-json.js";
import type { ModelRegistry } from "../../../core/model-registry.js";
import type { SettingsManager } from "../../../core/settings-manager.js";
import { fuzzyFilter } from "../../../utils/fuzzy.js";
import { theme } from "../theme/theme.js";
@ -38,6 +38,7 @@ export class ModelSelectorComponent extends Container {
private selectedIndex: number = 0;
private currentModel: Model<any> | null;
private settingsManager: SettingsManager;
private modelRegistry: ModelRegistry;
private onSelectCallback: (model: Model<any>) => void;
private onCancelCallback: () => void;
private errorMessage: string | null = null;
@ -48,6 +49,7 @@ export class ModelSelectorComponent extends Container {
tui: TUI,
currentModel: Model<any> | null,
settingsManager: SettingsManager,
modelRegistry: ModelRegistry,
scopedModels: ReadonlyArray<ScopedModelItem>,
onSelect: (model: Model<any>) => void,
onCancel: () => void,
@ -57,6 +59,7 @@ export class ModelSelectorComponent extends Container {
this.tui = tui;
this.currentModel = currentModel;
this.settingsManager = settingsManager;
this.modelRegistry = modelRegistry;
this.scopedModels = scopedModels;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
@ -113,26 +116,29 @@ export class ModelSelectorComponent extends Container {
model: scoped.model,
}));
} else {
// Load available models fresh (includes custom models from models.json)
// Pass settings manager's key resolver as fallback for settings.json apiKeys
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
this.settingsManager.getApiKey(provider),
);
// Refresh to pick up any changes to models.json
this.modelRegistry.refresh();
// If there's an error loading models.json, we'll show it via the "no models" path
// The error will be displayed to the user
if (error) {
this.allModels = [];
this.filteredModels = [];
this.errorMessage = error;
return;
// Check for models.json errors
const loadError = this.modelRegistry.getError();
if (loadError) {
this.errorMessage = loadError;
}
models = availableModels.map((model) => ({
provider: model.provider,
id: model.id,
model,
}));
// Load available models (built-in models still work even if models.json failed)
try {
const availableModels = await this.modelRegistry.getAvailable();
models = availableModels.map((model: Model<any>) => ({
provider: model.provider,
id: model.id,
model,
}));
} catch (error) {
this.allModels = [];
this.filteredModels = [];
this.errorMessage = error instanceof Error ? error.message : String(error);
return;
}
}
// Sort: current model first, then by provider

View file

@ -1,6 +1,6 @@
import { loadOAuthCredentials } from "@mariozechner/pi-ai";
import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai";
import { Container, isArrowDown, isArrowUp, isEnter, isEscape, Spacer, TruncatedText } from "@mariozechner/pi-tui";
import { getOAuthProviders, type OAuthProviderInfo } from "../../../core/oauth/index.js";
import type { AuthStorage } from "../../../core/auth-storage.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
@ -12,13 +12,20 @@ export class OAuthSelectorComponent extends Container {
private allProviders: OAuthProviderInfo[] = [];
private selectedIndex: number = 0;
private mode: "login" | "logout";
private authStorage: AuthStorage;
private onSelectCallback: (providerId: string) => void;
private onCancelCallback: () => void;
constructor(mode: "login" | "logout", onSelect: (providerId: string) => void, onCancel: () => void) {
constructor(
mode: "login" | "logout",
authStorage: AuthStorage,
onSelect: (providerId: string) => void,
onCancel: () => void,
) {
super();
this.mode = mode;
this.authStorage = authStorage;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
@ -49,7 +56,6 @@ export class OAuthSelectorComponent extends Container {
private loadProviders(): void {
this.allProviders = getOAuthProviders();
this.allProviders = this.allProviders.filter((p) => p.available);
}
private updateList(): void {
@ -63,8 +69,8 @@ export class OAuthSelectorComponent extends Container {
const isAvailable = provider.available;
// Check if user is logged in for this provider
const credentials = loadOAuthCredentials(provider.id);
const isLoggedIn = credentials !== null;
const credentials = this.authStorage.get(provider.id);
const isLoggedIn = credentials?.type === "oauth";
const statusIndicator = isLoggedIn ? theme.fg("success", " ✓ logged in") : "";
let line = "";

View file

@ -7,7 +7,7 @@ import * as fs from "node:fs";
import * as os from "node:os";
import * as path from "node:path";
import type { AgentState, AppMessage, Attachment } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Message } from "@mariozechner/pi-ai";
import type { AssistantMessage, Message, OAuthProvider } from "@mariozechner/pi-ai";
import type { SlashCommand } from "@mariozechner/pi-tui";
import {
CombinedAutocompleteProvider,
@ -25,13 +25,11 @@ import {
visibleWidth,
} from "@mariozechner/pi-tui";
import { exec, spawnSync } from "child_process";
import { APP_NAME, getDebugLogPath, getOAuthPath } from "../../config.js";
import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js";
import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js";
import type { HookUIContext } from "../../core/hooks/index.js";
import { isBashExecutionMessage } from "../../core/messages.js";
import { invalidateOAuthCache } from "../../core/models-json.js";
import { listOAuthProviders, login, logout, type OAuthProvider } from "../../core/oauth/index.js";
import {
getLatestCompactionEntry,
SessionManager,
@ -154,7 +152,7 @@ export class InteractiveMode {
this.editor = new CustomEditor(getEditorTheme());
this.editorContainer = new Container();
this.editorContainer.addChild(this.editor);
this.footer = new FooterComponent(session.state);
this.footer = new FooterComponent(session.state, session.modelRegistry);
this.footer.setAutoCompactEnabled(session.autoCompactionEnabled);
// Define slash commands for autocomplete
@ -1484,6 +1482,7 @@ export class InteractiveMode {
this.ui,
this.session.model,
this.settingsManager,
this.session.modelRegistry,
this.session.scopedModels,
async (model) => {
try {
@ -1588,7 +1587,10 @@ export class InteractiveMode {
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
if (mode === "logout") {
const loggedInProviders = listOAuthProviders();
const providers = this.session.modelRegistry.authStorage.list();
const loggedInProviders = providers.filter(
(p) => this.session.modelRegistry.authStorage.get(p)?.type === "oauth",
);
if (loggedInProviders.length === 0) {
this.showStatus("No OAuth providers logged in. Use /login first.");
return;
@ -1598,6 +1600,7 @@ export class InteractiveMode {
this.showSelector((done) => {
const selector = new OAuthSelectorComponent(
mode,
this.session.modelRegistry.authStorage,
async (providerId: string) => {
done();
@ -1605,9 +1608,8 @@ export class InteractiveMode {
this.showStatus(`Logging in to ${providerId}...`);
try {
await login(
providerId as OAuthProvider,
(info) => {
await this.session.modelRegistry.authStorage.login(providerId as OAuthProvider, {
onAuth: (info: { url: string; instructions?: string }) => {
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(theme.fg("accent", "Opening browser to:"), 1, 0));
this.chatContainer.addChild(new Text(theme.fg("accent", info.url), 1, 0));
@ -1625,7 +1627,7 @@ export class InteractiveMode {
: "xdg-open";
exec(`${openCmd} "${info.url}"`);
},
async (prompt) => {
onPrompt: async (prompt: { message: string; placeholder?: string }) => {
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(theme.fg("warning", prompt.message), 1, 0));
if (prompt.placeholder) {
@ -1648,32 +1650,35 @@ export class InteractiveMode {
this.ui.requestRender();
});
},
(message) => {
onProgress: (message: string) => {
this.chatContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
this.ui.requestRender();
},
);
invalidateOAuthCache();
});
// Refresh models to pick up new baseUrl (e.g., github-copilot)
this.session.modelRegistry.refresh();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(
new Text(theme.fg("success", `✓ Successfully logged in to ${providerId}`), 1, 0),
);
this.chatContainer.addChild(new Text(theme.fg("dim", `Tokens saved to ${getOAuthPath()}`), 1, 0));
this.chatContainer.addChild(
new Text(theme.fg("dim", `Credentials saved to ${getAuthPath()}`), 1, 0),
);
this.ui.requestRender();
} catch (error: unknown) {
this.showError(`Login failed: ${error instanceof Error ? error.message : String(error)}`);
}
} else {
try {
await logout(providerId as OAuthProvider);
invalidateOAuthCache();
this.session.modelRegistry.authStorage.logout(providerId);
// Refresh models to reset baseUrl
this.session.modelRegistry.refresh();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(
new Text(theme.fg("success", `✓ Successfully logged out of ${providerId}`), 1, 0),
);
this.chatContainer.addChild(
new Text(theme.fg("dim", `Credentials removed from ${getOAuthPath()}`), 1, 0),
new Text(theme.fg("dim", `Credentials removed from ${getAuthPath()}`), 1, 0),
);
this.ui.requestRender();
} catch (error: unknown) {

View file

@ -14,6 +14,8 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
import { getModel } from "@mariozechner/pi-ai";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { AgentSession } from "../src/core/agent-session.js";
import { AuthStorage } from "../src/core/auth-storage.js";
import { ModelRegistry } from "../src/core/model-registry.js";
import { SessionManager } from "../src/core/session-manager.js";
import { SettingsManager } from "../src/core/settings-manager.js";
import { codingTools } from "../src/core/tools/index.js";
@ -58,11 +60,14 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => {
sessionManager = noSession ? SessionManager.inMemory() : SessionManager.create(tempDir);
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
session = new AgentSession({
agent,
sessionManager,
settingsManager,
modelRegistry,
});
// Must subscribe to enable session persistence

View file

@ -14,6 +14,8 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
import { getModel } from "@mariozechner/pi-ai";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { AgentSession, type AgentSessionEvent } from "../src/core/agent-session.js";
import { AuthStorage } from "../src/core/auth-storage.js";
import { ModelRegistry } from "../src/core/model-registry.js";
import { SessionManager } from "../src/core/session-manager.js";
import { SettingsManager } from "../src/core/settings-manager.js";
import { codingTools } from "../src/core/tools/index.js";
@ -62,11 +64,14 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
sessionManager = SessionManager.create(tempDir);
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
session = new AgentSession({
agent,
sessionManager,
settingsManager,
modelRegistry,
});
// Subscribe to track events
@ -178,11 +183,14 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
const noSessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
const noSessionSession = new AgentSession({
agent,
sessionManager: noSessionManager,
settingsManager,
modelRegistry,
});
try {

View file

@ -9,7 +9,9 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
import { getModel } from "@mariozechner/pi-ai";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { AgentSession } from "../src/core/agent-session.js";
import { AuthStorage } from "../src/core/auth-storage.js";
import { HookRunner, type LoadedHook, type SessionEvent } from "../src/core/hooks/index.js";
import { ModelRegistry } from "../src/core/model-registry.js";
import { SessionManager } from "../src/core/session-manager.js";
import { SettingsManager } from "../src/core/settings-manager.js";
import { codingTools } from "../src/core/tools/index.js";
@ -83,6 +85,8 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => {
const sessionManager = SessionManager.create(tempDir);
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
hookRunner = new HookRunner(hooks, tempDir);
hookRunner.setUIContext(
@ -101,6 +105,7 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => {
sessionManager,
settingsManager,
hookRunner,
modelRegistry,
});
return session;

View file

@ -2,13 +2,16 @@ import { Agent, type AgentEvent, type Attachment, ProviderTransport } from "@mar
import { getModel } from "@mariozechner/pi-ai";
import {
AgentSession,
AuthStorage,
formatSkillsForPrompt,
loadSkillsFromDir,
ModelRegistry,
messageTransformer,
type Skill,
} from "@mariozechner/pi-coding-agent";
import { existsSync, readFileSync, statSync } from "fs";
import { mkdir, writeFile } from "fs/promises";
import { homedir } from "os";
import { join } from "path";
import { MomSessionManager, MomSettingsManager } from "./context.js";
import * as log from "./log.js";
@ -435,11 +438,17 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi
log.logInfo(`[${channelId}] Loaded ${loadedSession.messages.length} messages from context.jsonl`);
}
// Create AuthStorage and ModelRegistry for AgentSession
// Auth stored outside workspace so agent can't access it
const authStorage = new AuthStorage(join(homedir(), ".pi", "mom", "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
// Create AgentSession wrapper
const session = new AgentSession({
agent,
sessionManager: sessionManager as any,
settingsManager: settingsManager as any,
modelRegistry,
});
// Mutable per-run state - event handler references this

View file

@ -28,5 +28,5 @@
}
},
"include": ["packages/*/src/**/*", "packages/*/test/**/*", "packages/coding-agent/examples/**/*"],
"exclude": ["packages/web-ui/**/*"]
"exclude": ["packages/web-ui/**/*", "**/dist/**"]
}