mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 18:01:22 +00:00
Refactor OAuth/API key handling: AuthStorage and ModelRegistry
- Add AuthStorage class for credential storage (auth.json) - Add ModelRegistry class for model management with API key resolution - Add discoverAuthStorage() and discoverModels() discovery functions - Add migration from legacy oauth.json and settings.json apiKeys to auth.json - Remove configureOAuthStorage, defaultGetApiKey, findModel, discoverAvailableModels - Remove apiKeys from Settings type and SettingsManager methods - Rename getOAuthPath to getAuthPath - Update SDK, examples, docs, tests, and mom package Fixes #296
This commit is contained in:
parent
9f97f0c8da
commit
54018b6cc0
29 changed files with 953 additions and 2017 deletions
|
|
@ -23,7 +23,7 @@ import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custo
|
|||
import { exportSessionToHtml } from "./export-html.js";
|
||||
import type { HookRunner, SessionEventResult, TurnEndEvent, TurnStartEvent } from "./hooks/index.js";
|
||||
import type { BashExecutionMessage } from "./messages.js";
|
||||
import { getApiKeyForModel, getAvailableModels } from "./models-json.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
import type { CompactionEntry, SessionManager } from "./session-manager.js";
|
||||
import type { SettingsManager, SkillsSettings } from "./settings-manager.js";
|
||||
import { expandSlashCommand, type FileSlashCommand } from "./slash-commands.js";
|
||||
|
|
@ -56,8 +56,8 @@ export interface AgentSessionConfig {
|
|||
/** Custom tools for session lifecycle events */
|
||||
customTools?: LoadedCustomTool[];
|
||||
skillsSettings?: Required<SkillsSettings>;
|
||||
/** Resolve API key for a model. Default: getApiKeyForModel */
|
||||
resolveApiKey?: (model: Model<any>) => Promise<string | undefined>;
|
||||
/** Model registry for API key resolution and model discovery */
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
|
||||
/** Options for AgentSession.prompt() */
|
||||
|
|
@ -153,8 +153,8 @@ export class AgentSession {
|
|||
|
||||
private _skillsSettings: Required<SkillsSettings> | undefined;
|
||||
|
||||
// API key resolver
|
||||
private _resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
|
||||
// Model registry for API key resolution
|
||||
private _modelRegistry: ModelRegistry;
|
||||
|
||||
constructor(config: AgentSessionConfig) {
|
||||
this.agent = config.agent;
|
||||
|
|
@ -165,7 +165,12 @@ export class AgentSession {
|
|||
this._hookRunner = config.hookRunner ?? null;
|
||||
this._customTools = config.customTools ?? [];
|
||||
this._skillsSettings = config.skillsSettings;
|
||||
this._resolveApiKey = config.resolveApiKey ?? getApiKeyForModel;
|
||||
this._modelRegistry = config.modelRegistry;
|
||||
}
|
||||
|
||||
/** Model registry for API key resolution and model discovery */
|
||||
get modelRegistry(): ModelRegistry {
|
||||
return this._modelRegistry;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
|
|
@ -434,7 +439,7 @@ export class AgentSession {
|
|||
}
|
||||
|
||||
// Validate API key
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(
|
||||
`No API key found for ${this.model.provider}.\n\n` +
|
||||
|
|
@ -561,7 +566,7 @@ export class AgentSession {
|
|||
* @throws Error if no API key available for the model
|
||||
*/
|
||||
async setModel(model: Model<any>): Promise<void> {
|
||||
const apiKey = await this._resolveApiKey(model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${model.provider}/${model.id}`);
|
||||
}
|
||||
|
|
@ -599,7 +604,7 @@ export class AgentSession {
|
|||
const next = this._scopedModels[nextIndex];
|
||||
|
||||
// Validate API key
|
||||
const apiKey = await this._resolveApiKey(next.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(next.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${next.model.provider}/${next.model.id}`);
|
||||
}
|
||||
|
|
@ -616,10 +621,7 @@ export class AgentSession {
|
|||
}
|
||||
|
||||
private async _cycleAvailableModel(): Promise<ModelCycleResult | null> {
|
||||
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
|
||||
this.settingsManager.getApiKey(provider),
|
||||
);
|
||||
if (error) throw new Error(`Failed to load models: ${error}`);
|
||||
const availableModels = await this._modelRegistry.getAvailable();
|
||||
if (availableModels.length <= 1) return null;
|
||||
|
||||
const currentModel = this.model;
|
||||
|
|
@ -631,7 +633,7 @@ export class AgentSession {
|
|||
const nextIndex = (currentIndex + 1) % availableModels.length;
|
||||
const nextModel = availableModels[nextIndex];
|
||||
|
||||
const apiKey = await this._resolveApiKey(nextModel);
|
||||
const apiKey = await this._modelRegistry.getApiKey(nextModel);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`);
|
||||
}
|
||||
|
|
@ -650,11 +652,7 @@ export class AgentSession {
|
|||
* Get all available models with valid API keys.
|
||||
*/
|
||||
async getAvailableModels(): Promise<Model<any>[]> {
|
||||
const { models, error } = await getAvailableModels(undefined, (provider) =>
|
||||
this.settingsManager.getApiKey(provider),
|
||||
);
|
||||
if (error) throw new Error(error);
|
||||
return models;
|
||||
return this._modelRegistry.getAvailable();
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
|
|
@ -747,7 +745,7 @@ export class AgentSession {
|
|||
throw new Error("No model selected");
|
||||
}
|
||||
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${this.model.provider}`);
|
||||
}
|
||||
|
|
@ -786,7 +784,7 @@ export class AgentSession {
|
|||
tokensBefore: preparation.tokensBefore,
|
||||
customInstructions,
|
||||
model: this.model,
|
||||
resolveApiKey: this._resolveApiKey,
|
||||
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
|
||||
signal: this._compactionAbortController.signal,
|
||||
})) as SessionEventResult | undefined;
|
||||
|
||||
|
|
@ -908,7 +906,7 @@ export class AgentSession {
|
|||
return;
|
||||
}
|
||||
|
||||
const apiKey = await this._resolveApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
if (!apiKey) {
|
||||
this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false });
|
||||
return;
|
||||
|
|
@ -948,7 +946,7 @@ export class AgentSession {
|
|||
tokensBefore: preparation.tokensBefore,
|
||||
customInstructions: undefined,
|
||||
model: this.model,
|
||||
resolveApiKey: this._resolveApiKey,
|
||||
resolveApiKey: async (m: Model<any>) => (await this._modelRegistry.getApiKey(m)) ?? undefined,
|
||||
signal: this._autoCompactionAbortController.signal,
|
||||
})) as SessionEventResult | undefined;
|
||||
|
||||
|
|
@ -1334,9 +1332,7 @@ export class AgentSession {
|
|||
|
||||
// Restore model if saved
|
||||
if (sessionContext.model) {
|
||||
const availableModels = (
|
||||
await getAvailableModels(undefined, (provider) => this.settingsManager.getApiKey(provider))
|
||||
).models;
|
||||
const availableModels = await this._modelRegistry.getAvailable();
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === sessionContext.model!.provider && m.id === sessionContext.model!.modelId,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -3,9 +3,18 @@
|
|||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*/
|
||||
|
||||
import { getApiKeyFromEnv, getOAuthApiKey, type OAuthCredentials, type OAuthProvider } from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname } from "path";
|
||||
import {
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
loginAnthropic,
|
||||
loginAntigravity,
|
||||
loginGeminiCli,
|
||||
loginGitHubCopilot,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, renameSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
|
|
@ -25,11 +34,29 @@ export type AuthStorageData = Record<string, AuthCredential>;
|
|||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
private runtimeOverrides: Map<string, string> = new Map();
|
||||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
|
||||
constructor(private authPath: string) {
|
||||
this.reload();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a runtime API key override (not persisted to disk).
|
||||
* Used for CLI --api-key flag.
|
||||
*/
|
||||
setRuntimeApiKey(provider: string, apiKey: string): void {
|
||||
this.runtimeOverrides.set(provider, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a fallback resolver for API keys not found in auth.json or env vars.
|
||||
* Used for custom provider keys from models.json.
|
||||
*/
|
||||
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
|
||||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from disk.
|
||||
*/
|
||||
|
|
@ -101,14 +128,69 @@ export class AuthStorage {
|
|||
return { ...this.data };
|
||||
}
|
||||
|
||||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(
|
||||
provider: OAuthProvider,
|
||||
callbacks: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: { message: string; placeholder?: string }) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
},
|
||||
): Promise<void> {
|
||||
let credentials: OAuthCredentials;
|
||||
|
||||
switch (provider) {
|
||||
case "anthropic":
|
||||
credentials = await loginAnthropic(
|
||||
(url) => callbacks.onAuth({ url }),
|
||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||
);
|
||||
break;
|
||||
case "github-copilot":
|
||||
credentials = await loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
});
|
||||
break;
|
||||
case "google-gemini-cli":
|
||||
credentials = await loginGeminiCli(callbacks.onAuth, callbacks.onProgress);
|
||||
break;
|
||||
case "google-antigravity":
|
||||
credentials = await loginAntigravity(callbacks.onAuth, callbacks.onProgress);
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
||||
}
|
||||
|
||||
this.set(provider, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout from a provider.
|
||||
*/
|
||||
logout(provider: string): void {
|
||||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. API key from auth.json
|
||||
* 2. OAuth token from auth.json (auto-refreshed)
|
||||
* 3. Environment variable (via getApiKeyFromEnv)
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(provider: string): Promise<string | null> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(provider);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[provider];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
|
|
@ -116,30 +198,83 @@ export class AuthStorage {
|
|||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
// Build OAuthCredentials map (without type discriminator)
|
||||
// Filter to only oauth credentials for getOAuthApiKey
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(this.data)) {
|
||||
if (value.type === "oauth") {
|
||||
const { type: _, ...rest } = value;
|
||||
oauthCreds[key] = rest;
|
||||
oauthCreds[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCreds);
|
||||
if (result) {
|
||||
// Save refreshed credentials
|
||||
this.data[provider] = { type: "oauth", ...result.newCredentials };
|
||||
this.save();
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch {
|
||||
// Token refresh failed, remove invalid credentials
|
||||
this.remove(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
return getApiKeyFromEnv(provider) ?? null;
|
||||
const envKey = getEnvApiKey(provider);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(provider) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate credentials from legacy oauth.json and settings.json apiKeys to auth.json.
|
||||
* Only runs if auth.json doesn't exist yet. Returns list of migrated providers.
|
||||
*/
|
||||
static migrateLegacy(authPath: string, agentDir: string): string[] {
|
||||
const oauthPath = join(agentDir, "oauth.json");
|
||||
const settingsPath = join(agentDir, "settings.json");
|
||||
|
||||
// Skip if auth.json already exists
|
||||
if (existsSync(authPath)) return [];
|
||||
|
||||
const migrated: AuthStorageData = {};
|
||||
const providers: string[] = [];
|
||||
|
||||
// Migrate oauth.json
|
||||
if (existsSync(oauthPath)) {
|
||||
try {
|
||||
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
|
||||
for (const [provider, cred] of Object.entries(oauth)) {
|
||||
migrated[provider] = { type: "oauth", ...(cred as object) } as OAuthCredential;
|
||||
providers.push(provider);
|
||||
}
|
||||
renameSync(oauthPath, `${oauthPath}.migrated`);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
// Migrate settings.json apiKeys
|
||||
if (existsSync(settingsPath)) {
|
||||
try {
|
||||
const content = readFileSync(settingsPath, "utf-8");
|
||||
const settings = JSON.parse(content);
|
||||
if (settings.apiKeys && typeof settings.apiKeys === "object") {
|
||||
for (const [provider, key] of Object.entries(settings.apiKeys)) {
|
||||
if (!migrated[provider] && typeof key === "string") {
|
||||
migrated[provider] = { type: "api_key", key };
|
||||
providers.push(provider);
|
||||
}
|
||||
}
|
||||
delete settings.apiKeys;
|
||||
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
if (Object.keys(migrated).length > 0) {
|
||||
mkdirSync(dirname(authPath), { recursive: true });
|
||||
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
|
||||
}
|
||||
|
||||
return providers;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
315
packages/coding-agent/src/core/model-registry.ts
Normal file
315
packages/coding-agent/src/core/model-registry.ts
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
/**
|
||||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
getGitHubCopilotBaseUrl,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
normalizeDomain,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
});
|
||||
|
||||
// Schema for custom model definition
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
contextWindow: Type.Number(),
|
||||
maxTokens: Type.Number(),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.String({ minLength: 1 }),
|
||||
apiKey: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Array(ModelDefinitionSchema),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* Checks environment variable first, then treats as literal.
|
||||
*/
|
||||
function resolveApiKeyConfig(keyConfig: string): string | undefined {
|
||||
const envValue = process.env[keyConfig];
|
||||
if (envValue) return envValue;
|
||||
return keyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private loadError: string | null = null;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | null = null,
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
if (keyConfig) {
|
||||
return resolveApiKeyConfig(keyConfig);
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = null;
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (null if no error).
|
||||
*/
|
||||
getError(): string | null {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load built-in models
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
for (const provider of getProviders()) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models from models.json (if path provided)
|
||||
let customModels: Model<Api>[] = [];
|
||||
if (this.modelsJsonPath) {
|
||||
const result = this.loadCustomModels(this.modelsJsonPath);
|
||||
if (result.error) {
|
||||
this.loadError = result.error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
} else {
|
||||
customModels = result.models;
|
||||
}
|
||||
}
|
||||
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth credentials
|
||||
const copilotCred = this.authStorage.get("github-copilot");
|
||||
if (copilotCred?.type === "oauth") {
|
||||
const domain = copilotCred.enterpriseUrl
|
||||
? (normalizeDomain(copilotCred.enterpriseUrl) ?? undefined)
|
||||
: undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCred.access, domain);
|
||||
this.models = combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
} else {
|
||||
this.models = combined;
|
||||
}
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: null };
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const ajv = new Ajv();
|
||||
const validate = ajv.compile(ModelsConfigSchema);
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
// Parse models
|
||||
return { models: this.parseModels(config), error: null };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
||||
if (modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Store API key config for fallback resolver
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
if (!api) continue;
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
let headers =
|
||||
providerConfig.headers || modelDef.headers
|
||||
? { ...providerConfig.headers, ...modelDef.headers }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader) {
|
||||
const resolvedKey = resolveApiKeyConfig(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have valid API keys available.
|
||||
*/
|
||||
async getAvailable(): Promise<Model<Api>[]> {
|
||||
const available: Model<Api>[] = [];
|
||||
for (const model of this.models) {
|
||||
const apiKey = await this.authStorage.getApiKey(model.provider);
|
||||
if (apiKey) {
|
||||
available.push(model);
|
||||
}
|
||||
}
|
||||
return available;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | null {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | null> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
*/
|
||||
isUsingOAuth(model: Model<Api>): boolean {
|
||||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
}
|
||||
|
|
@ -6,8 +6,7 @@ import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
|||
import type { Api, KnownProvider, Model } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { isValidThinkingLevel } from "../cli/args.js";
|
||||
import { findModel, getApiKeyForModel, getAvailableModels } from "./models-json.js";
|
||||
import type { SettingsManager } from "./settings-manager.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
/** Default model IDs for each known provider */
|
||||
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
||||
|
|
@ -167,21 +166,9 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
|
|||
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
||||
* The algorithm tries to match the full pattern first, then progressively
|
||||
* strips colon-suffixes to find a match.
|
||||
*
|
||||
* @param patterns - Model patterns to resolve
|
||||
* @param settingsManager - Optional settings manager for API key fallback from settings.json
|
||||
*/
|
||||
export async function resolveModelScope(patterns: string[], settingsManager?: SettingsManager): Promise<ScopedModel[]> {
|
||||
const { models: availableModels, error } = await getAvailableModels(
|
||||
undefined,
|
||||
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.warn(chalk.yellow(`Warning: Error loading models: ${error}`));
|
||||
return [];
|
||||
}
|
||||
|
||||
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
const scopedModels: ScopedModel[] = [];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
|
|
@ -224,20 +211,28 @@ export async function findInitialModel(options: {
|
|||
cliModel?: string;
|
||||
scopedModels: ScopedModel[];
|
||||
isContinuing: boolean;
|
||||
settingsManager: SettingsManager;
|
||||
defaultProvider?: string;
|
||||
defaultModelId?: string;
|
||||
defaultThinkingLevel?: ThinkingLevel;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<InitialModelResult> {
|
||||
const { cliProvider, cliModel, scopedModels, isContinuing, settingsManager } = options;
|
||||
const {
|
||||
cliProvider,
|
||||
cliModel,
|
||||
scopedModels,
|
||||
isContinuing,
|
||||
defaultProvider,
|
||||
defaultModelId,
|
||||
defaultThinkingLevel,
|
||||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | null = null;
|
||||
let thinkingLevel: ThinkingLevel = "off";
|
||||
|
||||
// 1. CLI args take priority
|
||||
if (cliProvider && cliModel) {
|
||||
const { model: found, error } = findModel(cliProvider, cliModel);
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const found = modelRegistry.find(cliProvider, cliModel);
|
||||
if (!found) {
|
||||
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
|
||||
process.exit(1);
|
||||
|
|
@ -255,34 +250,19 @@ export async function findInitialModel(options: {
|
|||
}
|
||||
|
||||
// 3. Try saved default from settings
|
||||
const defaultProvider = settingsManager.getDefaultProvider();
|
||||
const defaultModelId = settingsManager.getDefaultModel();
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const { model: found, error } = findModel(defaultProvider, defaultModelId);
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (found) {
|
||||
model = found;
|
||||
// Also load saved thinking level
|
||||
const savedThinking = settingsManager.getDefaultThinkingLevel();
|
||||
if (savedThinking) {
|
||||
thinkingLevel = savedThinking;
|
||||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: null };
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try first available model with valid API key
|
||||
const { models: availableModels, error } = await getAvailableModels(undefined, (provider) =>
|
||||
settingsManager.getApiKey(provider),
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
|
|
@ -310,17 +290,12 @@ export async function restoreModelFromSession(
|
|||
savedModelId: string,
|
||||
currentModel: Model<Api> | null,
|
||||
shouldPrintMessages: boolean,
|
||||
settingsManager?: SettingsManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
|
||||
const { model: restoredModel, error } = findModel(savedProvider, savedModelId);
|
||||
|
||||
if (error) {
|
||||
console.error(chalk.red(error));
|
||||
process.exit(1);
|
||||
}
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
const hasApiKey = restoredModel ? !!(await getApiKeyForModel(restoredModel)) : false;
|
||||
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
|
||||
|
||||
if (restoredModel && hasApiKey) {
|
||||
if (shouldPrintMessages) {
|
||||
|
|
@ -348,14 +323,7 @@ export async function restoreModelFromSession(
|
|||
}
|
||||
|
||||
// Try to find any available model
|
||||
const { models: availableModels, error: availableError } = await getAvailableModels(
|
||||
undefined,
|
||||
settingsManager ? (provider) => settingsManager.getApiKey(provider) : undefined,
|
||||
);
|
||||
if (availableError) {
|
||||
console.error(chalk.red(availableError));
|
||||
process.exit(1);
|
||||
}
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
|
|
|
|||
|
|
@ -1,467 +0,0 @@
|
|||
import {
|
||||
type Api,
|
||||
getApiKey,
|
||||
getGitHubCopilotBaseUrl,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
loadOAuthCredentials,
|
||||
type Model,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
removeOAuthCredentials,
|
||||
saveOAuthCredentials,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { getOAuthToken, type OAuthProvider, refreshToken } from "./oauth/index.js";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
});
|
||||
|
||||
// Schema for custom model definition
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
reasoning: Type.Boolean(),
|
||||
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
cost: Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
contextWindow: Type.Number(),
|
||||
maxTokens: Type.Number(),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.String({ minLength: 1 }),
|
||||
apiKey: Type.String({ minLength: 1 }),
|
||||
api: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai-completions"),
|
||||
Type.Literal("openai-responses"),
|
||||
Type.Literal("anthropic-messages"),
|
||||
Type.Literal("google-generative-ai"),
|
||||
]),
|
||||
),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Array(ModelDefinitionSchema),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
// Custom provider API key mappings (provider name -> apiKey config)
|
||||
const customProviderApiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Resolve an API key config value to an actual key.
|
||||
* First checks if it's an environment variable, then treats as literal.
|
||||
*/
|
||||
export function resolveApiKey(keyConfig: string): string | undefined {
|
||||
// First check if it's an env var name
|
||||
const envValue = process.env[keyConfig];
|
||||
if (envValue) return envValue;
|
||||
|
||||
// Otherwise treat as literal API key
|
||||
return keyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load custom models from a models.json file
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
function loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: null };
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const ajv = new Ajv();
|
||||
const validate = ajv.compile(ModelsConfigSchema);
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
try {
|
||||
validateConfig(config);
|
||||
} catch (error) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Parse models
|
||||
return { models: parseModels(config), error: null };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
models: [],
|
||||
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate config structure and requirements
|
||||
*/
|
||||
function validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
|
||||
`Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
|
||||
if (modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse config into Model objects
|
||||
*/
|
||||
function parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
// Clear and rebuild custom provider API key mappings
|
||||
customProviderApiKeys.clear();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Store API key config for this provider
|
||||
customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
|
||||
for (const modelDef of providerConfig.models) {
|
||||
// Model-level api overrides provider-level api
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
|
||||
if (!api) {
|
||||
// This should have been caught by validateConfig, but be safe
|
||||
continue;
|
||||
}
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
let headers =
|
||||
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader) {
|
||||
const resolvedKey = resolveApiKey(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom), freshly loaded
|
||||
* Returns { models, error } - either models array or error message
|
||||
*/
|
||||
export function loadAndMergeModels(agentDir: string = getAgentDir()): { models: Model<Api>[]; error: string | null } {
|
||||
const builtInModels: Model<Api>[] = [];
|
||||
const providers = getProviders();
|
||||
|
||||
// Load all built-in models
|
||||
for (const provider of providers) {
|
||||
const providerModels = getModels(provider as KnownProvider);
|
||||
builtInModels.push(...(providerModels as Model<Api>[]));
|
||||
}
|
||||
|
||||
// Load custom models
|
||||
const { models: customModels, error } = loadCustomModels(join(agentDir, "models.json"));
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
const combined = [...builtInModels, ...customModels];
|
||||
|
||||
// Update github-copilot base URL based on OAuth token or enterprise domain
|
||||
const copilotCreds = loadOAuthCredentials("github-copilot");
|
||||
if (copilotCreds) {
|
||||
const domain = copilotCreds.enterpriseUrl ? normalizeDomain(copilotCreds.enterpriseUrl) : undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(copilotCreds.access, domain ?? undefined);
|
||||
return {
|
||||
models: combined.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m)),
|
||||
error: null,
|
||||
};
|
||||
}
|
||||
|
||||
return { models: combined, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model (checks custom providers first, then built-in)
|
||||
* Now async to support OAuth token refresh.
|
||||
* Note: OAuth storage location is configured globally via setOAuthStorage.
|
||||
*/
|
||||
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
|
||||
// For custom providers, check their apiKey config
|
||||
const customKeyConfig = customProviderApiKeys.get(model.provider);
|
||||
if (customKeyConfig) {
|
||||
return resolveApiKey(customKeyConfig);
|
||||
}
|
||||
|
||||
// For Anthropic, check OAuth first
|
||||
if (model.provider === "anthropic") {
|
||||
// 1. Check OAuth storage (auto-refresh if needed)
|
||||
const oauthToken = await getOAuthToken("anthropic");
|
||||
if (oauthToken) {
|
||||
return oauthToken;
|
||||
}
|
||||
|
||||
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
|
||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
if (oauthEnv) {
|
||||
return oauthEnv;
|
||||
}
|
||||
|
||||
// 3. Fall back to ANTHROPIC_API_KEY env var
|
||||
}
|
||||
|
||||
if (model.provider === "github-copilot") {
|
||||
// 1. Check OAuth storage (from device flow login)
|
||||
const oauthToken = await getOAuthToken("github-copilot");
|
||||
if (oauthToken) {
|
||||
return oauthToken;
|
||||
}
|
||||
|
||||
// 2. Use GitHub token directly (works with copilot scope on github.com)
|
||||
const githubToken = process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
|
||||
if (!githubToken) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// 3. For enterprise, exchange token for short-lived Copilot token
|
||||
const enterpriseDomain = process.env.COPILOT_ENTERPRISE_URL
|
||||
? normalizeDomain(process.env.COPILOT_ENTERPRISE_URL)
|
||||
: undefined;
|
||||
|
||||
if (enterpriseDomain) {
|
||||
const creds = await refreshGitHubCopilotToken(githubToken, enterpriseDomain);
|
||||
saveOAuthCredentials("github-copilot", creds);
|
||||
return creds.access;
|
||||
}
|
||||
|
||||
// 4. For github.com, use token directly
|
||||
return githubToken;
|
||||
}
|
||||
|
||||
// For Google Gemini CLI and Antigravity, check OAuth and encode projectId with token
|
||||
if (model.provider === "google-gemini-cli" || model.provider === "google-antigravity") {
|
||||
const oauthProvider = model.provider as "google-gemini-cli" | "google-antigravity";
|
||||
const credentials = loadOAuthCredentials(oauthProvider);
|
||||
if (!credentials) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if (Date.now() >= credentials.expires) {
|
||||
try {
|
||||
await refreshToken(oauthProvider);
|
||||
const refreshedCreds = loadOAuthCredentials(oauthProvider);
|
||||
if (refreshedCreds?.projectId) {
|
||||
return JSON.stringify({ token: refreshedCreds.access, projectId: refreshedCreds.projectId });
|
||||
}
|
||||
} catch {
|
||||
removeOAuthCredentials(oauthProvider);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
if (credentials.projectId) {
|
||||
return JSON.stringify({ token: credentials.access, projectId: credentials.projectId });
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// For built-in providers, use getApiKey from @mariozechner/pi-ai
|
||||
return getApiKey(model.provider as KnownProvider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have valid API keys available
|
||||
* Returns { models, error } - either models array or error message
|
||||
*
|
||||
* @param agentDir - Agent config directory
|
||||
* @param fallbackKeyResolver - Optional function to check for API keys not found by getApiKeyForModel
|
||||
* (e.g., keys from settings.json)
|
||||
*/
|
||||
export async function getAvailableModels(
|
||||
agentDir: string = getAgentDir(),
|
||||
fallbackKeyResolver?: (provider: string) => string | undefined,
|
||||
): Promise<{ models: Model<Api>[]; error: string | null }> {
|
||||
const { models: allModels, error } = loadAndMergeModels(agentDir);
|
||||
|
||||
if (error) {
|
||||
return { models: [], error };
|
||||
}
|
||||
|
||||
const availableModels: Model<Api>[] = [];
|
||||
for (const model of allModels) {
|
||||
let apiKey = await getApiKeyForModel(model);
|
||||
// Check fallback resolver if primary lookup failed
|
||||
if (!apiKey && fallbackKeyResolver) {
|
||||
apiKey = fallbackKeyResolver(model.provider);
|
||||
}
|
||||
if (apiKey) {
|
||||
availableModels.push(model);
|
||||
}
|
||||
}
|
||||
|
||||
return { models: availableModels, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a specific model by provider and ID.
|
||||
*
|
||||
* Searches models from:
|
||||
* 1. Built-in models from @mariozechner/pi-ai
|
||||
* 2. Custom models defined in ~/.pi/agent/models.json
|
||||
*
|
||||
* Returns { model, error } - either the model or an error message.
|
||||
*/
|
||||
export function findModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
agentDir: string = getAgentDir(),
|
||||
): { model: Model<Api> | null; error: string | null } {
|
||||
const { models: allModels, error } = loadAndMergeModels(agentDir);
|
||||
|
||||
if (error) {
|
||||
return { model: null, error };
|
||||
}
|
||||
|
||||
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
|
||||
return { model, error: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping from model provider to OAuth provider ID.
|
||||
* Only providers that support OAuth are listed here.
|
||||
*/
|
||||
const providerToOAuthProvider: Record<string, OAuthProvider> = {
|
||||
anthropic: "anthropic",
|
||||
"github-copilot": "github-copilot",
|
||||
"google-gemini-cli": "google-gemini-cli",
|
||||
"google-antigravity": "google-antigravity",
|
||||
};
|
||||
|
||||
// Cache for OAuth status per provider (avoids file reads on every render)
|
||||
const oauthStatusCache: Map<string, boolean> = new Map();
|
||||
|
||||
/**
|
||||
* Invalidate the OAuth status cache.
|
||||
* Call this after login/logout operations.
|
||||
*/
|
||||
export function invalidateOAuthCache(): void {
|
||||
oauthStatusCache.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
* This checks if OAuth credentials exist and would be used for the model,
|
||||
* without actually fetching or refreshing the token.
|
||||
* Results are cached until invalidateOAuthCache() is called.
|
||||
*/
|
||||
export function isModelUsingOAuth(model: Model<Api>): boolean {
|
||||
const oauthProvider = providerToOAuthProvider[model.provider];
|
||||
if (!oauthProvider) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if (oauthStatusCache.has(oauthProvider)) {
|
||||
return oauthStatusCache.get(oauthProvider)!;
|
||||
}
|
||||
|
||||
// Check if OAuth credentials exist for this provider
|
||||
let usingOAuth = false;
|
||||
const credentials = loadOAuthCredentials(oauthProvider);
|
||||
if (credentials) {
|
||||
usingOAuth = true;
|
||||
}
|
||||
|
||||
// Also check for manual OAuth token env var (for Anthropic)
|
||||
if (!usingOAuth && model.provider === "anthropic" && process.env.ANTHROPIC_OAUTH_TOKEN) {
|
||||
usingOAuth = true;
|
||||
}
|
||||
|
||||
oauthStatusCache.set(oauthProvider, usingOAuth);
|
||||
return usingOAuth;
|
||||
}
|
||||
|
|
@ -30,22 +30,17 @@
|
|||
*/
|
||||
|
||||
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import { type Model, setOAuthStorage } from "@mariozechner/pi-ai";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { AgentSession } from "./agent-session.js";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
|
||||
import type { CustomAgentTool } from "./custom-tools/types.js";
|
||||
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
|
||||
import type { HookFactory } from "./hooks/types.js";
|
||||
import { messageTransformer } from "./messages.js";
|
||||
import {
|
||||
findModel as findModelInternal,
|
||||
getApiKeyForModel,
|
||||
getAvailableModels,
|
||||
loadAndMergeModels,
|
||||
} from "./models-json.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
|
||||
import { loadSkills as loadSkillsInternal, type Skill } from "./skills.js";
|
||||
|
|
@ -86,6 +81,11 @@ export interface CreateAgentSessionOptions {
|
|||
/** Global config directory. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
|
||||
/** Auth storage for credentials. Default: discoverAuthStorage(agentDir) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: discoverModels(authStorage, agentDir) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
model?: Model<any>;
|
||||
/** Thinking level. Default: from settings, else 'off' (clamped to model capabilities) */
|
||||
|
|
@ -93,9 +93,6 @@ export interface CreateAgentSessionOptions {
|
|||
/** Models available for cycling (Ctrl+P in interactive mode) */
|
||||
scopedModels?: Array<{ model: Model<any>; thinkingLevel: ThinkingLevel }>;
|
||||
|
||||
/** API key resolver. Default: defaultGetApiKey() */
|
||||
getApiKey?: (model: Model<any>) => Promise<string | undefined>;
|
||||
|
||||
/** System prompt. String replaces default, function receives default and returns final. */
|
||||
systemPrompt?: string | ((defaultPrompt: string) => string);
|
||||
|
||||
|
|
@ -177,73 +174,20 @@ function getDefaultAgentDir(): string {
|
|||
return getAgentDir();
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure OAuth storage to use the specified agent directory.
|
||||
* Must be called before using OAuth-based authentication.
|
||||
*/
|
||||
export function configureOAuthStorage(agentDir: string = getDefaultAgentDir()): void {
|
||||
const oauthPath = join(agentDir, "oauth.json");
|
||||
|
||||
setOAuthStorage({
|
||||
load: () => {
|
||||
if (!existsSync(oauthPath)) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
return JSON.parse(readFileSync(oauthPath, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
},
|
||||
save: (storage) => {
|
||||
const dir = dirname(oauthPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
writeFileSync(oauthPath, JSON.stringify(storage, null, 2), "utf-8");
|
||||
chmodSync(oauthPath, 0o600);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Discovery Functions
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom from models.json).
|
||||
* Create an AuthStorage instance for the given agent directory.
|
||||
*/
|
||||
export function discoverModels(agentDir: string = getDefaultAgentDir()): Model<any>[] {
|
||||
const { models, error } = loadAndMergeModels(agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return models;
|
||||
export function discoverAuthStorage(agentDir: string = getDefaultAgentDir()): AuthStorage {
|
||||
return new AuthStorage(join(agentDir, "auth.json"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get models that have valid API keys available.
|
||||
* Create a ModelRegistry for the given agent directory.
|
||||
*/
|
||||
export async function discoverAvailableModels(agentDir: string = getDefaultAgentDir()): Promise<Model<any>[]> {
|
||||
const { models, error } = await getAvailableModels(agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
* @returns The model, or null if not found
|
||||
*/
|
||||
export function findModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
agentDir: string = getDefaultAgentDir(),
|
||||
): Model<any> | null {
|
||||
const { model, error } = findModelInternal(provider, modelId, agentDir);
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
return model;
|
||||
export function discoverModels(authStorage: AuthStorage, agentDir: string = getDefaultAgentDir()): ModelRegistry {
|
||||
return new ModelRegistry(authStorage, join(agentDir, "models.json"));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -326,30 +270,6 @@ export function discoverSlashCommands(cwd?: string, agentDir?: string): FileSlas
|
|||
|
||||
// API Key Helpers
|
||||
|
||||
/**
|
||||
* Create the default API key resolver.
|
||||
* Priority: OAuth > custom providers (models.json) > environment variables > settings.json apiKeys.
|
||||
*
|
||||
* OAuth takes priority so users logged in with a plan (e.g. unlimited tokens) aren't
|
||||
* accidentally billed via a PAYG API key sitting in settings.json.
|
||||
*/
|
||||
export function defaultGetApiKey(
|
||||
settingsManager?: SettingsManager,
|
||||
): (model: Model<any>) => Promise<string | undefined> {
|
||||
return async (model: Model<any>) => {
|
||||
// Check OAuth, custom providers, env vars first
|
||||
const resolvedKey = await getApiKeyForModel(model);
|
||||
if (resolvedKey) {
|
||||
return resolvedKey;
|
||||
}
|
||||
// Fall back to settings.json apiKeys
|
||||
if (settingsManager) {
|
||||
return settingsManager.getApiKey(model.provider);
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
}
|
||||
|
||||
// System Prompt
|
||||
|
||||
export interface BuildSystemPromptOptions {
|
||||
|
|
@ -457,8 +377,9 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
|
|||
* const { session } = await createAgentSession();
|
||||
*
|
||||
* // With explicit model
|
||||
* import { getModel } from '@mariozechner/pi-ai';
|
||||
* const { session } = await createAgentSession({
|
||||
* model: findModel('anthropic', 'claude-sonnet-4-20250514'),
|
||||
* model: getModel('anthropic', 'claude-opus-4-5'),
|
||||
* thinkingLevel: 'high',
|
||||
* });
|
||||
*
|
||||
|
|
@ -483,22 +404,16 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
const cwd = options.cwd ?? process.cwd();
|
||||
const agentDir = options.agentDir ?? getDefaultAgentDir();
|
||||
|
||||
// Configure OAuth storage for this agentDir
|
||||
configureOAuthStorage(agentDir);
|
||||
time("configureOAuthStorage");
|
||||
// Use provided or create AuthStorage and ModelRegistry
|
||||
const authStorage = options.authStorage ?? discoverAuthStorage(agentDir);
|
||||
const modelRegistry = options.modelRegistry ?? discoverModels(authStorage, agentDir);
|
||||
time("discoverModels");
|
||||
|
||||
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
time("settingsManager");
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, agentDir);
|
||||
time("sessionManager");
|
||||
|
||||
// Helper to check API key availability (settings first, then OAuth/env vars)
|
||||
const hasApiKey = async (m: Model<any>): Promise<boolean> => {
|
||||
const settingsKey = settingsManager.getApiKey(m.provider);
|
||||
if (settingsKey) return true;
|
||||
return !!(await getApiKeyForModel(m));
|
||||
};
|
||||
|
||||
// Check if session has existing data to restore
|
||||
const existingSession = sessionManager.buildSessionContext();
|
||||
time("loadSession");
|
||||
|
|
@ -509,8 +424,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
|
||||
// If session has data, try to restore model from it
|
||||
if (!model && hasExistingSession && existingSession.model) {
|
||||
const restoredModel = findModel(existingSession.model.provider, existingSession.model.modelId);
|
||||
if (restoredModel && (await hasApiKey(restoredModel))) {
|
||||
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
|
||||
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
||||
model = restoredModel;
|
||||
}
|
||||
if (!model) {
|
||||
|
|
@ -523,8 +438,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
const defaultProvider = settingsManager.getDefaultProvider();
|
||||
const defaultModelId = settingsManager.getDefaultModel();
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const settingsModel = findModel(defaultProvider, defaultModelId);
|
||||
if (settingsModel && (await hasApiKey(settingsModel))) {
|
||||
const settingsModel = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (settingsModel && (await modelRegistry.getApiKey(settingsModel))) {
|
||||
model = settingsModel;
|
||||
}
|
||||
}
|
||||
|
|
@ -532,14 +447,13 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
|
||||
// Fall back to first available model with a valid API key
|
||||
if (!model) {
|
||||
const allModels = discoverModels(agentDir);
|
||||
for (const m of allModels) {
|
||||
if (await hasApiKey(m)) {
|
||||
for (const m of modelRegistry.getAll()) {
|
||||
if (await modelRegistry.getApiKey(m)) {
|
||||
model = m;
|
||||
break;
|
||||
}
|
||||
}
|
||||
time("discoverAvailableModels");
|
||||
time("findAvailableModel");
|
||||
if (model) {
|
||||
if (modelFallbackMessage) {
|
||||
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
|
||||
|
|
@ -567,8 +481,6 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
thinkingLevel = "off";
|
||||
}
|
||||
|
||||
const getApiKey = options.getApiKey ?? defaultGetApiKey(settingsManager);
|
||||
|
||||
const skills = options.skills ?? discoverSkills(cwd, agentDir, settingsManager.getSkillsSettings());
|
||||
time("discoverSkills");
|
||||
|
||||
|
|
@ -661,7 +573,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
if (!currentModel) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await getApiKey(currentModel);
|
||||
const key = await modelRegistry.getApiKey(currentModel);
|
||||
if (!key) {
|
||||
throw new Error(`No API key found for provider "${currentModel.provider}"`);
|
||||
}
|
||||
|
|
@ -685,7 +597,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
hookRunner,
|
||||
customTools: customToolsResult.tools,
|
||||
skillsSettings: settingsManager.getSkillsSettings(),
|
||||
resolveApiKey: getApiKey,
|
||||
modelRegistry,
|
||||
});
|
||||
time("createAgentSession");
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ export interface Settings {
|
|||
customTools?: string[]; // Array of custom tool file paths
|
||||
skills?: SkillsSettings;
|
||||
terminal?: TerminalSettings;
|
||||
apiKeys?: Record<string, string>; // provider -> API key (e.g., { "anthropic": "sk-..." })
|
||||
}
|
||||
|
||||
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
|
||||
|
|
@ -366,27 +365,4 @@ export class SettingsManager {
|
|||
this.globalSettings.terminal.showImages = show;
|
||||
this.save();
|
||||
}
|
||||
|
||||
getApiKey(provider: string): string | undefined {
|
||||
return this.settings.apiKeys?.[provider];
|
||||
}
|
||||
|
||||
setApiKey(provider: string, key: string): void {
|
||||
if (!this.globalSettings.apiKeys) {
|
||||
this.globalSettings.apiKeys = {};
|
||||
}
|
||||
this.globalSettings.apiKeys[provider] = key;
|
||||
this.save();
|
||||
}
|
||||
|
||||
removeApiKey(provider: string): void {
|
||||
if (this.globalSettings.apiKeys) {
|
||||
delete this.globalSettings.apiKeys[provider];
|
||||
this.save();
|
||||
}
|
||||
}
|
||||
|
||||
getApiKeys(): Record<string, string> {
|
||||
return this.settings.apiKeys ?? {};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue