mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 07:04:45 +00:00
Refactor OAuth/API key handling: AuthStorage and ModelRegistry
- Add AuthStorage class for credential storage (auth.json) - Add ModelRegistry class for model management with API key resolution - Add discoverAuthStorage() and discoverModels() discovery functions - Add migration from legacy oauth.json and settings.json apiKeys to auth.json - Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels - Remove apiKeys from Settings type and SettingsManager methods - Rename getOAuthPath to getAuthPath - Update SDK, examples, docs, tests, and mom package Fixes #296
This commit is contained in:
parent
9f97f0c8da
commit
54018b6cc0
29 changed files with 953 additions and 2017 deletions
1013
compact.jsonl
1013
compact.jsonl
File diff suppressed because one or more lines are too long
|
|
@ -6359,6 +6359,23 @@ export const MODELS = {
|
|||
contextWindow: 128000,
|
||||
maxTokens: 16384,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"meta-llama/llama-3.1-70b-instruct": {
|
||||
id: "meta-llama/llama-3.1-70b-instruct",
|
||||
name: "Meta: Llama 3.1 70B Instruct",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.39999999999999997,
|
||||
output: 0.39999999999999997,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 131072,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"meta-llama/llama-3.1-8b-instruct": {
|
||||
id: "meta-llama/llama-3.1-8b-instruct",
|
||||
name: "Meta: Llama 3.1 8B Instruct",
|
||||
|
|
@ -6393,23 +6410,6 @@ export const MODELS = {
|
|||
contextWindow: 10000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"meta-llama/llama-3.1-70b-instruct": {
|
||||
id: "meta-llama/llama-3.1-70b-instruct",
|
||||
name: "Meta: Llama 3.1 70B Instruct",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.39999999999999997,
|
||||
output: 0.39999999999999997,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 131072,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"mistralai/mistral-nemo": {
|
||||
id: "mistralai/mistral-nemo",
|
||||
name: "Mistral: Mistral Nemo",
|
||||
|
|
@ -6546,23 +6546,6 @@ export const MODELS = {
|
|||
contextWindow: 128000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-4o-2024-05-13": {
|
||||
id: "openai/gpt-4o-2024-05-13",
|
||||
name: "OpenAI: GPT-4o (2024-05-13)",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: {
|
||||
input: 5,
|
||||
output: 15,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 128000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-4o": {
|
||||
id: "openai/gpt-4o",
|
||||
name: "OpenAI: GPT-4o",
|
||||
|
|
@ -6597,6 +6580,23 @@ export const MODELS = {
|
|||
contextWindow: 128000,
|
||||
maxTokens: 64000,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-4o-2024-05-13": {
|
||||
id: "openai/gpt-4o-2024-05-13",
|
||||
name: "OpenAI: GPT-4o (2024-05-13)",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: {
|
||||
input: 5,
|
||||
output: 15,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 128000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"meta-llama/llama-3-70b-instruct": {
|
||||
id: "meta-llama/llama-3-70b-instruct",
|
||||
name: "Meta: Llama 3 70B Instruct",
|
||||
|
|
@ -6716,23 +6716,6 @@ export const MODELS = {
|
|||
contextWindow: 128000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-3.5-turbo-0613": {
|
||||
id: "openai/gpt-3.5-turbo-0613",
|
||||
name: "OpenAI: GPT-3.5 Turbo (older v0613)",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 2,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 4095,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-4-turbo-preview": {
|
||||
id: "openai/gpt-4-turbo-preview",
|
||||
name: "OpenAI: GPT-4 Turbo Preview",
|
||||
|
|
@ -6750,6 +6733,23 @@ export const MODELS = {
|
|||
contextWindow: 128000,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-3.5-turbo-0613": {
|
||||
id: "openai/gpt-3.5-turbo-0613",
|
||||
name: "OpenAI: GPT-3.5 Turbo (older v0613)",
|
||||
api: "openai-completions",
|
||||
provider: "openrouter",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 2,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 4095,
|
||||
maxTokens: 4096,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"mistralai/mistral-tiny": {
|
||||
id: "mistralai/mistral-tiny",
|
||||
name: "Mistral Tiny",
|
||||
|
|
|
|||
|
|
@ -106,29 +106,21 @@ For most users, [Git for Windows](https://git-scm.com/download/win) is sufficien
|
|||
|
||||
### API Keys & OAuth
|
||||
|
||||
**Option 1: Settings file** (recommended)
|
||||
**Option 1: Auth file** (recommended)
|
||||
|
||||
Add API keys to `~/.pi/agent/settings.json`:
|
||||
Add API keys to `~/.pi/agent/auth.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"apiKeys": {
|
||||
"anthropic": "sk-ant-...",
|
||||
"openai": "sk-...",
|
||||
"google": "...",
|
||||
"mistral": "...",
|
||||
"groq": "...",
|
||||
"cerebras": "...",
|
||||
"xai": "...",
|
||||
"openrouter": "...",
|
||||
"zai": "..."
|
||||
}
|
||||
"anthropic": { "type": "api_key", "key": "sk-ant-..." },
|
||||
"openai": { "type": "api_key", "key": "sk-..." },
|
||||
"google": { "type": "api_key", "key": "..." }
|
||||
}
|
||||
```
|
||||
|
||||
**Option 2: Environment variables**
|
||||
|
||||
| Provider | Settings Key | Environment Variable |
|
||||
| Provider | Auth Key | Environment Variable |
|
||||
|----------|--------------|---------------------|
|
||||
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` |
|
||||
| OpenAI | `openai` | `OPENAI_API_KEY` |
|
||||
|
|
@ -140,7 +132,7 @@ Add API keys to `~/.pi/agent/settings.json`:
|
|||
| OpenRouter | `openrouter` | `OPENROUTER_API_KEY` |
|
||||
| ZAI | `zai` | `ZAI_API_KEY` |
|
||||
|
||||
Settings file keys take priority over environment variables.
|
||||
Auth file keys take priority over environment variables.
|
||||
|
||||
**OAuth Providers:**
|
||||
|
||||
|
|
@ -158,6 +150,8 @@ pi
|
|||
/login # Select provider, authorize in browser
|
||||
```
|
||||
|
||||
**Note:** `/login` replaces any existing API key for that provider with OAuth credentials in `auth.json`.
|
||||
|
||||
**GitHub Copilot notes:**
|
||||
- Press Enter for github.com, or enter your GitHub Enterprise Server domain
|
||||
- If you get "model not supported" error, enable it in VS Code: Copilot Chat → model selector → select model → "Enable"
|
||||
|
|
@ -167,7 +161,7 @@ pi
|
|||
- Antigravity uses a sandbox endpoint with access to Gemini 3, Claude (sonnet/opus thinking), and GPT-OSS models
|
||||
- Both are free with any Google account, subject to rate limits
|
||||
|
||||
Tokens stored in `~/.pi/agent/oauth.json`. Use `/logout` to clear.
|
||||
Credentials stored in `~/.pi/agent/auth.json`. Use `/logout` to clear.
|
||||
|
||||
### Quick Start
|
||||
|
||||
|
|
@ -855,10 +849,15 @@ For adding new tools, see [Custom Tools](#custom-tools) in the Configuration sec
|
|||
For embedding pi in Node.js/TypeScript applications, use the SDK:
|
||||
|
||||
```typescript
|
||||
import { createAgentSession, SessionManager } from "@mariozechner/pi-coding-agent";
|
||||
import { createAgentSession, discoverAuthStorage, discoverModels, SessionManager } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
session.subscribe((event) => {
|
||||
|
|
|
|||
|
|
@ -14,10 +14,16 @@ See [examples/sdk/](../examples/sdk/) for working examples from minimal to full
|
|||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { createAgentSession, SessionManager } from "@mariozechner/pi-coding-agent";
|
||||
import { createAgentSession, discoverAuthStorage, discoverModels, SessionManager } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Set up credential storage and model registry
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
session.subscribe((event) => {
|
||||
|
|
@ -220,32 +226,42 @@ const { session } = await createAgentSession({
|
|||
- Global commands (`commands/`)
|
||||
- Global context file (`AGENTS.md`)
|
||||
- Settings (`settings.json`)
|
||||
- Models (`models.json`)
|
||||
- OAuth tokens (`oauth.json`)
|
||||
- Custom models (`models.json`)
|
||||
- Credentials (`auth.json`)
|
||||
- Sessions (`sessions/`)
|
||||
|
||||
### Model
|
||||
|
||||
```typescript
|
||||
import { findModel, discoverAvailableModels } from "@mariozechner/pi-coding-agent";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Find specific model (returns { model, error })
|
||||
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
if (error) throw new Error(error);
|
||||
if (!model) throw new Error("Model not found");
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
// Or get all models with valid API keys
|
||||
const available = await discoverAvailableModels();
|
||||
// Find specific built-in model (doesn't check if API key exists)
|
||||
const opus = getModel("anthropic", "claude-opus-4-5");
|
||||
if (!opus) throw new Error("Model not found");
|
||||
|
||||
// Find any model by provider/id, including custom models from models.json
|
||||
// (doesn't check if API key exists)
|
||||
const customModel = modelRegistry.find("my-provider", "my-model");
|
||||
|
||||
// Get only models that have valid API keys configured
|
||||
const available = await modelRegistry.getAvailable();
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
model: model,
|
||||
model: opus,
|
||||
thinkingLevel: "medium", // off, minimal, low, medium, high, xhigh
|
||||
|
||||
// Models for cycling (Ctrl+P in interactive mode)
|
||||
scopedModels: [
|
||||
{ model: sonnet, thinkingLevel: "high" },
|
||||
{ model: opus, thinkingLevel: "high" },
|
||||
{ model: haiku, thinkingLevel: "off" },
|
||||
],
|
||||
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
```
|
||||
|
||||
|
|
@ -256,38 +272,42 @@ If no model is provided:
|
|||
|
||||
> See [examples/sdk/02-custom-model.ts](../examples/sdk/02-custom-model.ts)
|
||||
|
||||
### API Keys
|
||||
### API Keys and OAuth
|
||||
|
||||
API key resolution priority:
|
||||
1. `settings.json` apiKeys (e.g., `{ "apiKeys": { "anthropic": "sk-..." } }`)
|
||||
2. Custom providers from `models.json`
|
||||
3. OAuth credentials from `oauth.json`
|
||||
4. Environment variables (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.)
|
||||
API key resolution priority (handled by AuthStorage):
|
||||
1. Runtime overrides (via `setRuntimeApiKey`, not persisted)
|
||||
2. Stored credentials in `auth.json` (API keys or OAuth tokens)
|
||||
3. Environment variables (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.)
|
||||
4. Fallback resolver (for custom provider keys from `models.json`)
|
||||
|
||||
```typescript
|
||||
import { defaultGetApiKey, configureOAuthStorage } from "@mariozechner/pi-coding-agent";
|
||||
import { AuthStorage, ModelRegistry, discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Default: checks settings.json, models.json, OAuth, environment variables
|
||||
const { session } = await createAgentSession();
|
||||
// Default: uses ~/.pi/agent/auth.json and ~/.pi/agent/models.json
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
// Custom resolver
|
||||
const { session } = await createAgentSession({
|
||||
getApiKey: async (model) => {
|
||||
// Custom logic (secrets manager, database, etc.)
|
||||
if (model.provider === "anthropic") {
|
||||
return process.env.MY_ANTHROPIC_KEY;
|
||||
}
|
||||
// Fall back to default (pass settingsManager for settings.json lookup)
|
||||
return defaultGetApiKey()(model);
|
||||
},
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Use OAuth from ~/.pi/agent with custom agentDir for everything else
|
||||
configureOAuthStorage(); // Must call before createAgentSession
|
||||
// Runtime API key override (not persisted to disk)
|
||||
authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
|
||||
|
||||
// Custom auth storage location
|
||||
const customAuth = new AuthStorage("/my/app/auth.json");
|
||||
const customRegistry = new ModelRegistry(customAuth, "/my/app/models.json");
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
agentDir: "/custom/config",
|
||||
// OAuth tokens still come from ~/.pi/agent/oauth.json
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage: customAuth,
|
||||
modelRegistry: customRegistry,
|
||||
});
|
||||
|
||||
// No custom models.json (built-in models only)
|
||||
const simpleRegistry = new ModelRegistry(authStorage);
|
||||
```
|
||||
|
||||
> See [examples/sdk/09-api-keys-and-oauth.ts](../examples/sdk/09-api-keys-and-oauth.ts)
|
||||
|
|
@ -630,10 +650,12 @@ Project overrides global. Nested objects merge keys. Setters only modify global
|
|||
All discovery functions accept optional `cwd` and `agentDir` parameters.
|
||||
|
||||
```typescript
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
AuthStorage,
|
||||
ModelRegistry,
|
||||
discoverAuthStorage,
|
||||
discoverModels,
|
||||
discoverAvailableModels,
|
||||
findModel,
|
||||
discoverSkills,
|
||||
discoverHooks,
|
||||
discoverCustomTools,
|
||||
|
|
@ -643,10 +665,13 @@ import {
|
|||
buildSystemPrompt,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Models
|
||||
const allModels = discoverModels();
|
||||
const available = await discoverAvailableModels();
|
||||
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
// Auth and Models
|
||||
const authStorage = discoverAuthStorage(); // ~/.pi/agent/auth.json
|
||||
const modelRegistry = discoverModels(authStorage); // + ~/.pi/agent/models.json
|
||||
const allModels = modelRegistry.getAll(); // All models (built-in + custom)
|
||||
const available = await modelRegistry.getAvailable(); // Only models with API keys
|
||||
const model = modelRegistry.find("provider", "id"); // Find specific model
|
||||
const builtIn = getModel("anthropic", "claude-opus-4-5"); // Built-in only
|
||||
|
||||
// Skills
|
||||
const skills = discoverSkills(cwd, agentDir, skillsSettings);
|
||||
|
|
@ -698,12 +723,12 @@ interface CreateAgentSessionResult {
|
|||
## Complete Example
|
||||
|
||||
```typescript
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import {
|
||||
AuthStorage,
|
||||
createAgentSession,
|
||||
configureOAuthStorage,
|
||||
defaultGetApiKey,
|
||||
findModel,
|
||||
ModelRegistry,
|
||||
SessionManager,
|
||||
SettingsManager,
|
||||
readTool,
|
||||
|
|
@ -711,18 +736,17 @@ import {
|
|||
type HookFactory,
|
||||
type CustomAgentTool,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
import { getAgentDir } from "@mariozechner/pi-coding-agent/config";
|
||||
|
||||
// Use OAuth from default location
|
||||
configureOAuthStorage(getAgentDir());
|
||||
// Set up auth storage (custom location)
|
||||
const authStorage = new AuthStorage("/custom/agent/auth.json");
|
||||
|
||||
// Custom API key with fallback
|
||||
const getApiKey = async (model: { provider: string }) => {
|
||||
if (model.provider === "anthropic" && process.env.MY_KEY) {
|
||||
return process.env.MY_KEY;
|
||||
}
|
||||
return defaultGetApiKey()(model as any);
|
||||
};
|
||||
// Runtime API key override (not persisted)
|
||||
if (process.env.MY_KEY) {
|
||||
authStorage.setRuntimeApiKey("anthropic", process.env.MY_KEY);
|
||||
}
|
||||
|
||||
// Model registry (no custom models.json)
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
// Inline hook
|
||||
const auditHook: HookFactory = (api) => {
|
||||
|
|
@ -744,8 +768,7 @@ const statusTool: CustomAgentTool = {
|
|||
}),
|
||||
};
|
||||
|
||||
const { model, error } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
if (error) throw new Error(error);
|
||||
const model = getModel("anthropic", "claude-opus-4-5");
|
||||
if (!model) throw new Error("Model not found");
|
||||
|
||||
// In-memory settings with overrides
|
||||
|
|
@ -760,7 +783,8 @@ const { session } = await createAgentSession({
|
|||
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
getApiKey,
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
|
||||
systemPrompt: "You are a minimal assistant. Be concise.",
|
||||
|
||||
|
|
@ -812,12 +836,14 @@ The main entry point exports:
|
|||
```typescript
|
||||
// Factory
|
||||
createAgentSession
|
||||
configureOAuthStorage
|
||||
|
||||
// Auth and Models
|
||||
AuthStorage
|
||||
ModelRegistry
|
||||
discoverAuthStorage
|
||||
discoverModels
|
||||
|
||||
// Discovery
|
||||
discoverModels
|
||||
discoverAvailableModels
|
||||
findModel
|
||||
discoverSkills
|
||||
discoverHooks
|
||||
discoverCustomTools
|
||||
|
|
@ -825,7 +851,6 @@ discoverContextFiles
|
|||
discoverSlashCommands
|
||||
|
||||
// Helpers
|
||||
defaultGetApiKey
|
||||
loadSettings
|
||||
buildSystemPrompt
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@
|
|||
* pi --hook examples/hooks/custom-compaction.ts
|
||||
*/
|
||||
|
||||
import { complete } from "@mariozechner/pi-ai";
|
||||
import { findModel, messageTransformer } from "@mariozechner/pi-coding-agent";
|
||||
import { complete, getModel } from "@mariozechner/pi-ai";
|
||||
import { messageTransformer } from "@mariozechner/pi-coding-agent";
|
||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
|
|
@ -27,10 +27,9 @@ export default function (pi: HookAPI) {
|
|||
event;
|
||||
|
||||
// Use Gemini Flash for summarization (cheaper/faster than most conversation models)
|
||||
// findModel searches both built-in models and custom models from models.json
|
||||
const { model, error } = findModel("google", "gemini-2.5-flash");
|
||||
if (error || !model) {
|
||||
ctx.ui.notify(`Could not find Gemini Flash model: ${error}, using default compaction`, "warning");
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
if (!model) {
|
||||
ctx.ui.notify(`Could not find Gemini Flash model, using default compaction`, "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,16 +4,27 @@
|
|||
* Shows how to select a specific model and thinking level.
|
||||
*/
|
||||
|
||||
import { createAgentSession, discoverAvailableModels, findModel } from "../../src/index.js";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { createAgentSession, discoverAuthStorage, discoverModels } from "../../src/index.js";
|
||||
|
||||
// Option 1: Find a specific model by provider/id
|
||||
const { model: sonnet } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
if (sonnet) {
|
||||
console.log(`Found model: ${sonnet.provider}/${sonnet.id}`);
|
||||
// Set up auth storage and model registry
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
// Option 1: Find a specific built-in model by provider/id
|
||||
const opus = getModel("anthropic", "claude-opus-4-5");
|
||||
if (opus) {
|
||||
console.log(`Found model: ${opus.provider}/${opus.id}`);
|
||||
}
|
||||
|
||||
// Option 2: Pick from available models (have valid API keys)
|
||||
const available = await discoverAvailableModels();
|
||||
// Option 2: Find model via registry (includes custom models from models.json)
|
||||
const customModel = modelRegistry.find("my-provider", "my-model");
|
||||
if (customModel) {
|
||||
console.log(`Found custom model: ${customModel.provider}/${customModel.id}`);
|
||||
}
|
||||
|
||||
// Option 3: Pick from available models (have valid API keys)
|
||||
const available = await modelRegistry.getAvailable();
|
||||
console.log(
|
||||
"Available models:",
|
||||
available.map((m) => `${m.provider}/${m.id}`),
|
||||
|
|
@ -23,6 +34,8 @@ if (available.length > 0) {
|
|||
const { session } = await createAgentSession({
|
||||
model: available[0],
|
||||
thinkingLevel: "medium", // off, low, medium, high
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
session.subscribe((event) => {
|
||||
|
|
|
|||
|
|
@ -1,40 +1,55 @@
|
|||
/**
|
||||
* API Keys and OAuth
|
||||
*
|
||||
* Configure API key resolution. Default checks: models.json, OAuth, env vars.
|
||||
* Configure API key resolution via AuthStorage and ModelRegistry.
|
||||
*/
|
||||
|
||||
import { getAgentDir } from "../../src/config.js";
|
||||
import { configureOAuthStorage, createAgentSession, defaultGetApiKey, SessionManager } from "../../src/index.js";
|
||||
import {
|
||||
AuthStorage,
|
||||
createAgentSession,
|
||||
discoverAuthStorage,
|
||||
discoverModels,
|
||||
ModelRegistry,
|
||||
SessionManager,
|
||||
} from "../../src/index.js";
|
||||
|
||||
// Default: uses env vars (ANTHROPIC_API_KEY, etc.), OAuth, and models.json
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
console.log("Session with default API key resolution");
|
||||
|
||||
// Custom resolver
|
||||
await createAgentSession({
|
||||
getApiKey: async (model) => {
|
||||
// Custom logic (secrets manager, database, etc.)
|
||||
if (model.provider === "anthropic") {
|
||||
return process.env.MY_ANTHROPIC_KEY;
|
||||
}
|
||||
// Fall back to default
|
||||
return defaultGetApiKey()(model);
|
||||
},
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
console.log("Session with custom API key resolver");
|
||||
|
||||
// Use OAuth from ~/.pi/agent while customizing everything else
|
||||
configureOAuthStorage(getAgentDir()); // Must call before createAgentSession
|
||||
// Default: discoverAuthStorage() uses ~/.pi/agent/auth.json
|
||||
// discoverModels() loads built-in + custom models from ~/.pi/agent/models.json
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
await createAgentSession({
|
||||
agentDir: "/tmp/custom-config", // Custom config location
|
||||
// But OAuth tokens still come from ~/.pi/agent/oauth.json
|
||||
systemPrompt: "You are helpful.",
|
||||
skills: [],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
console.log("Session with OAuth from default location, custom config elsewhere");
|
||||
console.log("Session with default auth storage and model registry");
|
||||
|
||||
// Custom auth storage location
|
||||
const customAuthStorage = new AuthStorage("/tmp/my-app/auth.json");
|
||||
const customModelRegistry = new ModelRegistry(customAuthStorage, "/tmp/my-app/models.json");
|
||||
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage: customAuthStorage,
|
||||
modelRegistry: customModelRegistry,
|
||||
});
|
||||
console.log("Session with custom auth storage location");
|
||||
|
||||
// Runtime API key override (not persisted to disk)
|
||||
authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
console.log("Session with runtime API key override");
|
||||
|
||||
// No models.json - only built-in models
|
||||
const simpleRegistry = new ModelRegistry(authStorage); // null = no models.json
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry: simpleRegistry,
|
||||
});
|
||||
console.log("Session with only built-in models");
|
||||
|
|
|
|||
|
|
@ -2,38 +2,36 @@
|
|||
* Full Control
|
||||
*
|
||||
* Replace everything - no discovery, explicit configuration.
|
||||
* Still uses OAuth from ~/.pi/agent for convenience.
|
||||
*
|
||||
* IMPORTANT: When providing `tools` with a custom `cwd`, use the tool factory
|
||||
* functions (createReadTool, createBashTool, etc.) to ensure tools resolve
|
||||
* paths relative to your cwd.
|
||||
*/
|
||||
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { getAgentDir } from "../../src/config.js";
|
||||
import {
|
||||
AuthStorage,
|
||||
type CustomAgentTool,
|
||||
configureOAuthStorage,
|
||||
createAgentSession,
|
||||
createBashTool,
|
||||
createReadTool,
|
||||
defaultGetApiKey,
|
||||
findModel,
|
||||
type HookFactory,
|
||||
ModelRegistry,
|
||||
SessionManager,
|
||||
SettingsManager,
|
||||
} from "../../src/index.js";
|
||||
|
||||
// Use OAuth from default location
|
||||
configureOAuthStorage(getAgentDir());
|
||||
// Custom auth storage location
|
||||
const authStorage = new AuthStorage("/tmp/my-agent/auth.json");
|
||||
|
||||
// Custom API key with fallback
|
||||
const getApiKey = async (model: { provider: string }) => {
|
||||
if (model.provider === "anthropic" && process.env.MY_ANTHROPIC_KEY) {
|
||||
return process.env.MY_ANTHROPIC_KEY;
|
||||
}
|
||||
return defaultGetApiKey()(model as any);
|
||||
};
|
||||
// Runtime API key override (not persisted)
|
||||
if (process.env.MY_ANTHROPIC_KEY) {
|
||||
authStorage.setRuntimeApiKey("anthropic", process.env.MY_ANTHROPIC_KEY);
|
||||
}
|
||||
|
||||
// Model registry with no custom models.json
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
// Inline hook
|
||||
const auditHook: HookFactory = (api) => {
|
||||
|
|
@ -55,7 +53,7 @@ const statusTool: CustomAgentTool = {
|
|||
}),
|
||||
};
|
||||
|
||||
const { model } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
const model = getModel("anthropic", "claude-opus-4-5");
|
||||
if (!model) throw new Error("Model not found");
|
||||
|
||||
// In-memory settings with overrides
|
||||
|
|
@ -73,7 +71,8 @@ const { session } = await createAgentSession({
|
|||
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
getApiKey,
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
|
||||
systemPrompt: `You are a minimal assistant.
|
||||
Available: read, bash, status. Be concise.`,
|
||||
|
|
|
|||
|
|
@ -29,50 +29,63 @@ npx tsx examples/sdk/01-minimal.ts
|
|||
## Quick Reference
|
||||
|
||||
```typescript
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
AuthStorage,
|
||||
createAgentSession,
|
||||
configureOAuthStorage,
|
||||
discoverAuthStorage,
|
||||
discoverModels,
|
||||
discoverSkills,
|
||||
discoverHooks,
|
||||
discoverCustomTools,
|
||||
discoverContextFiles,
|
||||
discoverSlashCommands,
|
||||
discoverAvailableModels,
|
||||
findModel,
|
||||
defaultGetApiKey,
|
||||
loadSettings,
|
||||
buildSystemPrompt,
|
||||
ModelRegistry,
|
||||
SessionManager,
|
||||
codingTools,
|
||||
readOnlyTools,
|
||||
readTool, bashTool, editTool, writeTool,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Auth and models setup
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
|
||||
// Minimal
|
||||
const { session } = await createAgentSession();
|
||||
const { session } = await createAgentSession({ authStorage, modelRegistry });
|
||||
|
||||
// Custom model
|
||||
const { model } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
const { session } = await createAgentSession({ model, thinkingLevel: "high" });
|
||||
const model = getModel("anthropic", "claude-opus-4-5");
|
||||
const { session } = await createAgentSession({ model, thinkingLevel: "high", authStorage, modelRegistry });
|
||||
|
||||
// Modify prompt
|
||||
const { session } = await createAgentSession({
|
||||
systemPrompt: (defaultPrompt) => defaultPrompt + "\n\nBe concise.",
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Read-only
|
||||
const { session } = await createAgentSession({ tools: readOnlyTools });
|
||||
const { session } = await createAgentSession({ tools: readOnlyTools, authStorage, modelRegistry });
|
||||
|
||||
// In-memory
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Full control
|
||||
configureOAuthStorage(); // Use OAuth from ~/.pi/agent
|
||||
const customAuth = new AuthStorage("/my/app/auth.json");
|
||||
customAuth.setRuntimeApiKey("anthropic", process.env.MY_KEY!);
|
||||
const customRegistry = new ModelRegistry(customAuth);
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
model,
|
||||
getApiKey: async (m) => process.env.MY_KEY,
|
||||
authStorage: customAuth,
|
||||
modelRegistry: customRegistry,
|
||||
systemPrompt: "You are helpful.",
|
||||
tools: [readTool, bashTool],
|
||||
customTools: [{ tool: myTool }],
|
||||
|
|
@ -81,7 +94,6 @@ const { session } = await createAgentSession({
|
|||
contextFiles: [],
|
||||
slashCommands: [],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
settings: { compaction: { enabled: false } },
|
||||
});
|
||||
|
||||
// Run prompts
|
||||
|
|
@ -97,11 +109,12 @@ await session.prompt("Hello");
|
|||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `authStorage` | `discoverAuthStorage()` | Credential storage |
|
||||
| `modelRegistry` | `discoverModels(authStorage)` | Model registry |
|
||||
| `cwd` | `process.cwd()` | Working directory |
|
||||
| `agentDir` | `~/.pi/agent` | Config directory |
|
||||
| `model` | From settings/first available | Model to use |
|
||||
| `thinkingLevel` | From settings/"off" | off, low, medium, high |
|
||||
| `getApiKey` | Built-in resolver | API key function |
|
||||
| `systemPrompt` | Discovered | String or `(default) => modified` |
|
||||
| `tools` | `codingTools` | Built-in tools |
|
||||
| `customTools` | Discovered | Replaces discovery |
|
||||
|
|
@ -112,7 +125,7 @@ await session.prompt("Hello");
|
|||
| `contextFiles` | Discovered | AGENTS.md files |
|
||||
| `slashCommands` | Discovered | File commands |
|
||||
| `sessionManager` | `SessionManager.create(cwd)` | Persistence |
|
||||
| `settings` | From agentDir | Overrides |
|
||||
| `settingsManager` | From agentDir | Settings overrides |
|
||||
|
||||
## Events
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
*/
|
||||
|
||||
import type { Api, Model } from "@mariozechner/pi-ai";
|
||||
import { getAvailableModels } from "../core/models-json.js";
|
||||
import type { SettingsManager } from "../core/settings-manager.js";
|
||||
import type { ModelRegistry } from "../core/model-registry.js";
|
||||
import { fuzzyFilter } from "../utils/fuzzy.js";
|
||||
|
||||
/**
|
||||
|
|
@ -25,16 +24,8 @@ function formatTokenCount(count: number): string {
|
|||
/**
|
||||
* List available models, optionally filtered by search pattern
|
||||
*/
|
||||
export async function listModels(searchPattern?: string, settingsManager?: SettingsManager): Promise<void> {
|
||||
const { models, error } = await getAvailableModels(
|
||||
undefined,
|
||||
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.error(error);
|
||||
process.exit(1);
|
||||
}
|
||||
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
|
||||
const models = await modelRegistry.getAvailable();
|
||||
|
||||
if (models.length === 0) {
|
||||
console.log("No models available. Set API keys in environment variables.");
|
||||
|
|
|
|||
|
|
@ -112,9 +112,9 @@ export function getModelsPath(): string {
|
|||
return join(getAgentDir(), "models.json");
|
||||
}
|
||||
|
||||
/** Get path to oauth.json */
|
||||
export function getOAuthPath(): string {
|
||||
return join(getAgentDir(), "oauth.json");
|
||||
/** Get path to auth.json */
|
||||
export function getAuthPath(): string {
|
||||
return join(getAgentDir(), "auth.json");
|
||||
}
|
||||
|
||||
/** Get path to settings.json */
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custo
|
|||
import { exportSessionToHtml } from "./export-html.js";
|
||||
import type { HookRunner, SessionEventResult, TurnEndEvent, TurnStartEvent } from "./hooks/index.js";
|
||||
import type { BashExecutionMessage } from "./messages.js";
|
||||
import { getApiKeyForModel, getAvailableModels } from "./models-json.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
import type { CompactionEntry, SessionManager } from "./session-manager.js";
|
||||
import type { SettingsManager, SkillsSettings } from "./settings-manager.js";
|
||||
import { expandSlashCommand, type FileSlashCommand } from "./slash-commands.js";
|
||||
|
|
@ -56,8 +56,8 @@ export interface AgentSessionConfig {
|
|||
/** Custom tools for session lifecycle events */
|
||||
customTools?: LoadedCustomTool[];
|
||||
skillsSettings?: Required<SkillsSettings>;
|
||||
/** Resolve API key for a model. Default: getApiKeyForModel */
|
||||
resolveApiKey?: (model: Model<any>) => Promise<string | undefined>;
|
||||
/** Model registry for API key resolution and model discovery */
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
|
||||
/** Options for AgentSession.prompt() */
|
||||
|
|
@ -153,8 +153,8 @@ export class AgentSession {
|
|||
|
||||
private _skillsSettings: Required<SkillsSettings> | undefined;
|
||||
|
||||
// API key resolver
|
||||
private _resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
|
||||
// Model registry for API key resolution
|
||||
private _modelRegistry: ModelRegistry;
|
||||
|
||||
constructor(config: AgentSessionConfig) {
|
||||
this.agent = config.agent;
|
||||
|
|
@ -165,7 +165,12 @@ export class AgentSession {
|
|||
this._hookRunner = config.hookRunner ?? null;
|
||||
this._customTools = config.customTools ?? [];
|
||||
this._skillsSettings = config.skillsSettings;
|
||||
this._resolveApiKey = config.resolveApiKey ?? getApiKeyForModel;
|
||||
this._modelRegistry = config.modelRegistry;
|
||||
}
|
||||
|
||||
/** Model registry for API key resolution and model discovery */
|
||||
get modelRegistry(): ModelRegistry {
|
||||
return this._modelRegistry;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
|
|
@ -434,7 +439,7 @@ export class AgentSession {
|
|||
}
|
||||
|
||||
// Validate API key
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(
|
||||
`No API key found for ${this.model.provider}.\n\n` +
|
||||
|
|
@ -561,7 +566,7 @@ export class AgentSession {
|
|||
* @throws Error if no API key available for the model
|
||||
*/
|
||||
async setModel(model: Model<any>): Promise<void> {
|
||||
const apiKey = await this._resolveApiKey(model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${model.provider}/${model.id}`);
|
||||
}
|
||||
|
|
@ -599,7 +604,7 @@ export class AgentSession {
|
|||
const next = this._scopedModels[nextIndex];
|
||||
|
||||
// Validate API key
|
||||
const apiKey = await this._resolveApiKey(next.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(next.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${next.model.provider}/${next.model.id}`);
|
||||
}
|
||||
|
|
@ -616,10 +621,7 @@ export class AgentSession {
|
|||
}
|
||||
|
||||
private async _cycleAvailableModel(): Promise<ModelCycleResult | null> {
|
||||
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
|
||||
this.settingsManager.getApiKey(provider),
|
||||
);
|
||||
if (error) throw new Error(`Failed to load models: ${error}`);
|
||||
const availableModels = await this._modelRegistry.getAvailable();
|
||||
if (availableModels.length <= 1) return null;
|
||||
|
||||
const currentModel = this.model;
|
||||
|
|
@ -631,7 +633,7 @@ export class AgentSession {
|
|||
const nextIndex = (currentIndex + 1) % availableModels.length;
|
||||
const nextModel = availableModels[nextIndex];
|
||||
|
||||
const apiKey = await this._resolveApiKey(nextModel);
|
||||
const apiKey = await this._modelRegistry.getApiKey(nextModel);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`);
|
||||
}
|
||||
|
|
@ -650,11 +652,7 @@ export class AgentSession {
|
|||
* Get all available models with valid API keys.
|
||||
*/
|
||||
async getAvailableModels(): Promise<Model<any>[]> {
|
||||
const { models, error } = await getAvailableModels(undefined, (provider) =>
|
||||
this.settingsManager.getApiKey(provider),
|
||||
);
|
||||
if (error) throw new Error(error);
|
||||
return models;
|
||||
return this._modelRegistry.getAvailable();
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
|
|
@ -747,7 +745,7 @@ export class AgentSession {
|
|||
throw new Error("No model selected");
|
||||
}
|
||||
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${this.model.provider}`);
|
||||
}
|
||||
|
|
@ -786,7 +784,7 @@ export class AgentSession {
|
|||
tokensBefore: preparation.tokensBefore,
|
||||
customInstructions,
|
||||
model: this.model,
|
||||
resolveApiKey: this._resolveApiKey,
|
||||
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
|
||||
signal: this._compactionAbortController.signal,
|
||||
})) as SessionEventResult | undefined;
|
||||
|
||||
|
|
@ -908,7 +906,7 @@ export class AgentSession {
|
|||
return;
|
||||
}
|
||||
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false });
|
||||
return;
|
||||
|
|
@ -948,7 +946,7 @@ export class AgentSession {
|
|||
tokensBefore: preparation.tokensBefore,
|
||||
customInstructions: undefined,
|
||||
model: this.model,
|
||||
resolveApiKey: this._resolveApiKey,
|
||||
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
|
||||
signal: this._autoCompactionAbortController.signal,
|
||||
})) as SessionEventResult | undefined;
|
||||
|
||||
|
|
@ -1334,9 +1332,7 @@ export class AgentSession {
|
|||
|
||||
// Restore model if saved
|
||||
if (sessionContext.model) {
|
||||
const availableModels = (
|
||||
await getAvailableModels(undefined, (provider) => this.settingsManager.getApiKey(provider))
|
||||
).models;
|
||||
const availableModels = await this._modelRegistry.getAvailable();
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === sessionContext.model!.provider && m.id === sessionContext.model!.modelId,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -3,9 +3,18 @@
|
|||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*/
|
||||
|
||||
import { getApiKeyFromEnv, getOAuthApiKey, type OAuthCredentials, type OAuthProvider } from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname } from "path";
|
||||
import {
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
loginAnthropic,
|
||||
loginAntigravity,
|
||||
loginGeminiCli,
|
||||
loginGitHubCopilot,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, renameSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
|
|
@ -25,11 +34,29 @@ export type AuthStorageData = Record<string, AuthCredential>;
|
|||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
private runtimeOverrides: Map<string, string> = new Map();
|
||||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
|
||||
constructor(private authPath: string) {
|
||||
this.reload();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a runtime API key override (not persisted to disk).
|
||||
* Used for CLI --api-key flag.
|
||||
*/
|
||||
setRuntimeApiKey(provider: string, apiKey: string): void {
|
||||
this.runtimeOverrides.set(provider, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a fallback resolver for API keys not found in auth.json or env vars.
|
||||
* Used for custom provider keys from models.json.
|
||||
*/
|
||||
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
|
||||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from disk.
|
||||
*/
|
||||
|
|
@ -101,14 +128,69 @@ export class AuthStorage {
|
|||
return { ...this.data };
|
||||
}
|
||||
|
||||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(
|
||||
provider: OAuthProvider,
|
||||
callbacks: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
},
|
||||
): Promise<void> {
|
||||
let credentials: OAuthCredentials;
|
||||
|
||||
switch (provider) {
|
||||
case "anthropic":
|
||||
credentials = await loginAnthropic(
|
||||
(url) => callbacks.onAuth({ url }),
|
||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||
);
|
||||
break;
|
||||
case "github-copilot":
|
||||
credentials = await loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
});
|
||||
break;
|
||||
case "google-gemini-cli":
|
||||
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress);
|
||||
break;
|
||||
case "google-antigravity":
|
||||
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress);
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
||||
}
|
||||
|
||||
this.set(provider, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout from a provider.
|
||||
*/
|
||||
logout(provider: string): void {
|
||||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. API key from auth.json
|
||||
* 2. OAuth token from auth.json (auto-refreshed)
|
||||
* 3. Environment variable (via getApiKeyFromEnv)
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(provider: string): Promise<string | null> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(provider);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[provider];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
|
|
@ -116,30 +198,83 @@ export class AuthStorage {
|
|||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
// Build OAuthCredentials map (without type discriminator)
|
||||
// Filter to only oauth credentials for getOAuthApiKey
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(this.data)) {
|
||||
if (value.type === "oauth") {
|
||||
const { type: _, ...rest } = value;
|
||||
oauthCreds[key] = rest;
|
||||
oauthCreds[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
|
||||
if (result) {
|
||||
// Save refreshed credentials
|
||||
this.data[provider] = { type: "oauth", ...result.newCredentials };
|
||||
this.save();
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch {
|
||||
// Token refresh failed, remove invalid credentials
|
||||
this.remove(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
return getApiKeyFromEnv(provider) ?? null;
|
||||
const envKey = getEnvApiKey(provider);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(provider) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate credentials from legacy oauth.json and settings.json apiKeys to auth.json.
|
||||
* Only runs if auth.json doesn't exist yet. Returns list of migrated providers.
|
||||
*/
|
||||
static migrateLegacy(authPath: string, agentDir: string): string[] {
|
||||
const oauthPath = join(agentDir, "oauth.json");
|
||||
const settingsPath = join(agentDir, "settings.json");
|
||||
|
||||
// Skip if auth.json already exists
|
||||
if (existsSync(authPath)) return [];
|
||||
|
||||
const migrated: AuthStorageData = {};
|
||||
const providers: string[] = [];
|
||||
|
||||
// Migrate oauth.json
|
||||
if (existsSync(oauthPath)) {
|
||||
try {
|
||||
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
|
||||
for (const [provider, cred] of Object.entries(oauth)) {
|
||||
migrated[provider] = { type: "oauth", ...(cred as object) } as OAuthCredential;
|
||||
providers.push(provider);
|
||||
}
|
||||
renameSync(oauthPath, `${oauthPath}.migrated`);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
// Migrate settings.json apiKeys
|
||||
if (existsSync(settingsPath)) {
|
||||
try {
|
||||
const content = readFileSync(settingsPath, "utf-8");
|
||||
const settings = JSON.parse(content);
|
||||
if (settings.apiKeys && typeof settings.apiKeys === "object") {
|
||||
for (const [provider, key] of Object.entries(settings.apiKeys)) {
|
||||
if (!migrated[provider] && typeof key === "string") {
|
||||
migrated[provider] = { type: "api_key", key };
|
||||
providers.push(provider);
|
||||
}
|
||||
}
|
||||
delete settings.apiKeys;
|
||||
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
if (Object.keys(migrated).length > 0) {
|
||||
mkdirSync(dirname(authPath), { recursive: true });
|
||||
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
|
||||
}
|
||||
|
||||
return providers;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
315
packages/coding-agent/src/core/model-registry.ts
Normal file
315
packages/coding-agent/src/core/model-registry.ts
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
/**
|
||||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
getGitHubCopilotBaseUrl,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
normalizeDomain,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
});
|
||||
|
||||
// Schema for custom model definition
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
contextWindow: Type.Number(),
|
||||
maxTokens: Type.Number(),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.String({ minLength: 1 }),
|
||||
apiKey: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Array(ModelDefinitionSchema),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* Checks environment variable first, then treats as literal.
|
||||
*/
|
||||
function resolveApiKeyConfig(keyConfig: string): string | undefined {
|
||||
const envValue = process.env[keyConfig];
|
||||
if (envValue) return envValue;
|
||||
return keyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private loadError: string | null = null;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | null = null,
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
if (keyConfig) {
|
||||
return resolveApiKeyConfig(keyConfig);
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = null;
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (null if no error).
|
||||
*/
|
||||
getError(): string | null {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load built-in models
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
for (const provider of getProviders()) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models from models.json (if path provided)
|
||||
let customModels: Model<Api>[] = [];
|
||||
if (this.modelsJsonPath) {
|
||||
const result = this.loadCustomModels(this.modelsJsonPath);
|
||||
if (result.error) {
|
||||
this.loadError = result.error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
} else {
|
||||
customModels = result.models;
|
||||
}
|
||||
}
|
||||
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth credentials
|
||||
const copilotCred = this.authStorage.get("github-copilot");
|
||||
if (copilotCred?.type === "oauth") {
|
||||
const domain = copilotCred.enterpriseUrl
|
||||
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
|
||||
: undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
|
||||
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
} else {
|
||||
this.models = combined;
|
||||
}
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: null };
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const ajv = new Ajv();
|
||||
const validate = ajv.compile(ModelsConfigSchema);
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
// Parse models
|
||||
return { models: this.parseModels(config), error: null };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
||||
if (modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Store API key config for fallback resolver
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
if (!api) continue;
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
let headers =
|
||||
providerConfig.headers || modelDef.headers
|
||||
? { ...providerConfig.headers, ...modelDef.headers }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader) {
|
||||
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have valid API keys available.
|
||||
*/
|
||||
async getAvailable(): Promise<Model<Api>[]> {
|
||||
const available: Model<Api>[] = [];
|
||||
for (const model of this.models) {
|
||||
const apiKey = await this.authStorage.getApiKey(model.provider);
|
||||
if (apiKey) {
|
||||
available.push(model);
|
||||
}
|
||||
}
|
||||
return available;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | null {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | null> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
*/
|
||||
isUsingOAuth(model: Model<Api>): boolean {
|
||||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
}
|
||||
|
|
@ -6,8 +6,7 @@ import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
|||
import type { Api, KnownProvider, Model } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { isValidThinkingLevel } from "../cli/args.js";
|
||||
import { findModel, getApiKeyForModel, getAvailableModels } from "./models-json.js";
|
||||
import type { SettingsManager } from "./settings-manager.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
/** Default model IDs for each known provider */
|
||||
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
||||
|
|
@ -167,21 +166,9 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
|
|||
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
||||
* The algorithm tries to match the full pattern first, then progressively
|
||||
* strips colon-suffixes to find a match.
|
||||
*
|
||||
* @param patterns - Model patterns to resolve
|
||||
* @param settingsManager - Optional settings manager for API key fallback from settings.json
|
||||
*/
|
||||
export async function resolveModelScope(patterns: string[], settingsManager?: SettingsManager): Promise<ScopedModel[]> {
|
||||
const { models: availableModels, error } = await getAvailableModels(
|
||||
undefined,
|
||||
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.warn(chalk.yellow(`Warning: Error loading models: ${error}`));
|
||||
return [];
|
||||
}
|
||||
|
||||
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
const scopedModels: ScopedModel[] = [];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
|
|
@ -224,20 +211,28 @@ export async function findInitialModel(options: {
|
|||
cliModel?: string;
|
||||
scopedModels: ScopedModel[];
|
||||
isContinuing: boolean;
|
||||
settingsManager: SettingsManager;
|
||||
defaultProvider?: string;
|
||||
defaultModelId?: string;
|
||||
defaultThinkingLevel?: ThinkingLevel;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<InitialModelResult> {
|
||||
const { cliProvider, cliModel, scopedModels, isContinuing, settingsManager } = options;
|
||||
const {
|
||||
cliProvider,
|
||||
cliModel,
|
||||
scopedModels,
|
||||
isContinuing,
|
||||
defaultProvider,
|
||||
defaultModelId,
|
||||
defaultThinkingLevel,
|
||||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | null = null;
|
||||
let thinkingLevel: ThinkingLevel = "off";
|
||||
|
||||
// 1. CLI args take priority
|
||||
if (cliProvider && cliModel) {
|
||||
const { model: found, error } = findModel(cliProvider, cliModel);
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const found = modelRegistry.find(cliProvider, cliModel);
|
||||
if (!found) {
|
||||
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
|
||||
process.exit(1);
|
||||
|
|
@ -255,34 +250,19 @@ export async function findInitialModel(options: {
|
|||
}
|
||||
|
||||
// 3. Try saved default from settings
|
||||
const defaultProvider = settingsManager.getDefaultProvider();
|
||||
const defaultModelId = settingsManager.getDefaultModel();
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const { model: found, error } = findModel(defaultProvider, defaultModelId);
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (found) {
|
||||
model = found;
|
||||
// Also load saved thinking level
|
||||
const savedThinking = settingsManager.getDefaultThinkingLevel();
|
||||
if (savedThinking) {
|
||||
thinkingLevel = savedThinking;
|
||||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: null };
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try first available model with valid API key
|
||||
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
|
||||
settingsManager.getApiKey(provider),
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
|
|
@ -310,17 +290,12 @@ export async function restoreModelFromSession(
|
|||
savedModelId: string,
|
||||
currentModel: Model<Api> | null,
|
||||
shouldPrintMessages: boolean,
|
||||
settingsManager?: SettingsManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
|
||||
const { model: restoredModel, error } = findModel(savedProvider, savedModelId);
|
||||
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
const hasApiKey = restoredModel ? !!(await getApiKeyForModel(restoredModel)) : false;
|
||||
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
|
||||
|
||||
if (restoredModel && hasApiKey) {
|
||||
if (shouldPrintMessages) {
|
||||
|
|
@ -348,14 +323,7 @@ export async function restoreModelFromSession(
|
|||
}
|
||||
|
||||
// Try to find any available model
|
||||
const { models: availableModels, error: availableError } = await getAvailableModels(
|
||||
undefined,
|
||||
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
|
||||
);
|
||||
if (availableError) {
|
||||
console.error(chalk.red(availableError));
|
||||
process.exit(1);
|
||||
}
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
|
|
|
|||
|
|
@ -1,467 +0,0 @@
|
|||
import {
|
||||
type Api,
|
||||
getApiKey,
|
||||
getGitHubCopilotBaseUrl,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
loadOAuthCredentials,
|
||||
type Model,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
removeOAuthCredentials,
|
||||
saveOAuthCredentials,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
});
|
||||
|
||||
// Schema for custom model definition
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
contextWindow: Type.Number(),
|
||||
maxTokens: Type.Number(),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.String({ minLength: 1 }),
|
||||
apiKey: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Array(ModelDefinitionSchema),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
// Custom provider API key mappings (provider name -> apiKey config)
|
||||
const customProviderApiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* First checks if it's an environment variable, then treats as literal.
|
||||
*/
|
||||
export function resolveApiKey(keyConfig: string): string | undefined {
|
||||
// First check if it's an env var name
|
||||
const envValue = process.env[keyConfig];
|
||||
if (envValue) return envValue;
|
||||
|
||||
// Otherwise treat as literal API key
|
||||
return keyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom models from a models.json file
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
function loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: null };
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const ajv = new Ajv();
|
||||
const validate = ajv.compile(ModelsConfigSchema);
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
try {
|
||||
validateConfig(config);
|
||||
} catch (error) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Parse models
|
||||
return { models: parseModels(config), error: null };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate config structure and requirements
|
||||
*/
|
||||
function validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
|
||||
`Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
||||
if (modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse config into Model objects
|
||||
*/
|
||||
function parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
// Clear and rebuild custom provider API key mappings
|
||||
customProviderApiKeys.clear();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Store API key config for this provider
|
||||
customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
// Model-level api overrides provider-level api
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
|
||||
if (!api) {
|
||||
// This should have been caught by validateConfig, but be safe
|
||||
continue;
|
||||
}
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
let headers =
|
||||
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader) {
|
||||
const resolvedKey = resolveApiKey(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom), freshly loaded
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model<Api>[]; error: string | null } {
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
const providers = getProviders();
|
||||
|
||||
// Load all built-in models
|
||||
for (const provider of providers) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models
|
||||
const { models: customModels, error } = loadCustomModels(join(agentDir, "models.json"));
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth token or enterprise domain
|
||||
const copilotCreds = loadOAuthCredentials("github-copilot");
|
||||
if (copilotCreds) {
|
||||
const domain = copilotCreds.enterpriseUrl ? normalizeDomain(copilotCreds.enterpriseUrl) : undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCreds.access, domain ?? undefined);
|
||||
return {
|
||||
models: combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)),
|
||||
error: null,
|
||||
};
|
||||
}
|
||||
|
||||
return { models: combined, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model (checks custom providers first, then built-in)
|
||||
* Now async to support OAuth token refresh.
|
||||
* Note: OAuth storage location is configured globally via setOAuthStorage.
|
||||
*/
|
||||
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
|
||||
// For custom providers, check their apiKey config
|
||||
const customKeyConfig = customProviderApiKeys.get(model.provider);
|
||||
if (customKeyConfig) {
|
||||
return resolveApiKey(customKeyConfig);
|
||||
}
|
||||
|
||||
// For Anthropic, check OAuth first
|
||||
if (model.provider === "anthropic") {
|
||||
// 1. Check OAuth storage (auto-refresh if needed)
|
||||
const oauthToken = await getOAuthToken("anthropic");
|
||||
if (oauthToken) {
|
||||
return oauthToken;
|
||||
}
|
||||
|
||||
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
|
||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
if (oauthEnv) {
|
||||
return oauthEnv;
|
||||
}
|
||||
|
||||
// 3. Fall back to ANTHROPIC_API_KEY env var
|
||||
}
|
||||
|
||||
if (model.provider === "github-copilot") {
|
||||
// 1. Check OAuth storage (from device flow login)
|
||||
const oauthToken = await getOAuthToken("github-copilot");
|
||||
if (oauthToken) {
|
||||
return oauthToken;
|
||||
}
|
||||
|
||||
// 2. Use GitHub token directly (works with copilot scope on github.com)
|
||||
const githubToken = process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
|
||||
if (!githubToken) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// 3. For enterprise, exchange token for short-lived Copilot token
|
||||
const enterpriseDomain = process.env.COPILOT_ENTERPRISE_URL
|
||||
? normalizeDomain(process.env.COPILOT_ENTERPRISE_URL)
|
||||
: undefined;
|
||||
|
||||
if (enterpriseDomain) {
|
||||
const creds = await refreshGitHubCopilotToken(githubToken, enterpriseDomain);
|
||||
saveOAuthCredentials("github-copilot", creds);
|
||||
return creds.access;
|
||||
}
|
||||
|
||||
// 4. For github.com, use token directly
|
||||
return githubToken;
|
||||
}
|
||||
|
||||
// For Google Gemini CLI and Antigravity, check OAuth and encode projectId with token
|
||||
if (model.provider === "google-gemini-cli" || model.provider === "google-antigravity") {
|
||||
const oauthProvider = model.provider as "google-gemini-cli" | "google-antigravity";
|
||||
const credentials = loadOAuthCredentials(oauthProvider);
|
||||
if (!credentials) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if (Date.now() >= credentials.expires) {
|
||||
try {
|
||||
await refreshToken(oauthProvider);
|
||||
const refreshedCreds = loadOAuthCredentials(oauthProvider);
|
||||
if (refreshedCreds?.projectId) {
|
||||
return JSON.stringify({ token: refreshedCreds.access, projectId: refreshedCreds.projectId });
|
||||
}
|
||||
} catch {
|
||||
removeOAuthCredentials(oauthProvider);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
if (credentials.projectId) {
|
||||
return JSON.stringify({ token: credentials.access, projectId: credentials.projectId });
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// For built-in providers, use getApiKey from @mariozechner/pi-ai
|
||||
return getApiKey(model.provider as KnownProvider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have valid API keys available
|
||||
* Returns { models, error } - either models array or error message
|
||||
*
|
||||
* @param agentDir - Agent config directory
|
||||
* @param fallbackKeyResolver - Optional function to check for API keys not found by getApiKeyForModel
|
||||
* (e.g., keys from settings.json)
|
||||
*/
|
||||
export async function getAvailableModels(
|
||||
agentDir: string = getAgentDir(),
|
||||
fallbackKeyResolver?: (provider: string) => string | undefined,
|
||||
): Promise<{ models: Model<Api>[]; error: string | null }> {
|
||||
const { models: allModels, error } = loadAndMergeModels(agentDir);
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
const availableModels: Model<Api>[] = [];
|
||||
for (const model of allModels) {
|
||||
let apiKey = await getApiKeyForModel(model);
|
||||
// Check fallback resolver if primary lookup failed
|
||||
if (!apiKey && fallbackKeyResolver) {
|
||||
apiKey = fallbackKeyResolver(model.provider);
|
||||
}
|
||||
if (apiKey) {
|
||||
availableModels.push(model);
|
||||
}
|
||||
}
|
||||
|
||||
return { models: availableModels, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a specific model by provider and ID.
|
||||
*
|
||||
* Searches models from:
|
||||
* 1. Built-in models from @mariozechner/pi-ai
|
||||
* 2. Custom models defined in ~/.pi/agent/models.json
|
||||
*
|
||||
* Returns { model, error } - either the model or an error message.
|
||||
*/
|
||||
export function findModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
agentDir: string = getAgentDir(),
|
||||
): { model: Model<Api> | null; error: string | null } {
|
||||
const { models: allModels, error } = loadAndMergeModels(agentDir);
|
||||
|
||||
if (error) {
|
||||
return { model: null, error };
|
||||
}
|
||||
|
||||
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
|
||||
return { model, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping from model provider to OAuth provider ID.
|
||||
* Only providers that support OAuth are listed here.
|
||||
*/
|
||||
const providerToOAuthProvider: Record<string, OAuthProvider> = {
|
||||
anthropic: "anthropic",
|
||||
"github-copilot": "github-copilot",
|
||||
"google-gemini-cli": "google-gemini-cli",
|
||||
"google-antigravity": "google-antigravity",
|
||||
};
|
||||
|
||||
// Cache for OAuth status per provider (avoids file reads on every render)
|
||||
const oauthStatusCache: Map<string, boolean> = new Map();
|
||||
|
||||
/**
|
||||
* Invalidate the OAuth status cache.
|
||||
* Call this after login/logout operations.
|
||||
*/
|
||||
export function invalidateOAuthCache(): void {
|
||||
oauthStatusCache.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
* This checks if OAuth credentials exist and would be used for the model,
|
||||
* without actually fetching or refreshing the token.
|
||||
* Results are cached until invalidateOAuthCache() is called.
|
||||
*/
|
||||
export function isModelUsingOAuth(model: Model<Api>): boolean {
|
||||
const oauthProvider = providerToOAuthProvider[model.provider];
|
||||
if (!oauthProvider) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if (oauthStatusCache.has(oauthProvider)) {
|
||||
return oauthStatusCache.get(oauthProvider)!;
|
||||
}
|
||||
|
||||
// Check if OAuth credentials exist for this provider
|
||||
let usingOAuth = false;
|
||||
const credentials = loadOAuthCredentials(oauthProvider);
|
||||
if (credentials) {
|
||||
usingOAuth = true;
|
||||
}
|
||||
|
||||
// Also check for manual OAuth token env var (for Anthropic)
|
||||
if (!usingOAuth && model.provider === "anthropic" && process.env.ANTHROPIC_OAUTH_TOKEN) {
|
||||
usingOAuth = true;
|
||||
}
|
||||
|
||||
oauthStatusCache.set(oauthProvider, usingOAuth);
|
||||
return usingOAuth;
|
||||
}
|
||||
|
|
@ -30,22 +30,17 @@
|
|||
*/
|
||||
|
||||
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { AgentSession } from "./agent-session.js";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
|
||||
import type { CustomAgentTool } from "./custom-tools/types.js";
|
||||
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
|
||||
import type { HookFactory } from "./hooks/types.js";
|
||||
import { messageTransformer } from "./messages.js";
|
||||
import {
|
||||
findModel as findModelInternal,
|
||||
getApiKeyForModel,
|
||||
getAvailableModels,
|
||||
loadAndMergeModels,
|
||||
} from "./models-json.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
|
||||
import { loadSkills as loadSkillsInternal, type Skill } from "./skills.js";
|
||||
|
|
@ -86,6 +81,11 @@ export interface CreateAgentSessionOptions {
|
|||
/** Global config directory. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
|
||||
/** Auth storage for credentials. Default: discoverAuthStorage(agentDir) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: discoverModels(authStorage, agentDir) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
model?: Model<any>;
|
||||
/** Thinking level. Default: from settings, else 'off' (clamped to model capabilities) */
|
||||
|
|
@ -93,9 +93,6 @@ export interface CreateAgentSessionOptions {
|
|||
/** Models available for cycling (Ctrl+P in interactive mode) */
|
||||
scopedModels?: Array<{ model: Model<any>; thinkingLevel: ThinkingLevel }>;
|
||||
|
||||
/** API key resolver. Default: defaultGetApiKey() */
|
||||
getApiKey?: (model: Model<any>) => Promise<string | undefined>;
|
||||
|
||||
/** System prompt. String replaces default, function receives default and returns final. */
|
||||
systemPrompt?: string | ((defaultPrompt: string) => string);
|
||||
|
||||
|
|
@ -177,73 +174,20 @@ function getDefaultAgentDir(): string {
|
|||
return getAgentDir();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure OAuth storage to use the specified agent directory.
|
||||
* Must be called before using OAuth-based authentication.
|
||||
*/
|
||||
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
|
||||
const oauthPath = join(agentDir, "oauth.json");
|
||||
|
||||
setOAuthStorage({
|
||||
load: () => {
|
||||
if (!existsSync(oauthPath)) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
return JSON.parse(readFileSync(oauthPath, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
},
|
||||
save: (storage) => {
|
||||
const dir = dirname(oauthPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
|
||||
chmodSync(oauthPath, 0o600);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Discovery Functions
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom from models.json).
|
||||
* Create an AuthStorage instance for the given agent directory.
|
||||
*/
|
||||
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
|
||||
const { models, error } = loadAndMergeModels(agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return models;
|
||||
export function discoverAuthStorage(agentDir: string = getDefaultAgentDir()): AuthStorage {
|
||||
return new AuthStorage(join(agentDir, "auth.json"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get models that have valid API keys available.
|
||||
* Create a ModelRegistry for the given agent directory.
|
||||
*/
|
||||
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
|
||||
const { models, error } = await getAvailableModels(agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
* @returns The model, or null if not found
|
||||
*/
|
||||
export function findModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
agentDir: string = getDefaultAgentDir(),
|
||||
): Model<any> | null {
|
||||
const { model, error } = findModelInternal(provider, modelId, agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return model;
|
||||
export function discoverModels(authStorage: AuthStorage, agentDir: string = getDefaultAgentDir()): ModelRegistry {
|
||||
return new ModelRegistry(authStorage, join(agentDir, "models.json"));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -326,30 +270,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
|
|||
|
||||
// API Key Helpers
|
||||
|
||||
/**
|
||||
* Create the default API key resolver.
|
||||
* Priority: OAuth > custom providers (models.json) > environment variables > settings.json apiKeys.
|
||||
*
|
||||
* OAuth takes priority so users logged in with a plan (e.g. unlimited tokens) aren't
|
||||
* accidentally billed via a PAYG API key sitting in settings.json.
|
||||
*/
|
||||
export function defaultGetApiKey(
|
||||
settingsManager?: SettingsManager,
|
||||
): (model: Model<any>) => Promise<string | undefined> {
|
||||
return async (model: Model<any>) => {
|
||||
// Check OAuth, custom providers, env vars first
|
||||
const resolvedKey = await getApiKeyForModel(model);
|
||||
if (resolvedKey) {
|
||||
return resolvedKey;
|
||||
}
|
||||
// Fall back to settings.json apiKeys
|
||||
if (settingsManager) {
|
||||
return settingsManager.getApiKey(model.provider);
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
}
|
||||
|
||||
// System Prompt
|
||||
|
||||
export interface BuildSystemPromptOptions {
|
||||
|
|
@ -457,8 +377,9 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
|
|||
* const { session } = await createAgentSession();
|
||||
*
|
||||
* // With explicit model
|
||||
* import { getModel } from '@mariozechner/pi-ai';
|
||||
* const { session } = await createAgentSession({
|
||||
* model: findModel('anthropic', 'claude-sonnet-4-20250514'),
|
||||
* model: getModel('anthropic', 'claude-opus-4-5'),
|
||||
* thinkingLevel: 'high',
|
||||
* });
|
||||
*
|
||||
|
|
@ -483,22 +404,16 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
const cwd = options.cwd ?? process.cwd();
|
||||
const agentDir = options.agentDir ?? getDefaultAgentDir();
|
||||
|
||||
// Configure OAuth storage for this agentDir
|
||||
configureOAuthStorage(agentDir);
|
||||
time("configureOAuthStorage");
|
||||
// Use provided or create AuthStorage and ModelRegistry
|
||||
const authStorage = options.authStorage ?? discoverAuthStorage(agentDir);
|
||||
const modelRegistry = options.modelRegistry ?? discoverModels(authStorage, agentDir);
|
||||
time("discoverModels");
|
||||
|
||||
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
time("settingsManager");
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, agentDir);
|
||||
time("sessionManager");
|
||||
|
||||
// Helper to check API key availability (settings first, then OAuth/env vars)
|
||||
const hasApiKey = async (m: Model<any>): Promise<boolean> => {
|
||||
const settingsKey = settingsManager.getApiKey(m.provider);
|
||||
if (settingsKey) return true;
|
||||
return !!(await getApiKeyForModel(m));
|
||||
};
|
||||
|
||||
// Check if session has existing data to restore
|
||||
const existingSession = sessionManager.buildSessionContext();
|
||||
time("loadSession");
|
||||
|
|
@ -509,8 +424,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
|
||||
// If session has data, try to restore model from it
|
||||
if (!model && hasExistingSession && existingSession.model) {
|
||||
const restoredModel = findModel(existingSession.model.provider, existingSession.model.modelId);
|
||||
if (restoredModel && (await hasApiKey(restoredModel))) {
|
||||
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
|
||||
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
||||
model = restoredModel;
|
||||
}
|
||||
if (!model) {
|
||||
|
|
@ -523,8 +438,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
const defaultProvider = settingsManager.getDefaultProvider();
|
||||
const defaultModelId = settingsManager.getDefaultModel();
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const settingsModel = findModel(defaultProvider, defaultModelId);
|
||||
if (settingsModel && (await hasApiKey(settingsModel))) {
|
||||
const settingsModel = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
|
||||
model = settingsModel;
|
||||
}
|
||||
}
|
||||
|
|
@ -532,14 +447,13 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
|
||||
// Fall back to first available model with a valid API key
|
||||
if (!model) {
|
||||
const allModels = discoverModels(agentDir);
|
||||
for (const m of allModels) {
|
||||
if (await hasApiKey(m)) {
|
||||
for (const m of modelRegistry.getAll()) {
|
||||
if (await modelRegistry.getApiKey(m)) {
|
||||
model = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
time("discoverAvailableModels");
|
||||
time("findAvailableModel");
|
||||
if (model) {
|
||||
if (modelFallbackMessage) {
|
||||
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
|
||||
|
|
@ -567,8 +481,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
thinkingLevel = "off";
|
||||
}
|
||||
|
||||
const getApiKey = options.getApiKey ?? defaultGetApiKey(settingsManager);
|
||||
|
||||
const skills = options.skills ?? discoverSkills(cwd, agentDir, settingsManager.getSkillsSettings());
|
||||
time("discoverSkills");
|
||||
|
||||
|
|
@ -661,7 +573,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
if (!currentModel) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await getApiKey(currentModel);
|
||||
const key = await modelRegistry.getApiKey(currentModel);
|
||||
if (!key) {
|
||||
throw new Error(`No API key found for provider "${currentModel.provider}"`);
|
||||
}
|
||||
|
|
@ -685,7 +597,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
hookRunner,
|
||||
customTools: customToolsResult.tools,
|
||||
skillsSettings: settingsManager.getSkillsSettings(),
|
||||
resolveApiKey: getApiKey,
|
||||
modelRegistry,
|
||||
});
|
||||
time("createAgentSession");
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ export interface Settings {
|
|||
customTools?: string[]; // Array of custom tool file paths
|
||||
skills?: SkillsSettings;
|
||||
terminal?: TerminalSettings;
|
||||
apiKeys?: Record<string, string>; // provider -> API key (e.g., { "anthropic": "sk-..." })
|
||||
}
|
||||
|
||||
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
|
||||
|
|
@ -366,27 +365,4 @@ export class SettingsManager {
|
|||
this.globalSettings.terminal.showImages = show;
|
||||
this.save();
|
||||
}
|
||||
|
||||
getApiKey(provider: string): string | undefined {
|
||||
return this.settings.apiKeys?.[provider];
|
||||
}
|
||||
|
||||
setApiKey(provider: string, key: string): void {
|
||||
if (!this.globalSettings.apiKeys) {
|
||||
this.globalSettings.apiKeys = {};
|
||||
}
|
||||
this.globalSettings.apiKeys[provider] = key;
|
||||
this.save();
|
||||
}
|
||||
|
||||
removeApiKey(provider: string): void {
|
||||
if (this.globalSettings.apiKeys) {
|
||||
delete this.globalSettings.apiKeys[provider];
|
||||
this.save();
|
||||
}
|
||||
}
|
||||
|
||||
getApiKeys(): Record<string, string> {
|
||||
return this.settings.apiKeys ?? {};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ export {
|
|||
type PromptOptions,
|
||||
type SessionStats,
|
||||
} from "./core/agent-session.js";
|
||||
// Auth and model registry
|
||||
export { type ApiKeyCredential, type AuthCredential, AuthStorage, type OAuthCredential } from "./core/auth-storage.js";
|
||||
// Compaction
|
||||
export {
|
||||
type CutPointResult,
|
||||
|
|
@ -72,24 +74,13 @@ export {
|
|||
isWriteToolResult,
|
||||
} from "./core/hooks/index.js";
|
||||
export { messageTransformer } from "./core/messages.js";
|
||||
// Model configuration and OAuth
|
||||
export { findModel, getApiKeyForModel, getAvailableModels } from "./core/models-json.js";
|
||||
export {
|
||||
getOAuthProviders,
|
||||
login,
|
||||
logout,
|
||||
type OAuthAuthInfo,
|
||||
type OAuthPrompt,
|
||||
type OAuthProvider,
|
||||
} from "./core/oauth/index.js";
|
||||
export { ModelRegistry } from "./core/model-registry.js";
|
||||
// SDK for programmatic usage
|
||||
export {
|
||||
type BuildSystemPromptOptions,
|
||||
buildSystemPrompt,
|
||||
type CreateAgentSessionOptions,
|
||||
type CreateAgentSessionResult,
|
||||
// Configuration
|
||||
configureOAuthStorage,
|
||||
// Factory
|
||||
createAgentSession,
|
||||
createBashTool,
|
||||
|
|
@ -102,18 +93,15 @@ export {
|
|||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createWriteTool,
|
||||
// Helpers
|
||||
defaultGetApiKey,
|
||||
discoverAvailableModels,
|
||||
// Discovery
|
||||
discoverAuthStorage,
|
||||
discoverContextFiles,
|
||||
discoverCustomTools,
|
||||
discoverHooks,
|
||||
// Discovery
|
||||
discoverModels,
|
||||
discoverSkills,
|
||||
discoverSlashCommands,
|
||||
type FileSlashCommand,
|
||||
findModel as findModelByProviderAndId,
|
||||
loadSettings,
|
||||
// Pre-built tools (use process.cwd())
|
||||
readOnlyTools,
|
||||
|
|
|
|||
|
|
@ -8,19 +8,20 @@
|
|||
import type { Attachment } from "@mariozechner/pi-agent-core";
|
||||
import { supportsXhigh } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
|
||||
import { join } from "path";
|
||||
import { type Args, parseArgs, printHelp } from "./cli/args.js";
|
||||
import { processFileArguments } from "./cli/file-processor.js";
|
||||
import { listModels } from "./cli/list-models.js";
|
||||
import { selectSession } from "./cli/session-picker.js";
|
||||
import { getModelsPath, VERSION } from "./config.js";
|
||||
import { getAgentDir, getModelsPath, VERSION } from "./config.js";
|
||||
import type { AgentSession } from "./core/agent-session.js";
|
||||
import { AuthStorage } from "./core/auth-storage.js";
|
||||
import type { LoadedCustomTool } from "./core/custom-tools/index.js";
|
||||
import { exportFromFile } from "./core/export-html.js";
|
||||
import type { HookUIContext } from "./core/index.js";
|
||||
import type { ModelRegistry } from "./core/model-registry.js";
|
||||
import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
|
||||
import { findModel } from "./core/models-json.js";
|
||||
import { type CreateAgentSessionOptions, configureOAuthStorage, createAgentSession } from "./core/sdk.js";
|
||||
import { type CreateAgentSessionOptions, createAgentSession, discoverAuthStorage, discoverModels } from "./core/sdk.js";
|
||||
import { SessionManager } from "./core/session-manager.js";
|
||||
import { SettingsManager } from "./core/settings-manager.js";
|
||||
import { resolvePromptInput } from "./core/system-prompt.js";
|
||||
|
|
@ -33,7 +34,7 @@ import { ensureTool } from "./utils/tools-manager.js";
|
|||
|
||||
async function checkForNewVersion(currentVersion: string): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi-coding-agent/latest");
|
||||
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi -coding-agent/latest");
|
||||
if (!response.ok) return null;
|
||||
|
||||
const data = (await response.json()) as { version?: string };
|
||||
|
|
@ -54,6 +55,8 @@ async function runInteractiveMode(
|
|||
version: string,
|
||||
changelogMarkdown: string | null,
|
||||
modelFallbackMessage: string | undefined,
|
||||
modelsJsonError: string | null,
|
||||
migratedProviders: string[],
|
||||
versionCheckPromise: Promise<string | null>,
|
||||
initialMessages: string[],
|
||||
customTools: LoadedCustomTool[],
|
||||
|
|
@ -74,6 +77,14 @@ async function runInteractiveMode(
|
|||
|
||||
mode.renderInitialMessages(session.state);
|
||||
|
||||
if (migratedProviders.length > 0) {
|
||||
mode.showWarning(`Migrated credentials to auth.json: ${migratedProviders.join(", ")}`);
|
||||
}
|
||||
|
||||
if (modelsJsonError) {
|
||||
mode.showError(`models.json error: ${modelsJsonError}`);
|
||||
}
|
||||
|
||||
if (modelFallbackMessage) {
|
||||
mode.showWarning(modelFallbackMessage);
|
||||
}
|
||||
|
|
@ -175,6 +186,7 @@ function buildSessionOptions(
|
|||
parsed: Args,
|
||||
scopedModels: ScopedModel[],
|
||||
sessionManager: SessionManager | null,
|
||||
modelRegistry: ModelRegistry,
|
||||
): CreateAgentSessionOptions {
|
||||
const options: CreateAgentSessionOptions = {};
|
||||
|
||||
|
|
@ -187,11 +199,7 @@ function buildSessionOptions(
|
|||
|
||||
// Model from CLI
|
||||
if (parsed.provider && parsed.model) {
|
||||
const { model, error } = findModel(parsed.provider, parsed.model);
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const model = modelRegistry.find(parsed.provider, parsed.model);
|
||||
if (!model) {
|
||||
console.error(chalk.red(`Model ${parsed.provider}/${parsed.model} not found`));
|
||||
process.exit(1);
|
||||
|
|
@ -213,10 +221,8 @@ function buildSessionOptions(
|
|||
options.scopedModels = scopedModels;
|
||||
}
|
||||
|
||||
// API key from CLI
|
||||
if (parsed.apiKey) {
|
||||
options.getApiKey = async () => parsed.apiKey!;
|
||||
}
|
||||
// API key from CLI - set in authStorage
|
||||
// (handled by caller before createAgentSession)
|
||||
|
||||
// System prompt
|
||||
if (resolvedSystemPrompt && resolvedAppendPrompt) {
|
||||
|
|
@ -252,8 +258,15 @@ function buildSessionOptions(
|
|||
|
||||
export async function main(args: string[]) {
|
||||
time("start");
|
||||
configureOAuthStorage();
|
||||
time("configureOAuthStorage");
|
||||
|
||||
// Migrate legacy oauth.json and settings.json apiKeys to auth.json
|
||||
const agentDir = getAgentDir();
|
||||
const migratedProviders = AuthStorage.migrateLegacy(join(agentDir, "auth.json"), agentDir);
|
||||
|
||||
// Create AuthStorage and ModelRegistry upfront
|
||||
const authStorage = discoverAuthStorage();
|
||||
const modelRegistry = discoverModels(authStorage);
|
||||
time("discoverModels");
|
||||
|
||||
const parsed = parseArgs(args);
|
||||
time("parseArgs");
|
||||
|
|
@ -270,8 +283,7 @@ export async function main(args: string[]) {
|
|||
|
||||
if (parsed.listModels !== undefined) {
|
||||
const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined;
|
||||
const settingsManager = SettingsManager.create(process.cwd());
|
||||
await listModels(searchPattern, settingsManager);
|
||||
await listModels(modelRegistry, searchPattern);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -306,7 +318,7 @@ export async function main(args: string[]) {
|
|||
|
||||
let scopedModels: ScopedModel[] = [];
|
||||
if (parsed.models && parsed.models.length > 0) {
|
||||
scopedModels = await resolveModelScope(parsed.models, settingsManager);
|
||||
scopedModels = await resolveModelScope(parsed.models, modelRegistry);
|
||||
time("resolveModelScope");
|
||||
}
|
||||
|
||||
|
|
@ -331,7 +343,19 @@ export async function main(args: string[]) {
|
|||
sessionManager = SessionManager.open(selectedPath);
|
||||
}
|
||||
|
||||
const sessionOptions = buildSessionOptions(parsed, scopedModels, sessionManager);
|
||||
const sessionOptions = buildSessionOptions(parsed, scopedModels, sessionManager, modelRegistry);
|
||||
sessionOptions.authStorage = authStorage;
|
||||
sessionOptions.modelRegistry = modelRegistry;
|
||||
|
||||
// Handle CLI --api-key as runtime override (not persisted)
|
||||
if (parsed.apiKey) {
|
||||
if (!sessionOptions.model) {
|
||||
console.error(chalk.red("--api-key requires a model to be specified via --provider/--model or -m/--models"));
|
||||
process.exit(1);
|
||||
}
|
||||
authStorage.setRuntimeApiKey(sessionOptions.model.provider, parsed.apiKey);
|
||||
}
|
||||
|
||||
time("buildSessionOptions");
|
||||
const { session, customToolsResult, modelFallbackMessage } = await createAgentSession(sessionOptions);
|
||||
time("createAgentSession");
|
||||
|
|
@ -382,6 +406,8 @@ export async function main(args: string[]) {
|
|||
VERSION,
|
||||
changelogMarkdown,
|
||||
modelFallbackMessage,
|
||||
modelRegistry.getError(),
|
||||
migratedProviders,
|
||||
versionCheckPromise,
|
||||
parsed.messages,
|
||||
customToolsResult.tools,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import type { AssistantMessage } from "@mariozechner/pi-ai";
|
|||
import { type Component, visibleWidth } from "@mariozechner/pi-tui";
|
||||
import { existsSync, type FSWatcher, readFileSync, watch } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import { isModelUsingOAuth } from "../../../core/models-json.js";
|
||||
import type { ModelRegistry } from "../../../core/model-registry.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
|
|
@ -31,13 +31,15 @@ function findGitHeadPath(): string | null {
|
|||
*/
|
||||
export class FooterComponent implements Component {
|
||||
private state: AgentState;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private cachedBranch: string | null | undefined = undefined; // undefined = not checked yet, null = not in git repo, string = branch name
|
||||
private gitWatcher: FSWatcher | null = null;
|
||||
private onBranchChange: (() => void) | null = null;
|
||||
private autoCompactEnabled: boolean = true;
|
||||
|
||||
constructor(state: AgentState) {
|
||||
constructor(state: AgentState, modelRegistry: ModelRegistry) {
|
||||
this.state = state;
|
||||
this.modelRegistry = modelRegistry;
|
||||
}
|
||||
|
||||
setAutoCompactEnabled(enabled: boolean): void {
|
||||
|
|
@ -207,7 +209,7 @@ export class FooterComponent implements Component {
|
|||
if (totalCacheWrite) statsParts.push(`W${formatTokens(totalCacheWrite)}`);
|
||||
|
||||
// Show cost with "(sub)" indicator if using OAuth subscription
|
||||
const usingSubscription = this.state.model ? isModelUsingOAuth(this.state.model) : false;
|
||||
const usingSubscription = this.state.model ? this.modelRegistry.isUsingOAuth(this.state.model) : false;
|
||||
if (totalCost || usingSubscription) {
|
||||
const costStr = `$${totalCost.toFixed(3)}${usingSubscription ? " (sub)" : ""}`;
|
||||
statsParts.push(costStr);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import {
|
|||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { getAvailableModels } from "../../../core/models-json.js";
|
||||
import type { ModelRegistry } from "../../../core/model-registry.js";
|
||||
import type { SettingsManager } from "../../../core/settings-manager.js";
|
||||
import { fuzzyFilter } from "../../../utils/fuzzy.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
|
@ -38,6 +38,7 @@ export class ModelSelectorComponent extends Container {
|
|||
private selectedIndex: number = 0;
|
||||
private currentModel: Model<any> | null;
|
||||
private settingsManager: SettingsManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private onSelectCallback: (model: Model<any>) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private errorMessage: string | null = null;
|
||||
|
|
@ -48,6 +49,7 @@ export class ModelSelectorComponent extends Container {
|
|||
tui: TUI,
|
||||
currentModel: Model<any> | null,
|
||||
settingsManager: SettingsManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
scopedModels: ReadonlyArray<ScopedModelItem>,
|
||||
onSelect: (model: Model<any>) => void,
|
||||
onCancel: () => void,
|
||||
|
|
@ -57,6 +59,7 @@ export class ModelSelectorComponent extends Container {
|
|||
this.tui = tui;
|
||||
this.currentModel = currentModel;
|
||||
this.settingsManager = settingsManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.scopedModels = scopedModels;
|
||||
this.onSelectCallback = onSelect;
|
||||
this.onCancelCallback = onCancel;
|
||||
|
|
@ -113,26 +116,29 @@ export class ModelSelectorComponent extends Container {
|
|||
model: scoped.model,
|
||||
}));
|
||||
} else {
|
||||
// Load available models fresh (includes custom models from models.json)
|
||||
// Pass settings manager's key resolver as fallback for settings.json apiKeys
|
||||
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
|
||||
this.settingsManager.getApiKey(provider),
|
||||
);
|
||||
// Refresh to pick up any changes to models.json
|
||||
this.modelRegistry.refresh();
|
||||
|
||||
// If there's an error loading models.json, we'll show it via the "no models" path
|
||||
// The error will be displayed to the user
|
||||
if (error) {
|
||||
this.allModels = [];
|
||||
this.filteredModels = [];
|
||||
this.errorMessage = error;
|
||||
return;
|
||||
// Check for models.json errors
|
||||
const loadError = this.modelRegistry.getError();
|
||||
if (loadError) {
|
||||
this.errorMessage = loadError;
|
||||
}
|
||||
|
||||
models = availableModels.map((model) => ({
|
||||
provider: model.provider,
|
||||
id: model.id,
|
||||
model,
|
||||
}));
|
||||
// Load available models (built-in models still work even if models.json failed)
|
||||
try {
|
||||
const availableModels = await this.modelRegistry.getAvailable();
|
||||
models = availableModels.map((model: Model<any>) => ({
|
||||
provider: model.provider,
|
||||
id: model.id,
|
||||
model,
|
||||
}));
|
||||
} catch (error) {
|
||||
this.allModels = [];
|
||||
this.filteredModels = [];
|
||||
this.errorMessage = error instanceof Error ? error.message : String(error);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort: current model first, then by provider
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { loadOAuthCredentials } from "@mariozechner/pi-ai";
|
||||
import { getOAuthProviders, type OAuthProviderInfo } from "@mariozechner/pi-ai";
|
||||
import { Container, isArrowDown, isArrowUp, isEnter, isEscape, Spacer, TruncatedText } from "@mariozechner/pi-tui";
|
||||
import { getOAuthProviders, type OAuthProviderInfo } from "../../../core/oauth/index.js";
|
||||
import type { AuthStorage } from "../../../core/auth-storage.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
|
|
@ -12,13 +12,20 @@ export class OAuthSelectorComponent extends Container {
|
|||
private allProviders: OAuthProviderInfo[] = [];
|
||||
private selectedIndex: number = 0;
|
||||
private mode: "login" | "logout";
|
||||
private authStorage: AuthStorage;
|
||||
private onSelectCallback: (providerId: string) => void;
|
||||
private onCancelCallback: () => void;
|
||||
|
||||
constructor(mode: "login" | "logout", onSelect: (providerId: string) => void, onCancel: () => void) {
|
||||
constructor(
|
||||
mode: "login" | "logout",
|
||||
authStorage: AuthStorage,
|
||||
onSelect: (providerId: string) => void,
|
||||
onCancel: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.mode = mode;
|
||||
this.authStorage = authStorage;
|
||||
this.onSelectCallback = onSelect;
|
||||
this.onCancelCallback = onCancel;
|
||||
|
||||
|
|
@ -49,7 +56,6 @@ export class OAuthSelectorComponent extends Container {
|
|||
|
||||
private loadProviders(): void {
|
||||
this.allProviders = getOAuthProviders();
|
||||
this.allProviders = this.allProviders.filter((p) => p.available);
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
|
|
@ -63,8 +69,8 @@ export class OAuthSelectorComponent extends Container {
|
|||
const isAvailable = provider.available;
|
||||
|
||||
// Check if user is logged in for this provider
|
||||
const credentials = loadOAuthCredentials(provider.id);
|
||||
const isLoggedIn = credentials !== null;
|
||||
const credentials = this.authStorage.get(provider.id);
|
||||
const isLoggedIn = credentials?.type === "oauth";
|
||||
const statusIndicator = isLoggedIn ? theme.fg("success", " ✓ logged in") : "";
|
||||
|
||||
let line = "";
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import * as fs from "node:fs";
|
|||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import type { AgentState, AppMessage, Attachment } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, Message } from "@mariozechner/pi-ai";
|
||||
import type { AssistantMessage, Message, OAuthProvider } from "@mariozechner/pi-ai";
|
||||
import type { SlashCommand } from "@mariozechner/pi-tui";
|
||||
import {
|
||||
CombinedAutocompleteProvider,
|
||||
|
|
@ -25,13 +25,11 @@ import {
|
|||
visibleWidth,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { exec, spawnSync } from "child_process";
|
||||
import { APP_NAME, getDebugLogPath, getOAuthPath } from "../../config.js";
|
||||
import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js";
|
||||
import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js";
|
||||
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js";
|
||||
import type { HookUIContext } from "../../core/hooks/index.js";
|
||||
import { isBashExecutionMessage } from "../../core/messages.js";
|
||||
import { invalidateOAuthCache } from "../../core/models-json.js";
|
||||
import { listOAuthProviders, login, logout, type OAuthProvider } from "../../core/oauth/index.js";
|
||||
import {
|
||||
getLatestCompactionEntry,
|
||||
SessionManager,
|
||||
|
|
@ -154,7 +152,7 @@ export class InteractiveMode {
|
|||
this.editor = new CustomEditor(getEditorTheme());
|
||||
this.editorContainer = new Container();
|
||||
this.editorContainer.addChild(this.editor);
|
||||
this.footer = new FooterComponent(session.state);
|
||||
this.footer = new FooterComponent(session.state, session.modelRegistry);
|
||||
this.footer.setAutoCompactEnabled(session.autoCompactionEnabled);
|
||||
|
||||
// Define slash commands for autocomplete
|
||||
|
|
@ -1484,6 +1482,7 @@ export class InteractiveMode {
|
|||
this.ui,
|
||||
this.session.model,
|
||||
this.settingsManager,
|
||||
this.session.modelRegistry,
|
||||
this.session.scopedModels,
|
||||
async (model) => {
|
||||
try {
|
||||
|
|
@ -1588,7 +1587,10 @@ export class InteractiveMode {
|
|||
|
||||
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
|
||||
if (mode === "logout") {
|
||||
const loggedInProviders = listOAuthProviders();
|
||||
const providers = this.session.modelRegistry.authStorage.list();
|
||||
const loggedInProviders = providers.filter(
|
||||
(p) => this.session.modelRegistry.authStorage.get(p)?.type === "oauth",
|
||||
);
|
||||
if (loggedInProviders.length === 0) {
|
||||
this.showStatus("No OAuth providers logged in. Use /login first.");
|
||||
return;
|
||||
|
|
@ -1598,6 +1600,7 @@ export class InteractiveMode {
|
|||
this.showSelector((done) => {
|
||||
const selector = new OAuthSelectorComponent(
|
||||
mode,
|
||||
this.session.modelRegistry.authStorage,
|
||||
async (providerId: string) => {
|
||||
done();
|
||||
|
||||
|
|
@ -1605,9 +1608,8 @@ export class InteractiveMode {
|
|||
this.showStatus(`Logging in to ${providerId}...`);
|
||||
|
||||
try {
|
||||
await login(
|
||||
providerId as OAuthProvider,
|
||||
(info) => {
|
||||
await this.session.modelRegistry.authStorage.login(providerId as OAuthProvider, {
|
||||
onAuth: (info: { url: string; instructions?: string }) => {
|
||||
this.chatContainer.addChild(new Spacer(1));
|
||||
this.chatContainer.addChild(new Text(theme.fg("accent", "Opening browser to:"), 1, 0));
|
||||
this.chatContainer.addChild(new Text(theme.fg("accent", info.url), 1, 0));
|
||||
|
|
@ -1625,7 +1627,7 @@ export class InteractiveMode {
|
|||
: "xdg-open";
|
||||
exec(`${openCmd} "${info.url}"`);
|
||||
},
|
||||
async (prompt) => {
|
||||
onPrompt: async (prompt: { message: string; placeholder?: string }) => {
|
||||
this.chatContainer.addChild(new Spacer(1));
|
||||
this.chatContainer.addChild(new Text(theme.fg("warning", prompt.message), 1, 0));
|
||||
if (prompt.placeholder) {
|
||||
|
|
@ -1648,32 +1650,35 @@ export class InteractiveMode {
|
|||
this.ui.requestRender();
|
||||
});
|
||||
},
|
||||
(message) => {
|
||||
onProgress: (message: string) => {
|
||||
this.chatContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
|
||||
this.ui.requestRender();
|
||||
},
|
||||
);
|
||||
|
||||
invalidateOAuthCache();
|
||||
});
|
||||
// Refresh models to pick up new baseUrl (e.g., github-copilot)
|
||||
this.session.modelRegistry.refresh();
|
||||
this.chatContainer.addChild(new Spacer(1));
|
||||
this.chatContainer.addChild(
|
||||
new Text(theme.fg("success", `✓ Successfully logged in to ${providerId}`), 1, 0),
|
||||
);
|
||||
this.chatContainer.addChild(new Text(theme.fg("dim", `Tokens saved to ${getOAuthPath()}`), 1, 0));
|
||||
this.chatContainer.addChild(
|
||||
new Text(theme.fg("dim", `Credentials saved to ${getAuthPath()}`), 1, 0),
|
||||
);
|
||||
this.ui.requestRender();
|
||||
} catch (error: unknown) {
|
||||
this.showError(`Login failed: ${error instanceof Error ? error.message : String(error)}`);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
await logout(providerId as OAuthProvider);
|
||||
invalidateOAuthCache();
|
||||
this.session.modelRegistry.authStorage.logout(providerId);
|
||||
// Refresh models to reset baseUrl
|
||||
this.session.modelRegistry.refresh();
|
||||
this.chatContainer.addChild(new Spacer(1));
|
||||
this.chatContainer.addChild(
|
||||
new Text(theme.fg("success", `✓ Successfully logged out of ${providerId}`), 1, 0),
|
||||
);
|
||||
this.chatContainer.addChild(
|
||||
new Text(theme.fg("dim", `Credentials removed from ${getOAuthPath()}`), 1, 0),
|
||||
new Text(theme.fg("dim", `Credentials removed from ${getAuthPath()}`), 1, 0),
|
||||
);
|
||||
this.ui.requestRender();
|
||||
} catch (error: unknown) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
|
|||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { AgentSession } from "../src/core/agent-session.js";
|
||||
import { AuthStorage } from "../src/core/auth-storage.js";
|
||||
import { ModelRegistry } from "../src/core/model-registry.js";
|
||||
import { SessionManager } from "../src/core/session-manager.js";
|
||||
import { SettingsManager } from "../src/core/settings-manager.js";
|
||||
import { codingTools } from "../src/core/tools/index.js";
|
||||
|
|
@ -58,11 +60,14 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => {
|
|||
|
||||
sessionManager = noSession ? SessionManager.inMemory() : SessionManager.create(tempDir);
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Must subscribe to enable session persistence
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
|
|||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { AgentSession, type AgentSessionEvent } from "../src/core/agent-session.js";
|
||||
import { AuthStorage } from "../src/core/auth-storage.js";
|
||||
import { ModelRegistry } from "../src/core/model-registry.js";
|
||||
import { SessionManager } from "../src/core/session-manager.js";
|
||||
import { SettingsManager } from "../src/core/settings-manager.js";
|
||||
import { codingTools } from "../src/core/tools/index.js";
|
||||
|
|
@ -62,11 +64,14 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
|
|||
|
||||
sessionManager = SessionManager.create(tempDir);
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Subscribe to track events
|
||||
|
|
@ -178,11 +183,14 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
|
|||
const noSessionManager = SessionManager.inMemory();
|
||||
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
const noSessionSession = new AgentSession({
|
||||
agent,
|
||||
sessionManager: noSessionManager,
|
||||
settingsManager,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core";
|
|||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { AgentSession } from "../src/core/agent-session.js";
|
||||
import { AuthStorage } from "../src/core/auth-storage.js";
|
||||
import { HookRunner, type LoadedHook, type SessionEvent } from "../src/core/hooks/index.js";
|
||||
import { ModelRegistry } from "../src/core/model-registry.js";
|
||||
import { SessionManager } from "../src/core/session-manager.js";
|
||||
import { SettingsManager } from "../src/core/settings-manager.js";
|
||||
import { codingTools } from "../src/core/tools/index.js";
|
||||
|
|
@ -83,6 +85,8 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => {
|
|||
|
||||
const sessionManager = SessionManager.create(tempDir);
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = new AuthStorage(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
hookRunner = new HookRunner(hooks, tempDir);
|
||||
hookRunner.setUIContext(
|
||||
|
|
@ -101,6 +105,7 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => {
|
|||
sessionManager,
|
||||
settingsManager,
|
||||
hookRunner,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
return session;
|
||||
|
|
|
|||
|
|
@ -2,13 +2,16 @@ import { Agent, type AgentEvent, type Attachment, ProviderTransport } from "@mar
|
|||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
AgentSession,
|
||||
AuthStorage,
|
||||
formatSkillsForPrompt,
|
||||
loadSkillsFromDir,
|
||||
ModelRegistry,
|
||||
messageTransformer,
|
||||
type Skill,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
import { existsSync, readFileSync, statSync } from "fs";
|
||||
import { mkdir, writeFile } from "fs/promises";
|
||||
import { homedir } from "os";
|
||||
import { join } from "path";
|
||||
import { MomSessionManager, MomSettingsManager } from "./context.js";
|
||||
import * as log from "./log.js";
|
||||
|
|
@ -435,11 +438,17 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi
|
|||
log.logInfo(`[${channelId}] Loaded ${loadedSession.messages.length} messages from context.jsonl`);
|
||||
}
|
||||
|
||||
// Create AuthStorage and ModelRegistry for AgentSession
|
||||
// Auth stored outside workspace so agent can't access it
|
||||
const authStorage = new AuthStorage(join(homedir(), ".pi", "mom", "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
|
||||
// Create AgentSession wrapper
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
sessionManager: sessionManager as any,
|
||||
settingsManager: settingsManager as any,
|
||||
modelRegistry,
|
||||
});
|
||||
|
||||
// Mutable per-run state - event handler references this
|
||||
|
|
|
|||
|
|
@ -28,5 +28,5 @@
|
|||
}
|
||||
},
|
||||
"include": ["packages/*/src/**/*", "packages/*/test/**/*", "packages/coding-agent/examples/**/*"],
|
||||
"exclude": ["packages/web-ui/**/*"]
|
||||
"exclude": ["packages/web-ui/**/*", "**/dist/**"]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue