mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 07:04:45 +00:00
WIP: Remove global state from pi-ai OAuth/API key handling
- Remove setApiKey, resolveApiKey, and global apiKeys Map from stream.ts - Rename getApiKey to getApiKeyFromEnv (only checks env vars) - Remove OAuth storage layer (storage.ts deleted) - OAuth login/refresh functions now return credentials instead of saving - getOAuthApiKey/refreshOAuthToken now take credentials as params - Add test/oauth.ts helper for ai package tests - Simplify root npm run check (single biome + tsgo pass) - Remove redundant check scripts from most packages - Add web-ui and coding-agent examples to biome/tsgo includes coding-agent still has compile errors - needs refactoring for new API
This commit is contained in:
parent
d93cbf8c32
commit
030788140a
51 changed files with 646 additions and 570 deletions
|
|
@ -25,10 +25,11 @@
|
|||
},
|
||||
"files": {
|
||||
"includes": [
|
||||
"packages/*/src/**/*",
|
||||
"packages/*/test/**/*",
|
||||
"*.json",
|
||||
"*.md",
|
||||
"packages/*/src/**/*.ts",
|
||||
"packages/*/test/**/*.ts",
|
||||
"packages/coding-agent/examples/**/*.ts",
|
||||
"packages/web-ui/src/**/*.ts",
|
||||
"packages/web-ui/example/**/*.ts",
|
||||
"!**/node_modules/**/*",
|
||||
"!**/test-sessions.ts",
|
||||
"!**/models.generated.ts",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
"build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi-proxy && npm run build -w @mariozechner/pi",
|
||||
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui,proxy\" --prefix-colors \"cyan,yellow,red,white,green,magenta,blue\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\" \"npm run dev -w @mariozechner/pi-proxy\"",
|
||||
"dev:tsc": "concurrently --names \"ai,web-ui\" --prefix-colors \"cyan,green\" \"npm run dev:tsc -w @mariozechner/pi-ai\" \"npm run dev:tsc -w @mariozechner/pi-web-ui\"",
|
||||
"check": "biome check --write . && npm run check --workspaces --if-present && tsgo --noEmit",
|
||||
"check": "biome check --write . && tsgo --noEmit && npm run check -w @mariozechner/pi-web-ui",
|
||||
"test": "npm run test --workspaces --if-present",
|
||||
"version:patch": "npm version patch -ws --no-git-tag-version && node scripts/sync-versions.js && rm -rf node_modules packages/*/node_modules package-lock.json && npm install",
|
||||
"version:minor": "npm version minor -ws --no-git-tag-version && node scripts/sync-versions.js && rm -rf node_modules packages/*/node_modules package-lock.json && npm install",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
"clean": "rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"check": "tsgo --noEmit",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -2,9 +2,14 @@
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## Breaking Changes
|
||||
- **setApiKey, resolveApiKey**: removed. You need to create your own api key storage/resolution
|
||||
- **getApiKey**: renamed to `getApiKeyFromEnv`. Given a provider, checks for the known env variable holding the API key.
|
||||
### Breaking Changes
|
||||
- **setApiKey, resolveApiKey**: Removed. Callers must manage their own API key storage/resolution.
|
||||
- **getApiKey**: Renamed to `getApiKeyFromEnv`. Only checks environment variables for known providers.
|
||||
- **OAuth storage removed**: All storage functions (`loadOAuthCredentials`, `saveOAuthCredentials`, `setOAuthStorage`, etc.) removed. Callers are responsible for storing credentials.
|
||||
- **OAuth login functions**: `loginAnthropic`, `loginGitHubCopilot`, `loginGeminiCli`, `loginAntigravity` now return `OAuthCredentials` instead of saving to disk.
|
||||
- **refreshOAuthToken**: Now takes `(provider, credentials)` and returns new `OAuthCredentials` instead of saving.
|
||||
- **getOAuthApiKey**: Now takes `(provider, credentials)` and returns `{ newCredentials, apiKey }` or null.
|
||||
- **OAuthCredentials type**: No longer includes `type: "oauth"` discriminator. Callers add discriminator when storing.
|
||||
|
||||
## [0.27.7] - 2025-12-24
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
"build": "npm run generate-models && tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"dev:tsc": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"check": "biome check --write . && tsgo --noEmit",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -18,8 +18,9 @@ import type {
|
|||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Get API key from environment variables (sync).
|
||||
* Does NOT check OAuth credentials - use getApiKeyAsync for that.
|
||||
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getApiKeyFromEnv(provider: KnownProvider): string | undefined;
|
||||
export function getApiKeyFromEnv(provider: string): string | undefined;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
*/
|
||||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { type OAuthCredentials, saveOAuthCredentials } from "./storage.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
|
|
@ -30,7 +30,7 @@ function generatePKCE(): { verifier: string; challenge: string } {
|
|||
export async function loginAnthropic(
|
||||
onAuthUrl: (url: string) => void,
|
||||
onPromptCode: () => Promise<string>,
|
||||
): Promise<void> {
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
|
||||
// Build authorization URL
|
||||
|
|
@ -87,14 +87,11 @@ export async function loginAnthropic(
|
|||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
// Save credentials
|
||||
const credentials: OAuthCredentials = {
|
||||
type: "oauth",
|
||||
return {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
};
|
||||
|
||||
saveOAuthCredentials("anthropic", credentials);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -123,7 +120,6 @@ export async function refreshAnthropicToken(refreshToken: string): Promise<OAuth
|
|||
};
|
||||
|
||||
return {
|
||||
type: "oauth",
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
*/
|
||||
|
||||
import { getModels } from "../../models.js";
|
||||
import { type OAuthCredentials, saveOAuthCredentials } from "./storage.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
|
@ -63,7 +63,7 @@ function getUrls(domain: string): {
|
|||
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
||||
* Returns API URL like https://api.individual.githubcopilot.com
|
||||
*/
|
||||
export function getBaseUrlFromToken(token: string): string | null {
|
||||
function getBaseUrlFromToken(token: string): string | null {
|
||||
const match = token.match(/proxy-ep=([^;]+)/);
|
||||
if (!match) return null;
|
||||
const proxyHost = match[1];
|
||||
|
|
@ -217,7 +217,6 @@ export async function refreshGitHubCopilotToken(
|
|||
}
|
||||
|
||||
return {
|
||||
type: "oauth",
|
||||
refresh: refreshToken,
|
||||
access: token,
|
||||
expires: expiresAt * 1000 - 5 * 60 * 1000,
|
||||
|
|
@ -229,11 +228,7 @@ export async function refreshGitHubCopilotToken(
|
|||
* Enable a model for the user's GitHub Copilot account.
|
||||
* This is required for some models (like Claude, Grok) before they can be used.
|
||||
*/
|
||||
export async function enableGitHubCopilotModel(
|
||||
token: string,
|
||||
modelId: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<boolean> {
|
||||
async function enableGitHubCopilotModel(token: string, modelId: string, enterpriseDomain?: string): Promise<boolean> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
const url = `${baseUrl}/models/${modelId}/policy`;
|
||||
|
||||
|
|
@ -259,7 +254,7 @@ export async function enableGitHubCopilotModel(
|
|||
* Enable all known GitHub Copilot models that may require policy acceptance.
|
||||
* Called after successful login to ensure all models are available.
|
||||
*/
|
||||
export async function enableAllGitHubCopilotModels(
|
||||
async function enableAllGitHubCopilotModels(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
onProgress?: (model: string, success: boolean) => void,
|
||||
|
|
@ -312,9 +307,5 @@ export async function loginGitHubCopilot(options: {
|
|||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
||||
|
||||
// Save credentials
|
||||
saveOAuthCredentials("github-copilot", credentials);
|
||||
|
||||
return credentials;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { createServer, type Server } from "http";
|
||||
import { type OAuthCredentials, saveOAuthCredentials } from "./storage.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
|
|
@ -30,11 +30,6 @@ const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
|||
// Fallback project ID when discovery fails
|
||||
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
|
||||
|
||||
export interface AntigravityCredentials extends OAuthCredentials {
|
||||
projectId: string;
|
||||
email?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge
|
||||
*/
|
||||
|
|
@ -220,7 +215,6 @@ export async function refreshAntigravityToken(refreshToken: string, projectId: s
|
|||
};
|
||||
|
||||
return {
|
||||
type: "oauth",
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
|
|
@ -237,7 +231,7 @@ export async function refreshAntigravityToken(refreshToken: string, projectId: s
|
|||
export async function loginAntigravity(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<AntigravityCredentials> {
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
|
|
@ -317,8 +311,7 @@ export async function loginAntigravity(
|
|||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: AntigravityCredentials = {
|
||||
type: "oauth",
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
|
|
@ -326,8 +319,6 @@ export async function loginAntigravity(
|
|||
email,
|
||||
};
|
||||
|
||||
saveOAuthCredentials("google-antigravity", credentials);
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.close();
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { createServer, type Server } from "http";
|
||||
import { type OAuthCredentials, saveOAuthCredentials } from "./storage.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const CLIENT_ID = decode(
|
||||
|
|
@ -22,11 +22,6 @@ const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
|||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
|
||||
export interface GoogleCloudCredentials extends OAuthCredentials {
|
||||
projectId: string;
|
||||
email?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge
|
||||
*/
|
||||
|
|
@ -251,7 +246,6 @@ export async function refreshGoogleCloudToken(refreshToken: string, projectId: s
|
|||
};
|
||||
|
||||
return {
|
||||
type: "oauth",
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
|
|
@ -268,7 +262,7 @@ export async function refreshGoogleCloudToken(refreshToken: string, projectId: s
|
|||
export async function loginGeminiCli(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<GoogleCloudCredentials> {
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
|
|
@ -348,8 +342,7 @@ export async function loginGeminiCli(
|
|||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: GoogleCloudCredentials = {
|
||||
type: "oauth",
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
|
|
@ -357,8 +350,6 @@ export async function loginGeminiCli(
|
|||
email,
|
||||
};
|
||||
|
||||
saveOAuthCredentials("google-gemini-cli", credentials);
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.close();
|
||||
|
|
|
|||
|
|
@ -13,9 +13,6 @@
|
|||
export { loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
|
||||
// GitHub Copilot
|
||||
export {
|
||||
enableAllGitHubCopilotModels,
|
||||
enableGitHubCopilotModel,
|
||||
getBaseUrlFromToken,
|
||||
getGitHubCopilotBaseUrl,
|
||||
loginGitHubCopilot,
|
||||
normalizeDomain,
|
||||
|
|
@ -23,32 +20,16 @@ export {
|
|||
} from "./github-copilot.js";
|
||||
// Google Antigravity
|
||||
export {
|
||||
type AntigravityCredentials,
|
||||
loginAntigravity,
|
||||
refreshAntigravityToken,
|
||||
} from "./google-antigravity.js";
|
||||
// Google Gemini CLI
|
||||
export {
|
||||
type GoogleCloudCredentials,
|
||||
loginGeminiCli,
|
||||
refreshGoogleCloudToken,
|
||||
} from "./google-gemini-cli.js";
|
||||
// Storage
|
||||
export {
|
||||
getOAuthPath,
|
||||
hasOAuthCredentials,
|
||||
listOAuthProviders,
|
||||
loadOAuthCredentials,
|
||||
loadOAuthStorage,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
type OAuthStorage,
|
||||
type OAuthStorageBackend,
|
||||
removeOAuthCredentials,
|
||||
resetOAuthStorage,
|
||||
saveOAuthCredentials,
|
||||
setOAuthStorage,
|
||||
} from "./storage.js";
|
||||
|
||||
export * from "./types.js";
|
||||
|
||||
// ============================================================================
|
||||
// High-level API
|
||||
|
|
@ -58,15 +39,16 @@ import { refreshAnthropicToken } from "./anthropic.js";
|
|||
import { refreshGitHubCopilotToken } from "./github-copilot.js";
|
||||
import { refreshAntigravityToken } from "./google-antigravity.js";
|
||||
import { refreshGoogleCloudToken } from "./google-gemini-cli.js";
|
||||
import type { OAuthCredentials, OAuthProvider } from "./storage.js";
|
||||
import { loadOAuthCredentials, removeOAuthCredentials, saveOAuthCredentials } from "./storage.js";
|
||||
import type { OAuthCredentials, OAuthProvider, OAuthProviderInfo } from "./types.js";
|
||||
|
||||
/**
|
||||
* Refresh token for any OAuth provider.
|
||||
* Saves the new credentials and returns the new access token.
|
||||
*/
|
||||
export async function refreshToken(provider: OAuthProvider): Promise<string> {
|
||||
const credentials = loadOAuthCredentials(provider);
|
||||
export async function refreshOAuthToken(
|
||||
provider: OAuthProvider,
|
||||
credentials: OAuthCredentials,
|
||||
): Promise<OAuthCredentials> {
|
||||
if (!credentials) {
|
||||
throw new Error(`No OAuth credentials found for ${provider}`);
|
||||
}
|
||||
|
|
@ -96,8 +78,7 @@ export async function refreshToken(provider: OAuthProvider): Promise<string> {
|
|||
throw new Error(`Unknown OAuth provider: ${provider}`);
|
||||
}
|
||||
|
||||
saveOAuthCredentials(provider, newCredentials);
|
||||
return newCredentials.access;
|
||||
return newCredentials;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -107,81 +88,30 @@ export async function refreshToken(provider: OAuthProvider): Promise<string> {
|
|||
* For google-gemini-cli and antigravity, returns JSON-encoded { token, projectId }
|
||||
*
|
||||
* @returns API key string, or null if no credentials
|
||||
* @throws Error if refresh fails
|
||||
*/
|
||||
export async function getOAuthApiKey(provider: OAuthProvider): Promise<string | null> {
|
||||
const credentials = loadOAuthCredentials(provider);
|
||||
if (!credentials) {
|
||||
export async function getOAuthApiKey(
|
||||
provider: OAuthProvider,
|
||||
credentials: Record<string, OAuthCredentials>,
|
||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||
let creds = credentials[provider];
|
||||
if (!creds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Providers that need projectId in the API key
|
||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
|
||||
// Check if expired
|
||||
if (Date.now() >= credentials.expires) {
|
||||
// Refresh if expired
|
||||
if (Date.now() >= creds.expires) {
|
||||
try {
|
||||
const newToken = await refreshToken(provider);
|
||||
|
||||
// For providers that need projectId, return JSON
|
||||
if (needsProjectId) {
|
||||
const refreshedCreds = loadOAuthCredentials(provider);
|
||||
if (refreshedCreds?.projectId) {
|
||||
return JSON.stringify({ token: newToken, projectId: refreshedCreds.projectId });
|
||||
}
|
||||
}
|
||||
|
||||
return newToken;
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh OAuth token for ${provider}:`, error);
|
||||
removeOAuthCredentials(provider);
|
||||
return null;
|
||||
creds = await refreshOAuthToken(provider, creds);
|
||||
} catch (_error) {
|
||||
throw new Error(`Failed to refresh OAuth token for ${provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
// For providers that need projectId, return JSON
|
||||
if (needsProjectId) {
|
||||
if (!credentials.projectId) {
|
||||
return null;
|
||||
}
|
||||
return JSON.stringify({ token: credentials.access, projectId: credentials.projectId });
|
||||
}
|
||||
|
||||
return credentials.access;
|
||||
}
|
||||
|
||||
/**
|
||||
* Map model provider to OAuth provider.
|
||||
* Returns undefined if the provider doesn't use OAuth.
|
||||
*/
|
||||
export function getOAuthProviderForModelProvider(modelProvider: string): OAuthProvider | undefined {
|
||||
const mapping: Record<string, OAuthProvider> = {
|
||||
anthropic: "anthropic",
|
||||
"github-copilot": "github-copilot",
|
||||
"google-gemini-cli": "google-gemini-cli",
|
||||
"google-antigravity": "google-antigravity",
|
||||
};
|
||||
return mapping[modelProvider];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Login/Logout types for convenience
|
||||
// ============================================================================
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthProviderInfo {
|
||||
id: OAuthProvider;
|
||||
name: string;
|
||||
available: boolean;
|
||||
const needsProjectId = provider === "google-gemini-cli" || provider === "google-antigravity";
|
||||
const apiKey = needsProjectId ? JSON.stringify({ token: creds.access, projectId: creds.projectId }) : creds.access;
|
||||
return { newCredentials: creds, apiKey };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
/**
|
||||
* OAuth credential storage with configurable backend.
|
||||
*
|
||||
* Default: ~/.pi/agent/oauth.json
|
||||
* Override with setOAuthStorage() for custom storage locations or backends.
|
||||
*/
|
||||
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { dirname, join } from "path";
|
||||
|
||||
export interface OAuthCredentials {
|
||||
type: "oauth";
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
enterpriseUrl?: string;
|
||||
projectId?: string;
|
||||
email?: string;
|
||||
}
|
||||
|
||||
export interface OAuthStorage {
|
||||
[provider: string]: OAuthCredentials;
|
||||
}
|
||||
|
||||
export type OAuthProvider = "anthropic" | "github-copilot" | "google-gemini-cli" | "google-antigravity";
|
||||
|
||||
/**
|
||||
* Storage backend interface.
|
||||
* Implement this to use a custom storage location or backend.
|
||||
*/
|
||||
export interface OAuthStorageBackend {
|
||||
/** Load all OAuth credentials. Return empty object if none exist. */
|
||||
load(): OAuthStorage;
|
||||
/** Save all OAuth credentials. */
|
||||
save(storage: OAuthStorage): void;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Default filesystem backend
|
||||
// ============================================================================
|
||||
|
||||
const DEFAULT_PATH = join(homedir(), ".pi", "agent", "oauth.json");
|
||||
|
||||
function defaultLoad(): OAuthStorage {
|
||||
if (!existsSync(DEFAULT_PATH)) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
const content = readFileSync(DEFAULT_PATH, "utf-8");
|
||||
return JSON.parse(content);
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function defaultSave(storage: OAuthStorage): void {
|
||||
const configDir = dirname(DEFAULT_PATH);
|
||||
if (!existsSync(configDir)) {
|
||||
mkdirSync(configDir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
writeFileSync(DEFAULT_PATH, JSON.stringify(storage, null, 2), "utf-8");
|
||||
chmodSync(DEFAULT_PATH, 0o600);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Configurable backend
|
||||
// ============================================================================
|
||||
|
||||
let currentBackend: OAuthStorageBackend = {
|
||||
load: defaultLoad,
|
||||
save: defaultSave,
|
||||
};
|
||||
|
||||
/**
|
||||
* Configure the OAuth storage backend.
|
||||
*
|
||||
* @example
|
||||
* // Custom file path
|
||||
* setOAuthStorage({
|
||||
* load: () => JSON.parse(readFileSync('/custom/path/oauth.json', 'utf-8')),
|
||||
* save: (storage) => writeFileSync('/custom/path/oauth.json', JSON.stringify(storage))
|
||||
* });
|
||||
*
|
||||
* @example
|
||||
* // In-memory storage (for testing)
|
||||
* let memoryStorage = {};
|
||||
* setOAuthStorage({
|
||||
* load: () => memoryStorage,
|
||||
* save: (storage) => { memoryStorage = storage; }
|
||||
* });
|
||||
*/
|
||||
export function setOAuthStorage(backend: OAuthStorageBackend): void {
|
||||
currentBackend = backend;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset to default filesystem storage (~/.pi/agent/oauth.json)
|
||||
*/
|
||||
export function resetOAuthStorage(): void {
|
||||
currentBackend = { load: defaultLoad, save: defaultSave };
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default OAuth path (for reference, may not be used if custom backend is set)
|
||||
*/
|
||||
export function getOAuthPath(): string {
|
||||
return DEFAULT_PATH;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API (uses current backend)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Load all OAuth credentials
|
||||
*/
|
||||
export function loadOAuthStorage(): OAuthStorage {
|
||||
return currentBackend.load();
|
||||
}
|
||||
|
||||
/**
|
||||
* Load OAuth credentials for a specific provider
|
||||
*/
|
||||
export function loadOAuthCredentials(provider: string): OAuthCredentials | null {
|
||||
const storage = currentBackend.load();
|
||||
return storage[provider] || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save OAuth credentials for a specific provider
|
||||
*/
|
||||
export function saveOAuthCredentials(provider: string, creds: OAuthCredentials): void {
|
||||
const storage = currentBackend.load();
|
||||
storage[provider] = creds;
|
||||
currentBackend.save(storage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove OAuth credentials for a specific provider
|
||||
*/
|
||||
export function removeOAuthCredentials(provider: string): void {
|
||||
const storage = currentBackend.load();
|
||||
delete storage[provider];
|
||||
currentBackend.save(storage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if OAuth credentials exist for a provider
|
||||
*/
|
||||
export function hasOAuthCredentials(provider: string): boolean {
|
||||
return loadOAuthCredentials(provider) !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers with OAuth credentials
|
||||
*/
|
||||
export function listOAuthProviders(): string[] {
|
||||
const storage = currentBackend.load();
|
||||
return Object.keys(storage);
|
||||
}
|
||||
27
packages/ai/src/utils/oauth/types.ts
Normal file
27
packages/ai/src/utils/oauth/types.ts
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
export type OAuthCredentials = {
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
enterpriseUrl?: string;
|
||||
projectId?: string;
|
||||
email?: string;
|
||||
};
|
||||
|
||||
export type OAuthProvider = "anthropic" | "github-copilot" | "google-gemini-cli" | "google-antigravity";
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthProviderInfo {
|
||||
id: OAuthProvider;
|
||||
name: string;
|
||||
available: boolean;
|
||||
}
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey, stream } from "../src/stream.js";
|
||||
import { complete, stream } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const geminiCliToken = await resolveApiKey("google-gemini-cli");
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js";
|
|||
import { calculateTool } from "../src/agent/tools/calculate.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { resolveApiKey } from "../src/stream.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
|
|
@ -13,6 +12,7 @@ import type {
|
|||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import type { ChildProcess } from "child_process";
|
|||
import { execSync, spawn } from "child_process";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey } from "../src/stream.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { AssistantMessage, Context, Model, Usage } from "../src/types.js";
|
||||
import { isContextOverflow } from "../src/utils/overflow.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey } from "../src/stream.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ import { join } from "node:path";
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { Api, Context, Model, Tool, ToolResultMessage } from "../src/index.js";
|
||||
import { complete, getModel, resolveApiKey } from "../src/index.js";
|
||||
import { complete, getModel } from "../src/index.js";
|
||||
import type { OptionsForApi } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
89
packages/ai/test/oauth.ts
Normal file
89
packages/ai/test/oauth.ts
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Test helper for resolving API keys from ~/.pi/agent/auth.json
|
||||
*
|
||||
* Supports both API key and OAuth credentials.
|
||||
* OAuth tokens are automatically refreshed if expired and saved back to auth.json.
|
||||
*/
|
||||
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { dirname, join } from "path";
|
||||
import { getOAuthApiKey } from "../src/utils/oauth/index.js";
|
||||
import type { OAuthCredentials, OAuthProvider } from "../src/utils/oauth/types.js";
|
||||
|
||||
const AUTH_PATH = join(homedir(), ".pi", "agent", "auth.json");
|
||||
|
||||
type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
key: string;
|
||||
};
|
||||
|
||||
type OAuthCredentialEntry = {
|
||||
type: "oauth";
|
||||
} & OAuthCredentials;
|
||||
|
||||
type AuthCredential = ApiKeyCredential | OAuthCredentialEntry;
|
||||
|
||||
type AuthStorage = Record<string, AuthCredential>;
|
||||
|
||||
function loadAuthStorage(): AuthStorage {
|
||||
if (!existsSync(AUTH_PATH)) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
const content = readFileSync(AUTH_PATH, "utf-8");
|
||||
return JSON.parse(content);
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveAuthStorage(storage: AuthStorage): void {
|
||||
const configDir = dirname(AUTH_PATH);
|
||||
if (!existsSync(configDir)) {
|
||||
mkdirSync(configDir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
writeFileSync(AUTH_PATH, JSON.stringify(storage, null, 2), "utf-8");
|
||||
chmodSync(AUTH_PATH, 0o600);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve API key for a provider from ~/.pi/agent/auth.json
|
||||
*
|
||||
* For API key credentials, returns the key directly.
|
||||
* For OAuth credentials, returns the access token (refreshing if expired and saving back).
|
||||
*
|
||||
* For google-gemini-cli and google-antigravity, returns JSON-encoded { token, projectId }
|
||||
*/
|
||||
export async function resolveApiKey(provider: string): Promise<string | undefined> {
|
||||
const storage = loadAuthStorage();
|
||||
const entry = storage[provider];
|
||||
|
||||
if (!entry) return undefined;
|
||||
|
||||
if (entry.type === "api_key") {
|
||||
return entry.key;
|
||||
}
|
||||
|
||||
if (entry.type === "oauth") {
|
||||
// Build OAuthCredentials record for getOAuthApiKey
|
||||
const oauthCredentials: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(storage)) {
|
||||
if (value.type === "oauth") {
|
||||
const { type: _, ...creds } = value;
|
||||
oauthCredentials[key] = creds;
|
||||
}
|
||||
}
|
||||
|
||||
const result = await getOAuthApiKey(provider as OAuthProvider, oauthCredentials);
|
||||
if (!result) return undefined;
|
||||
|
||||
// Save refreshed credentials back to auth.json
|
||||
storage[provider] = { type: "oauth", ...result.newCredentials };
|
||||
saveAuthStorage(storage);
|
||||
|
||||
return result.apiKey;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -5,9 +5,10 @@ import { dirname, join } from "path";
|
|||
import { fileURLToPath } from "url";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey, stream } from "../src/stream.js";
|
||||
import { complete, stream } from "../src/stream.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
|
||||
import { StringEnum } from "../src/utils/typebox-helpers.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { resolveApiKey, stream } from "../src/stream.js";
|
||||
import { stream } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey } from "../src/stream.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@
|
|||
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey } from "../src/stream.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi, Usage } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { complete, resolveApiKey } from "../src/stream.js";
|
||||
import { complete } from "../src/stream.js";
|
||||
import type { Api, Context, Model, OptionsForApi, ToolResultMessage } from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Empty schema for test tools - must be proper OBJECT type for Cloud Code Assist
|
||||
const emptySchema = Type.Object({});
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { CustomToolFactory } from "@mariozechner/pi-coding-agent";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
|
||||
const factory: CustomToolFactory = (pi) => ({
|
||||
name: "hello",
|
||||
label: "Hello",
|
||||
description: "A simple greeting tool",
|
||||
parameters: Type.Object({
|
||||
name: Type.String({ description: "Name to greet" }),
|
||||
}),
|
||||
const factory: CustomToolFactory = (_pi) => ({
|
||||
name: "hello",
|
||||
label: "Hello",
|
||||
description: "A simple greeting tool",
|
||||
parameters: Type.Object({
|
||||
name: Type.String({ description: "Name to greet" }),
|
||||
}),
|
||||
|
||||
async execute(toolCallId, params) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Hello, ${params.name}!` }],
|
||||
details: { greeted: params.name },
|
||||
};
|
||||
},
|
||||
async execute(_toolCallId, params) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Hello, ${params.name}!` }],
|
||||
details: { greeted: params.name },
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
export default factory;
|
||||
export default factory;
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
* Question Tool - Let the LLM ask the user a question with options
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { Text } from "@mariozechner/pi-tui";
|
||||
import type { CustomAgentTool, CustomToolFactory } from "@mariozechner/pi-coding-agent";
|
||||
import { Text } from "@mariozechner/pi-tui";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
|
||||
interface QuestionDetails {
|
||||
question: string;
|
||||
|
|
@ -57,7 +57,7 @@ const factory: CustomToolFactory = (pi) => {
|
|||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("question ")) + theme.fg("muted", args.question);
|
||||
if (args.options?.length) {
|
||||
text += "\n" + theme.fg("dim", ` Options: ${args.options.join(", ")}`);
|
||||
text += `\n${theme.fg("dim", ` Options: ${args.options.join(", ")}`)}`;
|
||||
}
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
|
|
|||
|
|
@ -129,8 +129,7 @@ export function discoverAgents(cwd: string, scope: AgentScope): AgentDiscoveryRe
|
|||
const projectAgentsDir = findNearestProjectAgentsDir(cwd);
|
||||
|
||||
const userAgents = scope === "project" ? [] : loadAgentsFromDir(userDir, "user");
|
||||
const projectAgents =
|
||||
scope === "user" || !projectAgentsDir ? [] : loadAgentsFromDir(projectAgentsDir, "project");
|
||||
const projectAgents = scope === "user" || !projectAgentsDir ? [] : loadAgentsFromDir(projectAgentsDir, "project");
|
||||
|
||||
const agentMap = new Map<string, AgentConfig>();
|
||||
|
||||
|
|
|
|||
|
|
@ -16,11 +16,16 @@ import { spawn } from "node:child_process";
|
|||
import * as fs from "node:fs";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { AgentToolResult, Message } from "@mariozechner/pi-ai";
|
||||
import { StringEnum } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
type CustomAgentTool,
|
||||
type CustomToolFactory,
|
||||
getMarkdownTheme,
|
||||
type ToolAPI,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
import { getMarkdownTheme, type CustomAgentTool, type CustomToolFactory, type ToolAPI } from "@mariozechner/pi-coding-agent";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { type AgentConfig, type AgentScope, discoverAgents, formatAgentList } from "./agents.js";
|
||||
|
||||
const MAX_PARALLEL_TASKS = 8;
|
||||
|
|
@ -30,12 +35,23 @@ const COLLAPSED_ITEM_COUNT = 10;
|
|||
|
||||
function formatTokens(count: number): string {
|
||||
if (count < 1000) return count.toString();
|
||||
if (count < 10000) return (count / 1000).toFixed(1) + "k";
|
||||
if (count < 1000000) return Math.round(count / 1000) + "k";
|
||||
return (count / 1000000).toFixed(1) + "M";
|
||||
if (count < 10000) return `${(count / 1000).toFixed(1)}k`;
|
||||
if (count < 1000000) return `${Math.round(count / 1000)}k`;
|
||||
return `${(count / 1000000).toFixed(1)}M`;
|
||||
}
|
||||
|
||||
function formatUsageStats(usage: { input: number; output: number; cacheRead: number; cacheWrite: number; cost: number; contextTokens?: number; turns?: number }, model?: string): string {
|
||||
function formatUsageStats(
|
||||
usage: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
cost: number;
|
||||
contextTokens?: number;
|
||||
turns?: number;
|
||||
},
|
||||
model?: string,
|
||||
): string {
|
||||
const parts: string[] = [];
|
||||
if (usage.turns) parts.push(`${usage.turns} turn${usage.turns > 1 ? "s" : ""}`);
|
||||
if (usage.input) parts.push(`↑${formatTokens(usage.input)}`);
|
||||
|
|
@ -50,16 +66,20 @@ function formatUsageStats(usage: { input: number; output: number; cacheRead: num
|
|||
return parts.join(" ");
|
||||
}
|
||||
|
||||
function formatToolCall(toolName: string, args: Record<string, unknown>, themeFg: (color: any, text: string) => string): string {
|
||||
function formatToolCall(
|
||||
toolName: string,
|
||||
args: Record<string, unknown>,
|
||||
themeFg: (color: any, text: string) => string,
|
||||
): string {
|
||||
const shortenPath = (p: string) => {
|
||||
const home = os.homedir();
|
||||
return p.startsWith(home) ? "~" + p.slice(home.length) : p;
|
||||
return p.startsWith(home) ? `~${p.slice(home.length)}` : p;
|
||||
};
|
||||
|
||||
switch (toolName) {
|
||||
case "bash": {
|
||||
const command = (args.command as string) || "...";
|
||||
const preview = command.length > 60 ? command.slice(0, 60) + "..." : command;
|
||||
const preview = command.length > 60 ? `${command.slice(0, 60)}...` : command;
|
||||
return themeFg("muted", "$ ") + themeFg("toolOutput", preview);
|
||||
}
|
||||
case "read": {
|
||||
|
|
@ -100,11 +120,15 @@ function formatToolCall(toolName: string, args: Record<string, unknown>, themeFg
|
|||
case "grep": {
|
||||
const pattern = (args.pattern || "") as string;
|
||||
const rawPath = (args.path || ".") as string;
|
||||
return themeFg("muted", "grep ") + themeFg("accent", `/${pattern}/`) + themeFg("dim", ` in ${shortenPath(rawPath)}`);
|
||||
return (
|
||||
themeFg("muted", "grep ") +
|
||||
themeFg("accent", `/${pattern}/`) +
|
||||
themeFg("dim", ` in ${shortenPath(rawPath)}`)
|
||||
);
|
||||
}
|
||||
default: {
|
||||
const argsStr = JSON.stringify(args);
|
||||
const preview = argsStr.length > 50 ? argsStr.slice(0, 50) + "..." : argsStr;
|
||||
const preview = argsStr.length > 50 ? `${argsStr.slice(0, 50)}...` : argsStr;
|
||||
return themeFg("accent", toolName) + themeFg("dim", ` ${preview}`);
|
||||
}
|
||||
}
|
||||
|
|
@ -171,7 +195,7 @@ function getDisplayItems(messages: Message[]): DisplayItem[] {
|
|||
async function mapWithConcurrencyLimit<TIn, TOut>(
|
||||
items: TIn[],
|
||||
concurrency: number,
|
||||
fn: (item: TIn, index: number) => Promise<TOut>
|
||||
fn: (item: TIn, index: number) => Promise<TOut>,
|
||||
): Promise<TOut[]> {
|
||||
if (items.length === 0) return [];
|
||||
const limit = Math.max(1, Math.min(concurrency, items.length));
|
||||
|
|
@ -207,7 +231,7 @@ async function runSingleAgent(
|
|||
step: number | undefined,
|
||||
signal: AbortSignal | undefined,
|
||||
onUpdate: OnUpdateCallback | undefined,
|
||||
makeDetails: (results: SingleResult[]) => SubagentDetails
|
||||
makeDetails: (results: SingleResult[]) => SubagentDetails,
|
||||
): Promise<SingleResult> {
|
||||
const agent = agents.find((a) => a.name === agentName);
|
||||
|
||||
|
|
@ -270,7 +294,11 @@ async function runSingleAgent(
|
|||
const processLine = (line: string) => {
|
||||
if (!line.trim()) return;
|
||||
let event: any;
|
||||
try { event = JSON.parse(line); } catch { return; }
|
||||
try {
|
||||
event = JSON.parse(line);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === "message_end" && event.message) {
|
||||
const msg = event.message as Message;
|
||||
|
|
@ -307,20 +335,26 @@ async function runSingleAgent(
|
|||
for (const line of lines) processLine(line);
|
||||
});
|
||||
|
||||
proc.stderr.on("data", (data) => { currentResult.stderr += data.toString(); });
|
||||
proc.stderr.on("data", (data) => {
|
||||
currentResult.stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (buffer.trim()) processLine(buffer);
|
||||
resolve(code ?? 0);
|
||||
});
|
||||
|
||||
proc.on("error", () => { resolve(1); });
|
||||
proc.on("error", () => {
|
||||
resolve(1);
|
||||
});
|
||||
|
||||
if (signal) {
|
||||
const killProc = () => {
|
||||
wasAborted = true;
|
||||
proc.kill("SIGTERM");
|
||||
setTimeout(() => { if (!proc.killed) proc.kill("SIGKILL"); }, 5000);
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) proc.kill("SIGKILL");
|
||||
}, 5000);
|
||||
};
|
||||
if (signal.aborted) killProc();
|
||||
else signal.addEventListener("abort", killProc, { once: true });
|
||||
|
|
@ -331,8 +365,18 @@ async function runSingleAgent(
|
|||
if (wasAborted) throw new Error("Subagent was aborted");
|
||||
return currentResult;
|
||||
} finally {
|
||||
if (tmpPromptPath) try { fs.unlinkSync(tmpPromptPath); } catch { /* ignore */ }
|
||||
if (tmpPromptDir) try { fs.rmdirSync(tmpPromptDir); } catch { /* ignore */ }
|
||||
if (tmpPromptPath)
|
||||
try {
|
||||
fs.unlinkSync(tmpPromptPath);
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
if (tmpPromptDir)
|
||||
try {
|
||||
fs.rmdirSync(tmpPromptDir);
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -359,7 +403,9 @@ const SubagentParams = Type.Object({
|
|||
tasks: Type.Optional(Type.Array(TaskItem, { description: "Array of {agent, task} for parallel execution" })),
|
||||
chain: Type.Optional(Type.Array(ChainItem, { description: "Array of {agent, task} for sequential execution" })),
|
||||
agentScope: Type.Optional(AgentScopeSchema),
|
||||
confirmProjectAgents: Type.Optional(Type.Boolean({ description: "Prompt before running project-local agents. Default: true.", default: true })),
|
||||
confirmProjectAgents: Type.Optional(
|
||||
Type.Boolean({ description: "Prompt before running project-local agents. Default: true.", default: true }),
|
||||
),
|
||||
cwd: Type.Optional(Type.String({ description: "Working directory for the agent process (single mode)" })),
|
||||
});
|
||||
|
||||
|
|
@ -397,13 +443,26 @@ const factory: CustomToolFactory = (pi) => {
|
|||
const hasSingle = Boolean(params.agent && params.task);
|
||||
const modeCount = Number(hasChain) + Number(hasTasks) + Number(hasSingle);
|
||||
|
||||
const makeDetails = (mode: "single" | "parallel" | "chain") => (results: SingleResult[]): SubagentDetails => ({
|
||||
mode, agentScope, projectAgentsDir: discovery.projectAgentsDir, results,
|
||||
});
|
||||
const makeDetails =
|
||||
(mode: "single" | "parallel" | "chain") =>
|
||||
(results: SingleResult[]): SubagentDetails => ({
|
||||
mode,
|
||||
agentScope,
|
||||
projectAgentsDir: discovery.projectAgentsDir,
|
||||
results,
|
||||
});
|
||||
|
||||
if (modeCount !== 1) {
|
||||
const available = agents.map((a) => `${a.name} (${a.source})`).join(", ") || "none";
|
||||
return { content: [{ type: "text", text: `Invalid parameters. Provide exactly one mode.\nAvailable agents: ${available}` }], details: makeDetails("single")([]) };
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Invalid parameters. Provide exactly one mode.\nAvailable agents: ${available}`,
|
||||
},
|
||||
],
|
||||
details: makeDetails("single")([]),
|
||||
};
|
||||
}
|
||||
|
||||
if ((agentScope === "project" || agentScope === "both") && confirmProjectAgents && pi.hasUI) {
|
||||
|
|
@ -419,51 +478,88 @@ const factory: CustomToolFactory = (pi) => {
|
|||
if (projectAgentsRequested.length > 0) {
|
||||
const names = projectAgentsRequested.map((a) => a.name).join(", ");
|
||||
const dir = discovery.projectAgentsDir ?? "(unknown)";
|
||||
const ok = await pi.ui.confirm("Run project-local agents?", `Agents: ${names}\nSource: ${dir}\n\nProject agents are repo-controlled. Only continue for trusted repositories.`);
|
||||
if (!ok) return { content: [{ type: "text", text: "Canceled: project-local agents not approved." }], details: makeDetails(hasChain ? "chain" : hasTasks ? "parallel" : "single")([]) };
|
||||
const ok = await pi.ui.confirm(
|
||||
"Run project-local agents?",
|
||||
`Agents: ${names}\nSource: ${dir}\n\nProject agents are repo-controlled. Only continue for trusted repositories.`,
|
||||
);
|
||||
if (!ok)
|
||||
return {
|
||||
content: [{ type: "text", text: "Canceled: project-local agents not approved." }],
|
||||
details: makeDetails(hasChain ? "chain" : hasTasks ? "parallel" : "single")([]),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (params.chain && params.chain.length > 0) {
|
||||
const results: SingleResult[] = [];
|
||||
let previousOutput = "";
|
||||
|
||||
|
||||
for (let i = 0; i < params.chain.length; i++) {
|
||||
const step = params.chain[i];
|
||||
const taskWithContext = step.task.replace(/\{previous\}/g, previousOutput);
|
||||
|
||||
|
||||
// Create update callback that includes all previous results
|
||||
const chainUpdate: OnUpdateCallback | undefined = onUpdate ? (partial) => {
|
||||
// Combine completed results with current streaming result
|
||||
const currentResult = partial.details?.results[0];
|
||||
if (currentResult) {
|
||||
const allResults = [...results, currentResult];
|
||||
onUpdate({
|
||||
content: partial.content,
|
||||
details: makeDetails("chain")(allResults),
|
||||
});
|
||||
}
|
||||
} : undefined;
|
||||
|
||||
const result = await runSingleAgent(pi, agents, step.agent, taskWithContext, step.cwd, i + 1, signal, chainUpdate, makeDetails("chain"));
|
||||
const chainUpdate: OnUpdateCallback | undefined = onUpdate
|
||||
? (partial) => {
|
||||
// Combine completed results with current streaming result
|
||||
const currentResult = partial.details?.results[0];
|
||||
if (currentResult) {
|
||||
const allResults = [...results, currentResult];
|
||||
onUpdate({
|
||||
content: partial.content,
|
||||
details: makeDetails("chain")(allResults),
|
||||
});
|
||||
}
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const result = await runSingleAgent(
|
||||
pi,
|
||||
agents,
|
||||
step.agent,
|
||||
taskWithContext,
|
||||
step.cwd,
|
||||
i + 1,
|
||||
signal,
|
||||
chainUpdate,
|
||||
makeDetails("chain"),
|
||||
);
|
||||
results.push(result);
|
||||
|
||||
const isError = result.exitCode !== 0 || result.stopReason === "error" || result.stopReason === "aborted";
|
||||
|
||||
const isError =
|
||||
result.exitCode !== 0 || result.stopReason === "error" || result.stopReason === "aborted";
|
||||
if (isError) {
|
||||
const errorMsg = result.errorMessage || result.stderr || getFinalOutput(result.messages) || "(no output)";
|
||||
return { content: [{ type: "text", text: `Chain stopped at step ${i + 1} (${step.agent}): ${errorMsg}` }], details: makeDetails("chain")(results), isError: true };
|
||||
const errorMsg =
|
||||
result.errorMessage || result.stderr || getFinalOutput(result.messages) || "(no output)";
|
||||
return {
|
||||
content: [{ type: "text", text: `Chain stopped at step ${i + 1} (${step.agent}): ${errorMsg}` }],
|
||||
details: makeDetails("chain")(results),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
previousOutput = getFinalOutput(result.messages);
|
||||
}
|
||||
return { content: [{ type: "text", text: getFinalOutput(results[results.length - 1].messages) || "(no output)" }], details: makeDetails("chain")(results) };
|
||||
return {
|
||||
content: [{ type: "text", text: getFinalOutput(results[results.length - 1].messages) || "(no output)" }],
|
||||
details: makeDetails("chain")(results),
|
||||
};
|
||||
}
|
||||
|
||||
if (params.tasks && params.tasks.length > 0) {
|
||||
if (params.tasks.length > MAX_PARALLEL_TASKS) return { content: [{ type: "text", text: `Too many parallel tasks (${params.tasks.length}). Max is ${MAX_PARALLEL_TASKS}.` }], details: makeDetails("parallel")([]) };
|
||||
|
||||
if (params.tasks.length > MAX_PARALLEL_TASKS)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Too many parallel tasks (${params.tasks.length}). Max is ${MAX_PARALLEL_TASKS}.`,
|
||||
},
|
||||
],
|
||||
details: makeDetails("parallel")([]),
|
||||
};
|
||||
|
||||
// Track all results for streaming updates
|
||||
const allResults: SingleResult[] = new Array(params.tasks.length);
|
||||
|
||||
|
||||
// Initialize placeholder results
|
||||
for (let i = 0; i < params.tasks.length; i++) {
|
||||
allResults[i] = {
|
||||
|
|
@ -476,21 +572,29 @@ const factory: CustomToolFactory = (pi) => {
|
|||
usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, cost: 0, contextTokens: 0, turns: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
const emitParallelUpdate = () => {
|
||||
if (onUpdate) {
|
||||
const running = allResults.filter(r => r.exitCode === -1).length;
|
||||
const done = allResults.filter(r => r.exitCode !== -1).length;
|
||||
const running = allResults.filter((r) => r.exitCode === -1).length;
|
||||
const done = allResults.filter((r) => r.exitCode !== -1).length;
|
||||
onUpdate({
|
||||
content: [{ type: "text", text: `Parallel: ${done}/${allResults.length} done, ${running} running...` }],
|
||||
content: [
|
||||
{ type: "text", text: `Parallel: ${done}/${allResults.length} done, ${running} running...` },
|
||||
],
|
||||
details: makeDetails("parallel")([...allResults]),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const results = await mapWithConcurrencyLimit(params.tasks, MAX_CONCURRENCY, async (t, index) => {
|
||||
const result = await runSingleAgent(
|
||||
pi, agents, t.agent, t.task, t.cwd, undefined, signal,
|
||||
pi,
|
||||
agents,
|
||||
t.agent,
|
||||
t.task,
|
||||
t.cwd,
|
||||
undefined,
|
||||
signal,
|
||||
// Per-task update callback
|
||||
(partial) => {
|
||||
if (partial.details?.results[0]) {
|
||||
|
|
@ -498,63 +602,106 @@ const factory: CustomToolFactory = (pi) => {
|
|||
emitParallelUpdate();
|
||||
}
|
||||
},
|
||||
makeDetails("parallel")
|
||||
makeDetails("parallel"),
|
||||
);
|
||||
allResults[index] = result;
|
||||
emitParallelUpdate();
|
||||
return result;
|
||||
});
|
||||
|
||||
|
||||
const successCount = results.filter((r) => r.exitCode === 0).length;
|
||||
const summaries = results.map((r) => {
|
||||
const output = getFinalOutput(r.messages);
|
||||
const preview = output.slice(0, 100) + (output.length > 100 ? "..." : "");
|
||||
return `[${r.agent}] ${r.exitCode === 0 ? "completed" : "failed"}: ${preview || "(no output)"}`;
|
||||
});
|
||||
return { content: [{ type: "text", text: `Parallel: ${successCount}/${results.length} succeeded\n\n${summaries.join("\n\n")}` }], details: makeDetails("parallel")(results) };
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Parallel: ${successCount}/${results.length} succeeded\n\n${summaries.join("\n\n")}`,
|
||||
},
|
||||
],
|
||||
details: makeDetails("parallel")(results),
|
||||
};
|
||||
}
|
||||
|
||||
if (params.agent && params.task) {
|
||||
const result = await runSingleAgent(pi, agents, params.agent, params.task, params.cwd, undefined, signal, onUpdate, makeDetails("single"));
|
||||
const result = await runSingleAgent(
|
||||
pi,
|
||||
agents,
|
||||
params.agent,
|
||||
params.task,
|
||||
params.cwd,
|
||||
undefined,
|
||||
signal,
|
||||
onUpdate,
|
||||
makeDetails("single"),
|
||||
);
|
||||
const isError = result.exitCode !== 0 || result.stopReason === "error" || result.stopReason === "aborted";
|
||||
if (isError) {
|
||||
const errorMsg = result.errorMessage || result.stderr || getFinalOutput(result.messages) || "(no output)";
|
||||
return { content: [{ type: "text", text: `Agent ${result.stopReason || "failed"}: ${errorMsg}` }], details: makeDetails("single")([result]), isError: true };
|
||||
const errorMsg =
|
||||
result.errorMessage || result.stderr || getFinalOutput(result.messages) || "(no output)";
|
||||
return {
|
||||
content: [{ type: "text", text: `Agent ${result.stopReason || "failed"}: ${errorMsg}` }],
|
||||
details: makeDetails("single")([result]),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
return { content: [{ type: "text", text: getFinalOutput(result.messages) || "(no output)" }], details: makeDetails("single")([result]) };
|
||||
return {
|
||||
content: [{ type: "text", text: getFinalOutput(result.messages) || "(no output)" }],
|
||||
details: makeDetails("single")([result]),
|
||||
};
|
||||
}
|
||||
|
||||
const available = agents.map((a) => `${a.name} (${a.source})`).join(", ") || "none";
|
||||
return { content: [{ type: "text", text: `Invalid parameters. Available agents: ${available}` }], details: makeDetails("single")([]) };
|
||||
return {
|
||||
content: [{ type: "text", text: `Invalid parameters. Available agents: ${available}` }],
|
||||
details: makeDetails("single")([]),
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
const scope: AgentScope = args.agentScope ?? "user";
|
||||
if (args.chain && args.chain.length > 0) {
|
||||
let text = theme.fg("toolTitle", theme.bold("subagent ")) + theme.fg("accent", `chain (${args.chain.length} steps)`) + theme.fg("muted", ` [${scope}]`);
|
||||
let text =
|
||||
theme.fg("toolTitle", theme.bold("subagent ")) +
|
||||
theme.fg("accent", `chain (${args.chain.length} steps)`) +
|
||||
theme.fg("muted", ` [${scope}]`);
|
||||
for (let i = 0; i < Math.min(args.chain.length, 3); i++) {
|
||||
const step = args.chain[i];
|
||||
// Clean up {previous} placeholder for display
|
||||
const cleanTask = step.task.replace(/\{previous\}/g, "").trim();
|
||||
const preview = cleanTask.length > 40 ? cleanTask.slice(0, 40) + "..." : cleanTask;
|
||||
text += "\n " + theme.fg("muted", `${i + 1}.`) + " " + theme.fg("accent", step.agent) + theme.fg("dim", ` ${preview}`);
|
||||
const preview = cleanTask.length > 40 ? `${cleanTask.slice(0, 40)}...` : cleanTask;
|
||||
text +=
|
||||
"\n " +
|
||||
theme.fg("muted", `${i + 1}.`) +
|
||||
" " +
|
||||
theme.fg("accent", step.agent) +
|
||||
theme.fg("dim", ` ${preview}`);
|
||||
}
|
||||
if (args.chain.length > 3) text += "\n " + theme.fg("muted", `... +${args.chain.length - 3} more`);
|
||||
if (args.chain.length > 3) text += `\n ${theme.fg("muted", `... +${args.chain.length - 3} more`)}`;
|
||||
return new Text(text, 0, 0);
|
||||
}
|
||||
if (args.tasks && args.tasks.length > 0) {
|
||||
let text = theme.fg("toolTitle", theme.bold("subagent ")) + theme.fg("accent", `parallel (${args.tasks.length} tasks)`) + theme.fg("muted", ` [${scope}]`);
|
||||
let text =
|
||||
theme.fg("toolTitle", theme.bold("subagent ")) +
|
||||
theme.fg("accent", `parallel (${args.tasks.length} tasks)`) +
|
||||
theme.fg("muted", ` [${scope}]`);
|
||||
for (const t of args.tasks.slice(0, 3)) {
|
||||
const preview = t.task.length > 40 ? t.task.slice(0, 40) + "..." : t.task;
|
||||
text += "\n " + theme.fg("accent", t.agent) + theme.fg("dim", ` ${preview}`);
|
||||
const preview = t.task.length > 40 ? `${t.task.slice(0, 40)}...` : t.task;
|
||||
text += `\n ${theme.fg("accent", t.agent)}${theme.fg("dim", ` ${preview}`)}`;
|
||||
}
|
||||
if (args.tasks.length > 3) text += "\n " + theme.fg("muted", `... +${args.tasks.length - 3} more`);
|
||||
if (args.tasks.length > 3) text += `\n ${theme.fg("muted", `... +${args.tasks.length - 3} more`)}`;
|
||||
return new Text(text, 0, 0);
|
||||
}
|
||||
const agentName = args.agent || "...";
|
||||
const preview = args.task ? (args.task.length > 60 ? args.task.slice(0, 60) + "..." : args.task) : "...";
|
||||
let text = theme.fg("toolTitle", theme.bold("subagent ")) + theme.fg("accent", agentName) + theme.fg("muted", ` [${scope}]`);
|
||||
text += "\n " + theme.fg("dim", preview);
|
||||
const preview = args.task ? (args.task.length > 60 ? `${args.task.slice(0, 60)}...` : args.task) : "...";
|
||||
let text =
|
||||
theme.fg("toolTitle", theme.bold("subagent ")) +
|
||||
theme.fg("accent", agentName) +
|
||||
theme.fg("muted", ` [${scope}]`);
|
||||
text += `\n ${theme.fg("dim", preview)}`;
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
|
|
@ -575,9 +722,9 @@ const factory: CustomToolFactory = (pi) => {
|
|||
for (const item of toShow) {
|
||||
if (item.type === "text") {
|
||||
const preview = expanded ? item.text : item.text.split("\n").slice(0, 3).join("\n");
|
||||
text += theme.fg("toolOutput", preview) + "\n";
|
||||
text += `${theme.fg("toolOutput", preview)}\n`;
|
||||
} else {
|
||||
text += theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)) + "\n";
|
||||
text += `${theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme))}\n`;
|
||||
}
|
||||
}
|
||||
return text.trimEnd();
|
||||
|
|
@ -592,10 +739,11 @@ const factory: CustomToolFactory = (pi) => {
|
|||
|
||||
if (expanded) {
|
||||
const container = new Container();
|
||||
let header = icon + " " + theme.fg("toolTitle", theme.bold(r.agent)) + theme.fg("muted", ` (${r.agentSource})`);
|
||||
if (isError && r.stopReason) header += " " + theme.fg("error", `[${r.stopReason}]`);
|
||||
let header = `${icon} ${theme.fg("toolTitle", theme.bold(r.agent))}${theme.fg("muted", ` (${r.agentSource})`)}`;
|
||||
if (isError && r.stopReason) header += ` ${theme.fg("error", `[${r.stopReason}]`)}`;
|
||||
container.addChild(new Text(header, 0, 0));
|
||||
if (isError && r.errorMessage) container.addChild(new Text(theme.fg("error", `Error: ${r.errorMessage}`), 0, 0));
|
||||
if (isError && r.errorMessage)
|
||||
container.addChild(new Text(theme.fg("error", `Error: ${r.errorMessage}`), 0, 0));
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Text(theme.fg("muted", "─── Task ───"), 0, 0));
|
||||
container.addChild(new Text(theme.fg("dim", r.task), 0, 0));
|
||||
|
|
@ -605,7 +753,14 @@ const factory: CustomToolFactory = (pi) => {
|
|||
container.addChild(new Text(theme.fg("muted", "(no output)"), 0, 0));
|
||||
} else {
|
||||
for (const item of displayItems) {
|
||||
if (item.type === "toolCall") container.addChild(new Text(theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)), 0, 0));
|
||||
if (item.type === "toolCall")
|
||||
container.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
if (finalOutput) {
|
||||
container.addChild(new Spacer(1));
|
||||
|
|
@ -613,20 +768,23 @@ const factory: CustomToolFactory = (pi) => {
|
|||
}
|
||||
}
|
||||
const usageStr = formatUsageStats(r.usage, r.model);
|
||||
if (usageStr) { container.addChild(new Spacer(1)); container.addChild(new Text(theme.fg("dim", usageStr), 0, 0)); }
|
||||
if (usageStr) {
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Text(theme.fg("dim", usageStr), 0, 0));
|
||||
}
|
||||
return container;
|
||||
}
|
||||
|
||||
let text = icon + " " + theme.fg("toolTitle", theme.bold(r.agent)) + theme.fg("muted", ` (${r.agentSource})`);
|
||||
if (isError && r.stopReason) text += " " + theme.fg("error", `[${r.stopReason}]`);
|
||||
if (isError && r.errorMessage) text += "\n" + theme.fg("error", `Error: ${r.errorMessage}`);
|
||||
else if (displayItems.length === 0) text += "\n" + theme.fg("muted", "(no output)");
|
||||
let text = `${icon} ${theme.fg("toolTitle", theme.bold(r.agent))}${theme.fg("muted", ` (${r.agentSource})`)}`;
|
||||
if (isError && r.stopReason) text += ` ${theme.fg("error", `[${r.stopReason}]`)}`;
|
||||
if (isError && r.errorMessage) text += `\n${theme.fg("error", `Error: ${r.errorMessage}`)}`;
|
||||
else if (displayItems.length === 0) text += `\n${theme.fg("muted", "(no output)")}`;
|
||||
else {
|
||||
text += "\n" + renderDisplayItems(displayItems, COLLAPSED_ITEM_COUNT);
|
||||
if (displayItems.length > COLLAPSED_ITEM_COUNT) text += "\n" + theme.fg("muted", "(Ctrl+O to expand)");
|
||||
text += `\n${renderDisplayItems(displayItems, COLLAPSED_ITEM_COUNT)}`;
|
||||
if (displayItems.length > COLLAPSED_ITEM_COUNT) text += `\n${theme.fg("muted", "(Ctrl+O to expand)")}`;
|
||||
}
|
||||
const usageStr = formatUsageStats(r.usage, r.model);
|
||||
if (usageStr) text += "\n" + theme.fg("dim", usageStr);
|
||||
if (usageStr) text += `\n${theme.fg("dim", usageStr)}`;
|
||||
return new Text(text, 0, 0);
|
||||
}
|
||||
|
||||
|
|
@ -646,37 +804,58 @@ const factory: CustomToolFactory = (pi) => {
|
|||
if (details.mode === "chain") {
|
||||
const successCount = details.results.filter((r) => r.exitCode === 0).length;
|
||||
const icon = successCount === details.results.length ? theme.fg("success", "✓") : theme.fg("error", "✗");
|
||||
|
||||
|
||||
if (expanded) {
|
||||
const container = new Container();
|
||||
container.addChild(new Text(icon + " " + theme.fg("toolTitle", theme.bold("chain ")) + theme.fg("accent", `${successCount}/${details.results.length} steps`), 0, 0));
|
||||
|
||||
container.addChild(
|
||||
new Text(
|
||||
icon +
|
||||
" " +
|
||||
theme.fg("toolTitle", theme.bold("chain ")) +
|
||||
theme.fg("accent", `${successCount}/${details.results.length} steps`),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
|
||||
for (const r of details.results) {
|
||||
const rIcon = r.exitCode === 0 ? theme.fg("success", "✓") : theme.fg("error", "✗");
|
||||
const displayItems = getDisplayItems(r.messages);
|
||||
const finalOutput = getFinalOutput(r.messages);
|
||||
|
||||
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Text(theme.fg("muted", `─── Step ${r.step}: `) + theme.fg("accent", r.agent) + " " + rIcon, 0, 0));
|
||||
container.addChild(
|
||||
new Text(
|
||||
`${theme.fg("muted", `─── Step ${r.step}: `) + theme.fg("accent", r.agent)} ${rIcon}`,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
container.addChild(new Text(theme.fg("muted", "Task: ") + theme.fg("dim", r.task), 0, 0));
|
||||
|
||||
|
||||
// Show tool calls
|
||||
for (const item of displayItems) {
|
||||
if (item.type === "toolCall") {
|
||||
container.addChild(new Text(theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)), 0, 0));
|
||||
container.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Show final output as markdown
|
||||
if (finalOutput) {
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Markdown(finalOutput.trim(), 0, 0, mdTheme));
|
||||
}
|
||||
|
||||
|
||||
const stepUsage = formatUsageStats(r.usage, r.model);
|
||||
if (stepUsage) container.addChild(new Text(theme.fg("dim", stepUsage), 0, 0));
|
||||
}
|
||||
|
||||
|
||||
const usageStr = formatUsageStats(aggregateUsage(details.results));
|
||||
if (usageStr) {
|
||||
container.addChild(new Spacer(1));
|
||||
|
|
@ -684,19 +863,23 @@ const factory: CustomToolFactory = (pi) => {
|
|||
}
|
||||
return container;
|
||||
}
|
||||
|
||||
|
||||
// Collapsed view
|
||||
let text = icon + " " + theme.fg("toolTitle", theme.bold("chain ")) + theme.fg("accent", `${successCount}/${details.results.length} steps`);
|
||||
let text =
|
||||
icon +
|
||||
" " +
|
||||
theme.fg("toolTitle", theme.bold("chain ")) +
|
||||
theme.fg("accent", `${successCount}/${details.results.length} steps`);
|
||||
for (const r of details.results) {
|
||||
const rIcon = r.exitCode === 0 ? theme.fg("success", "✓") : theme.fg("error", "✗");
|
||||
const displayItems = getDisplayItems(r.messages);
|
||||
text += "\n\n" + theme.fg("muted", `─── Step ${r.step}: `) + theme.fg("accent", r.agent) + " " + rIcon;
|
||||
if (displayItems.length === 0) text += "\n" + theme.fg("muted", "(no output)");
|
||||
else text += "\n" + renderDisplayItems(displayItems, 5);
|
||||
text += `\n\n${theme.fg("muted", `─── Step ${r.step}: `)}${theme.fg("accent", r.agent)} ${rIcon}`;
|
||||
if (displayItems.length === 0) text += `\n${theme.fg("muted", "(no output)")}`;
|
||||
else text += `\n${renderDisplayItems(displayItems, 5)}`;
|
||||
}
|
||||
const usageStr = formatUsageStats(aggregateUsage(details.results));
|
||||
if (usageStr) text += "\n\n" + theme.fg("dim", `Total: ${usageStr}`);
|
||||
text += "\n" + theme.fg("muted", "(Ctrl+O to expand)");
|
||||
if (usageStr) text += `\n\n${theme.fg("dim", `Total: ${usageStr}`)}`;
|
||||
text += `\n${theme.fg("muted", "(Ctrl+O to expand)")}`;
|
||||
return new Text(text, 0, 0);
|
||||
}
|
||||
|
||||
|
|
@ -705,41 +888,59 @@ const factory: CustomToolFactory = (pi) => {
|
|||
const successCount = details.results.filter((r) => r.exitCode === 0).length;
|
||||
const failCount = details.results.filter((r) => r.exitCode > 0).length;
|
||||
const isRunning = running > 0;
|
||||
const icon = isRunning ? theme.fg("warning", "⏳") : (failCount > 0 ? theme.fg("warning", "◐") : theme.fg("success", "✓"));
|
||||
const status = isRunning
|
||||
const icon = isRunning
|
||||
? theme.fg("warning", "⏳")
|
||||
: failCount > 0
|
||||
? theme.fg("warning", "◐")
|
||||
: theme.fg("success", "✓");
|
||||
const status = isRunning
|
||||
? `${successCount + failCount}/${details.results.length} done, ${running} running`
|
||||
: `${successCount}/${details.results.length} tasks`;
|
||||
|
||||
|
||||
if (expanded && !isRunning) {
|
||||
const container = new Container();
|
||||
container.addChild(new Text(icon + " " + theme.fg("toolTitle", theme.bold("parallel ")) + theme.fg("accent", status), 0, 0));
|
||||
|
||||
container.addChild(
|
||||
new Text(
|
||||
`${icon} ${theme.fg("toolTitle", theme.bold("parallel "))}${theme.fg("accent", status)}`,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
|
||||
for (const r of details.results) {
|
||||
const rIcon = r.exitCode === 0 ? theme.fg("success", "✓") : theme.fg("error", "✗");
|
||||
const displayItems = getDisplayItems(r.messages);
|
||||
const finalOutput = getFinalOutput(r.messages);
|
||||
|
||||
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Text(theme.fg("muted", "─── ") + theme.fg("accent", r.agent) + " " + rIcon, 0, 0));
|
||||
container.addChild(
|
||||
new Text(`${theme.fg("muted", "─── ") + theme.fg("accent", r.agent)} ${rIcon}`, 0, 0),
|
||||
);
|
||||
container.addChild(new Text(theme.fg("muted", "Task: ") + theme.fg("dim", r.task), 0, 0));
|
||||
|
||||
|
||||
// Show tool calls
|
||||
for (const item of displayItems) {
|
||||
if (item.type === "toolCall") {
|
||||
container.addChild(new Text(theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)), 0, 0));
|
||||
container.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", "→ ") + formatToolCall(item.name, item.args, theme.fg.bind(theme)),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Show final output as markdown
|
||||
if (finalOutput) {
|
||||
container.addChild(new Spacer(1));
|
||||
container.addChild(new Markdown(finalOutput.trim(), 0, 0, mdTheme));
|
||||
}
|
||||
|
||||
|
||||
const taskUsage = formatUsageStats(r.usage, r.model);
|
||||
if (taskUsage) container.addChild(new Text(theme.fg("dim", taskUsage), 0, 0));
|
||||
}
|
||||
|
||||
|
||||
const usageStr = formatUsageStats(aggregateUsage(details.results));
|
||||
if (usageStr) {
|
||||
container.addChild(new Spacer(1));
|
||||
|
|
@ -747,21 +948,27 @@ const factory: CustomToolFactory = (pi) => {
|
|||
}
|
||||
return container;
|
||||
}
|
||||
|
||||
|
||||
// Collapsed view (or still running)
|
||||
let text = icon + " " + theme.fg("toolTitle", theme.bold("parallel ")) + theme.fg("accent", status);
|
||||
let text = `${icon} ${theme.fg("toolTitle", theme.bold("parallel "))}${theme.fg("accent", status)}`;
|
||||
for (const r of details.results) {
|
||||
const rIcon = r.exitCode === -1 ? theme.fg("warning", "⏳") : (r.exitCode === 0 ? theme.fg("success", "✓") : theme.fg("error", "✗"));
|
||||
const rIcon =
|
||||
r.exitCode === -1
|
||||
? theme.fg("warning", "⏳")
|
||||
: r.exitCode === 0
|
||||
? theme.fg("success", "✓")
|
||||
: theme.fg("error", "✗");
|
||||
const displayItems = getDisplayItems(r.messages);
|
||||
text += "\n\n" + theme.fg("muted", "─── ") + theme.fg("accent", r.agent) + " " + rIcon;
|
||||
if (displayItems.length === 0) text += "\n" + theme.fg("muted", r.exitCode === -1 ? "(running...)" : "(no output)");
|
||||
else text += "\n" + renderDisplayItems(displayItems, 5);
|
||||
text += `\n\n${theme.fg("muted", "─── ")}${theme.fg("accent", r.agent)} ${rIcon}`;
|
||||
if (displayItems.length === 0)
|
||||
text += `\n${theme.fg("muted", r.exitCode === -1 ? "(running...)" : "(no output)")}`;
|
||||
else text += `\n${renderDisplayItems(displayItems, 5)}`;
|
||||
}
|
||||
if (!isRunning) {
|
||||
const usageStr = formatUsageStats(aggregateUsage(details.results));
|
||||
if (usageStr) text += "\n\n" + theme.fg("dim", `Total: ${usageStr}`);
|
||||
if (usageStr) text += `\n\n${theme.fg("dim", `Total: ${usageStr}`)}`;
|
||||
}
|
||||
if (!expanded) text += "\n" + theme.fg("muted", "(Ctrl+O to expand)");
|
||||
if (!expanded) text += `\n${theme.fg("muted", "(Ctrl+O to expand)")}`;
|
||||
return new Text(text, 0, 0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@
|
|||
* The onSession callback reconstructs state by scanning past tool results.
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { StringEnum } from "@mariozechner/pi-ai";
|
||||
import { Text } from "@mariozechner/pi-tui";
|
||||
import type { CustomAgentTool, CustomToolFactory, ToolSessionEvent } from "@mariozechner/pi-coding-agent";
|
||||
import { Text } from "@mariozechner/pi-tui";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
|
||||
interface Todo {
|
||||
id: number;
|
||||
|
|
@ -76,11 +76,18 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
switch (params.action) {
|
||||
case "list":
|
||||
return {
|
||||
content: [{ type: "text", text: todos.length ? todos.map((t) => `[${t.done ? "x" : " "}] #${t.id}: ${t.text}`).join("\n") : "No todos" }],
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: todos.length
|
||||
? todos.map((t) => `[${t.done ? "x" : " "}] #${t.id}: ${t.text}`).join("\n")
|
||||
: "No todos",
|
||||
},
|
||||
],
|
||||
details: { action: "list", todos: [...todos], nextId },
|
||||
};
|
||||
|
||||
case "add":
|
||||
case "add": {
|
||||
if (!params.text) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Error: text required for add" }],
|
||||
|
|
@ -93,8 +100,9 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
content: [{ type: "text", text: `Added todo #${newTodo.id}: ${newTodo.text}` }],
|
||||
details: { action: "add", todos: [...todos], nextId },
|
||||
};
|
||||
}
|
||||
|
||||
case "toggle":
|
||||
case "toggle": {
|
||||
if (params.id === undefined) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Error: id required for toggle" }],
|
||||
|
|
@ -113,8 +121,9 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
content: [{ type: "text", text: `Todo #${todo.id} ${todo.done ? "completed" : "uncompleted"}` }],
|
||||
details: { action: "toggle", todos: [...todos], nextId },
|
||||
};
|
||||
}
|
||||
|
||||
case "clear":
|
||||
case "clear": {
|
||||
const count = todos.length;
|
||||
todos = [];
|
||||
nextId = 1;
|
||||
|
|
@ -122,6 +131,7 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
content: [{ type: "text", text: `Cleared ${count} todos` }],
|
||||
details: { action: "clear", todos: [], nextId: 1 },
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
return {
|
||||
|
|
@ -133,8 +143,8 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("todo ")) + theme.fg("muted", args.action);
|
||||
if (args.text) text += " " + theme.fg("dim", `"${args.text}"`);
|
||||
if (args.id !== undefined) text += " " + theme.fg("accent", `#${args.id}`);
|
||||
if (args.text) text += ` ${theme.fg("dim", `"${args.text}"`)}`;
|
||||
if (args.id !== undefined) text += ` ${theme.fg("accent", `#${args.id}`)}`;
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
|
|
@ -153,7 +163,7 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
const todoList = details.todos;
|
||||
|
||||
switch (details.action) {
|
||||
case "list":
|
||||
case "list": {
|
||||
if (todoList.length === 0) {
|
||||
return new Text(theme.fg("dim", "No todos"), 0, 0);
|
||||
}
|
||||
|
|
@ -162,16 +172,24 @@ const factory: CustomToolFactory = (_pi) => {
|
|||
for (const t of display) {
|
||||
const check = t.done ? theme.fg("success", "✓") : theme.fg("dim", "○");
|
||||
const itemText = t.done ? theme.fg("dim", t.text) : theme.fg("muted", t.text);
|
||||
listText += "\n" + check + " " + theme.fg("accent", `#${t.id}`) + " " + itemText;
|
||||
listText += `\n${check} ${theme.fg("accent", `#${t.id}`)} ${itemText}`;
|
||||
}
|
||||
if (!expanded && todoList.length > 5) {
|
||||
listText += "\n" + theme.fg("dim", `... ${todoList.length - 5} more`);
|
||||
listText += `\n${theme.fg("dim", `... ${todoList.length - 5} more`)}`;
|
||||
}
|
||||
return new Text(listText, 0, 0);
|
||||
}
|
||||
|
||||
case "add": {
|
||||
const added = todoList[todoList.length - 1];
|
||||
return new Text(theme.fg("success", "✓ Added ") + theme.fg("accent", `#${added.id}`) + " " + theme.fg("muted", added.text), 0, 0);
|
||||
return new Text(
|
||||
theme.fg("success", "✓ Added ") +
|
||||
theme.fg("accent", `#${added.id}`) +
|
||||
" " +
|
||||
theme.fg("muted", added.text),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
case "toggle": {
|
||||
|
|
|
|||
|
|
@ -28,9 +28,7 @@ export default function (pi: HookAPI) {
|
|||
if (!ctx.hasUI) return;
|
||||
|
||||
// Check if there are unsaved changes (messages since last assistant response)
|
||||
const hasUnsavedWork = event.entries.some(
|
||||
(e) => e.type === "message" && e.message.role === "user",
|
||||
);
|
||||
const hasUnsavedWork = event.entries.some((e) => e.type === "message" && e.message.role === "user");
|
||||
|
||||
if (hasUnsavedWork) {
|
||||
const confirmed = await ctx.ui.confirm(
|
||||
|
|
@ -48,10 +46,10 @@ export default function (pi: HookAPI) {
|
|||
if (event.reason === "before_branch") {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const choice = await ctx.ui.select(
|
||||
`Branch from turn ${event.targetTurnIndex}?`,
|
||||
["Yes, create branch", "No, stay in current session"],
|
||||
);
|
||||
const choice = await ctx.ui.select(`Branch from turn ${event.targetTurnIndex}?`, [
|
||||
"Yes, create branch",
|
||||
"No, stay in current session",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, create branch") {
|
||||
ctx.ui.notify("Branch cancelled", "info");
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ export default function (pi: HookAPI) {
|
|||
|
||||
ctx.ui.notify("Custom compaction hook triggered", "info");
|
||||
|
||||
const { messagesToSummarize, messagesToKeep, previousSummary, tokensBefore, resolveApiKey, entries, signal } = event;
|
||||
const { messagesToSummarize, messagesToKeep, previousSummary, tokensBefore, resolveApiKey, entries, signal } =
|
||||
event;
|
||||
|
||||
// Use Gemini Flash for summarization (cheaper/faster than most conversation models)
|
||||
// findModel searches both built-in models and custom models from models.json
|
||||
|
|
|
|||
|
|
@ -10,11 +10,7 @@ import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
|||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
// Only guard destructive actions
|
||||
if (
|
||||
event.reason !== "before_clear" &&
|
||||
event.reason !== "before_switch" &&
|
||||
event.reason !== "before_branch"
|
||||
) {
|
||||
if (event.reason !== "before_clear" && event.reason !== "before_switch" && event.reason !== "before_branch") {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -46,10 +42,10 @@ export default function (pi: HookAPI) {
|
|||
? "switch session"
|
||||
: "branch";
|
||||
|
||||
const choice = await ctx.ui.select(
|
||||
`You have ${changedFiles} uncommitted file(s). ${action} anyway?`,
|
||||
["Yes, proceed anyway", "No, let me commit first"],
|
||||
);
|
||||
const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [
|
||||
"Yes, proceed anyway",
|
||||
"No, let me commit first",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, proceed anyway") {
|
||||
ctx.ui.notify("Commit your changes first", "warning");
|
||||
|
|
|
|||
|
|
@ -8,11 +8,7 @@
|
|||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
const dangerousPatterns = [
|
||||
/\brm\s+(-rf?|--recursive)/i,
|
||||
/\bsudo\b/i,
|
||||
/\b(chmod|chown)\b.*777/i,
|
||||
];
|
||||
const dangerousPatterns = [/\brm\s+(-rf?|--recursive)/i, /\bsudo\b/i, /\b(chmod|chown)\b.*777/i];
|
||||
|
||||
pi.on("tool_call", async (event, ctx) => {
|
||||
if (event.toolName !== "bash") return undefined;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
* Shows how to select a specific model and thinking level.
|
||||
*/
|
||||
|
||||
import { createAgentSession, findModel, discoverAvailableModels } from "../../src/index.js";
|
||||
import { createAgentSession, discoverAvailableModels, findModel } from "../../src/index.js";
|
||||
|
||||
// Option 1: Find a specific model by provider/id
|
||||
const { model: sonnet } = findModel("anthropic", "claude-sonnet-4-20250514");
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ const customSkill: Skill = {
|
|||
};
|
||||
|
||||
// Use filtered + custom skills
|
||||
const { session } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
skills: [...filteredSkills, customSkill],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -10,31 +10,28 @@
|
|||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import {
|
||||
createAgentSession,
|
||||
discoverCustomTools,
|
||||
SessionManager,
|
||||
codingTools, // read, bash, edit, write - uses process.cwd()
|
||||
readOnlyTools, // read, grep, find, ls - uses process.cwd()
|
||||
createCodingTools, // Factory: creates tools for specific cwd
|
||||
createReadOnlyTools, // Factory: creates tools for specific cwd
|
||||
createReadTool,
|
||||
createBashTool,
|
||||
createGrepTool,
|
||||
readTool,
|
||||
bashTool,
|
||||
grepTool,
|
||||
bashTool, // read, bash, edit, write - uses process.cwd()
|
||||
type CustomAgentTool,
|
||||
createAgentSession,
|
||||
createBashTool,
|
||||
createCodingTools, // Factory: creates tools for specific cwd
|
||||
createGrepTool,
|
||||
createReadTool,
|
||||
grepTool,
|
||||
readOnlyTools, // read, grep, find, ls - uses process.cwd()
|
||||
readTool,
|
||||
SessionManager,
|
||||
} from "../../src/index.js";
|
||||
|
||||
// Read-only mode (no edit/write) - uses process.cwd()
|
||||
const { session: readOnly } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
tools: readOnlyTools,
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
console.log("Read-only session created");
|
||||
|
||||
// Custom tool selection - uses process.cwd()
|
||||
const { session: custom } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
tools: [readTool, bashTool, grepTool],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
|
|
@ -42,7 +39,7 @@ console.log("Custom tools session created");
|
|||
|
||||
// With custom cwd - MUST use factory functions!
|
||||
const customCwd = "/path/to/project";
|
||||
const { session: customCwdSession } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
cwd: customCwd,
|
||||
tools: createCodingTools(customCwd), // Tools resolve paths relative to customCwd
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
|
|
@ -50,7 +47,7 @@ const { session: customCwdSession } = await createAgentSession({
|
|||
console.log("Custom cwd session created");
|
||||
|
||||
// Or pick specific tools for custom cwd
|
||||
const { session: specificTools } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
cwd: customCwd,
|
||||
tools: [createReadTool(customCwd), createBashTool(customCwd), createGrepTool(customCwd)],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
* Hooks intercept agent events for logging, blocking, or modification.
|
||||
*/
|
||||
|
||||
import { createAgentSession, discoverHooks, SessionManager, type HookFactory } from "../../src/index.js";
|
||||
import { createAgentSession, type HookFactory, SessionManager } from "../../src/index.js";
|
||||
|
||||
// Logging hook
|
||||
const loggingHook: HookFactory = (api) => {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ for (const file of discovered) {
|
|||
}
|
||||
|
||||
// Use custom context files
|
||||
const { session } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
contextFiles: [
|
||||
...discovered,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
* File-based commands that inject content when invoked with /commandname.
|
||||
*/
|
||||
|
||||
import { createAgentSession, discoverSlashCommands, SessionManager, type FileSlashCommand } from "../../src/index.js";
|
||||
import { createAgentSession, discoverSlashCommands, type FileSlashCommand, SessionManager } from "../../src/index.js";
|
||||
|
||||
// Discover commands from cwd/.pi/commands/ and ~/.pi/agent/commands/
|
||||
const discovered = discoverSlashCommands();
|
||||
|
|
@ -21,12 +21,12 @@ const deployCommand: FileSlashCommand = {
|
|||
content: `# Deploy Instructions
|
||||
|
||||
1. Build: npm run build
|
||||
2. Test: npm test
|
||||
2. Test: npm test
|
||||
3. Deploy: npm run deploy`,
|
||||
};
|
||||
|
||||
// Use discovered + custom commands
|
||||
const { session } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
slashCommands: [...discovered, deployCommand],
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -4,22 +4,17 @@
|
|||
* Configure API key resolution. Default checks: models.json, OAuth, env vars.
|
||||
*/
|
||||
|
||||
import {
|
||||
createAgentSession,
|
||||
configureOAuthStorage,
|
||||
defaultGetApiKey,
|
||||
SessionManager,
|
||||
} from "../../src/index.js";
|
||||
import { getAgentDir } from "../../src/config.js";
|
||||
import { configureOAuthStorage, createAgentSession, defaultGetApiKey, SessionManager } from "../../src/index.js";
|
||||
|
||||
// Default: uses env vars (ANTHROPIC_API_KEY, etc.), OAuth, and models.json
|
||||
const { session: defaultSession } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
console.log("Session with default API key resolution");
|
||||
|
||||
// Custom resolver
|
||||
const { session: customSession } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
getApiKey: async (model) => {
|
||||
// Custom logic (secrets manager, database, etc.)
|
||||
if (model.provider === "anthropic") {
|
||||
|
|
@ -35,7 +30,7 @@ console.log("Session with custom API key resolver");
|
|||
// Use OAuth from ~/.pi/agent while customizing everything else
|
||||
configureOAuthStorage(getAgentDir()); // Must call before createAgentSession
|
||||
|
||||
const { session: hybridSession } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
agentDir: "/tmp/custom-config", // Custom config location
|
||||
// But OAuth tokens still come from ~/.pi/agent/oauth.json
|
||||
systemPrompt: "You are helpful.",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ settingsManager.applyOverrides({
|
|||
retry: { enabled: true, maxRetries: 5, baseDelayMs: 1000 },
|
||||
});
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
settingsManager,
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
|
|
@ -30,7 +30,7 @@ const inMemorySettings = SettingsManager.inMemory({
|
|||
retry: { enabled: false },
|
||||
});
|
||||
|
||||
const { session: testSession } = await createAgentSession({
|
||||
await createAgentSession({
|
||||
settingsManager: inMemorySettings,
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -10,19 +10,19 @@
|
|||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { getAgentDir } from "../../src/config.js";
|
||||
import {
|
||||
createAgentSession,
|
||||
type CustomAgentTool,
|
||||
configureOAuthStorage,
|
||||
createAgentSession,
|
||||
createBashTool,
|
||||
createReadTool,
|
||||
defaultGetApiKey,
|
||||
findModel,
|
||||
type HookFactory,
|
||||
SessionManager,
|
||||
SettingsManager,
|
||||
createReadTool,
|
||||
createBashTool,
|
||||
type HookFactory,
|
||||
type CustomAgentTool,
|
||||
} from "../../src/index.js";
|
||||
import { getAgentDir } from "../../src/config.js";
|
||||
|
||||
// Use OAuth from default location
|
||||
configureOAuthStorage(getAgentDir());
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@
|
|||
"build:binary": "npm run build && bun build --compile ./dist/cli.js --outfile dist/pi && npm run copy-binary-assets",
|
||||
"copy-assets": "mkdir -p dist/modes/interactive/theme && cp src/modes/interactive/theme/*.json dist/modes/interactive/theme/",
|
||||
"copy-binary-assets": "cp package.json dist/ && cp README.md dist/ && cp CHANGELOG.md dist/ && mkdir -p dist/theme && cp src/modes/interactive/theme/*.json dist/theme/ && cp -r docs dist/ && cp -r examples dist/",
|
||||
"check": "tsgo --noEmit && tsgo -p tsconfig.examples.json",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
"clean": "rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json && chmod +x dist/main.js",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"check": "biome check --write . && tsgo --noEmit",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
"scripts": {
|
||||
"clean": "rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json && chmod +x dist/cli.js && cp src/models.json dist/ && cp -r scripts dist/",
|
||||
"check": "biome check --write .",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
"files": [
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
"scripts": {
|
||||
"clean": "rm -rf dist",
|
||||
"build": "tsc",
|
||||
"check": "biome check --write .",
|
||||
"typecheck": "tsgo --noEmit",
|
||||
"dev": "tsx src/cors-proxy.ts 3001"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
"clean": "rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"check": "biome check --write . && tsgo --noEmit",
|
||||
"test": "node --test --import tsx test/*.test.ts",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import type { Message } from "@mariozechner/pi-ai";
|
||||
import { html } from "lit";
|
||||
import { registerMessageRenderer } from "@mariozechner/pi-web-ui";
|
||||
import type { AppMessage, MessageRenderer } from "@mariozechner/pi-web-ui";
|
||||
import { Alert } from "@mariozechner/mini-lit/dist/Alert.js";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
import type { AppMessage, MessageRenderer } from "@mariozechner/pi-web-ui";
|
||||
import { registerMessageRenderer } from "@mariozechner/pi-web-ui";
|
||||
import { html } from "lit";
|
||||
|
||||
// ============================================================================
|
||||
// 1. EXTEND AppMessage TYPE VIA DECLARATION MERGING
|
||||
|
|
@ -85,10 +85,7 @@ export function customMessageTransformer(messages: AppMessage[]): Message[] {
|
|||
|
||||
// Keep LLM-compatible messages + custom messages
|
||||
return (
|
||||
m.role === "user" ||
|
||||
m.role === "assistant" ||
|
||||
m.role === "toolResult" ||
|
||||
m.role === "system-notification"
|
||||
m.role === "user" || m.role === "assistant" || m.role === "toolResult" || m.role === "system-notification"
|
||||
);
|
||||
})
|
||||
.map((m) => {
|
||||
|
|
@ -103,7 +100,7 @@ export function customMessageTransformer(messages: AppMessage[]): Message[] {
|
|||
|
||||
// Strip attachments from user messages
|
||||
if (m.role === "user") {
|
||||
const { attachments, ...rest } = m as any;
|
||||
const { attachments: _, ...rest } = m as any;
|
||||
return rest as Message;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,27 +7,31 @@ import {
|
|||
type AppMessage,
|
||||
AppStorage,
|
||||
ChatPanel,
|
||||
createJavaScriptReplTool,
|
||||
CustomProvidersStore,
|
||||
createJavaScriptReplTool,
|
||||
IndexedDBStorageBackend,
|
||||
// PersistentStorageDialog, // TODO: Fix - currently broken
|
||||
ProviderKeysStore,
|
||||
ProviderTransport,
|
||||
ProvidersModelsTab,
|
||||
ProviderTransport,
|
||||
ProxyTab,
|
||||
SessionListDialog,
|
||||
SessionsStore,
|
||||
setAppStorage,
|
||||
SettingsDialog,
|
||||
SettingsStore,
|
||||
setAppStorage,
|
||||
} from "@mariozechner/pi-web-ui";
|
||||
import { html, render } from "lit";
|
||||
import { Bell, History, Plus, Settings } from "lucide";
|
||||
import "./app.css";
|
||||
import { createSystemNotification, customMessageTransformer, registerCustomMessageRenderers } from "./custom-messages.js";
|
||||
import { Button } from "@mariozechner/mini-lit/dist/Button.js";
|
||||
import { icon } from "@mariozechner/mini-lit";
|
||||
import { Button } from "@mariozechner/mini-lit/dist/Button.js";
|
||||
import { Input } from "@mariozechner/mini-lit/dist/Input.js";
|
||||
import {
|
||||
createSystemNotification,
|
||||
customMessageTransformer,
|
||||
registerCustomMessageRenderers,
|
||||
} from "./custom-messages.js";
|
||||
|
||||
// Register custom message renderers
|
||||
registerCustomMessageRenderers();
|
||||
|
|
@ -92,7 +96,7 @@ const generateTitle = (messages: AppMessage[]): string => {
|
|||
if (sentenceEnd > 0 && sentenceEnd <= 50) {
|
||||
return text.substring(0, sentenceEnd + 1);
|
||||
}
|
||||
return text.length <= 50 ? text : text.substring(0, 47) + "...";
|
||||
return text.length <= 50 ? text : `${text.substring(0, 47)}...`;
|
||||
};
|
||||
|
||||
const shouldSaveSession = (messages: AppMessage[]): boolean => {
|
||||
|
|
@ -211,12 +215,12 @@ Feel free to use these tools when needed to provide accurate and helpful respons
|
|||
onApiKeyRequired: async (provider: string) => {
|
||||
return await ApiKeyPromptDialog.prompt(provider);
|
||||
},
|
||||
toolsFactory: (agent, agentInterface, artifactsPanel, runtimeProvidersFactory) => {
|
||||
toolsFactory: (_agent, _agentInterface, _artifactsPanel, runtimeProvidersFactory) => {
|
||||
// Create javascript_repl tool with access to attachments + artifacts
|
||||
const replTool = createJavaScriptReplTool();
|
||||
replTool.runtimeProvidersFactory = runtimeProvidersFactory;
|
||||
return [replTool];
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -290,9 +294,10 @@ const renderApp = () => {
|
|||
title: "New Session",
|
||||
})}
|
||||
|
||||
${currentTitle
|
||||
? isEditingTitle
|
||||
? html`<div class="flex items-center gap-2">
|
||||
${
|
||||
currentTitle
|
||||
? isEditingTitle
|
||||
? html`<div class="flex items-center gap-2">
|
||||
${Input({
|
||||
type: "text",
|
||||
value: currentTitle,
|
||||
|
|
@ -322,7 +327,7 @@ const renderApp = () => {
|
|||
},
|
||||
})}
|
||||
</div>`
|
||||
: html`<button
|
||||
: html`<button
|
||||
class="px-2 py-1 text-sm text-foreground hover:bg-secondary rounded transition-colors"
|
||||
@click=${() => {
|
||||
isEditingTitle = true;
|
||||
|
|
@ -339,7 +344,8 @@ const renderApp = () => {
|
|||
>
|
||||
${currentTitle}
|
||||
</button>`
|
||||
: html`<span class="text-base font-semibold text-foreground">Pi Web UI Example</span>`}
|
||||
: html`<span class="text-base font-semibold text-foreground">Pi Web UI Example</span>`
|
||||
}
|
||||
</div>
|
||||
<div class="flex items-center gap-1 px-2">
|
||||
${Button({
|
||||
|
|
@ -350,7 +356,9 @@ const renderApp = () => {
|
|||
// Demo: Inject custom message
|
||||
if (agent) {
|
||||
agent.appendMessage(
|
||||
createSystemNotification("This is a custom message! It appears in the UI but is never sent to the LLM."),
|
||||
createSystemNotification(
|
||||
"This is a custom message! It appears in the UI but is never sent to the LLM.",
|
||||
),
|
||||
);
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
"build": "tsc -p tsconfig.build.json && tailwindcss -i ./src/app.css -o ./dist/app.css --minify",
|
||||
"dev": "concurrently --names \"build,example\" --prefix-colors \"cyan,green\" \"tsc -p tsconfig.build.json --watch --preserveWatchOutput\" \"tailwindcss -i ./src/app.css -o ./dist/app.css --watch\" \"npm run dev --prefix example\"",
|
||||
"dev:tsc": "concurrently --names \"build\" --prefix-colors \"cyan\" \"tsc -p tsconfig.build.json --watch --preserveWatchOutput\" \"tailwindcss -i ./src/app.css -o ./dist/app.css --watch\"",
|
||||
"check": "tsgo --noEmit && cd example && tsgo --noEmit"
|
||||
"check": "biome check --write . && tsc --noEmit && cd example && biome check --write . && tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"@lmstudio/sdk": "^1.5.0",
|
||||
|
|
|
|||
|
|
@ -27,5 +27,6 @@
|
|||
"@mariozechner/pi-agent-old/*": ["./packages/agent-old/src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["packages/*/src/**/*", "packages/*/test/**/*", "packages/coding-agent/examples/**/*"]
|
||||
"include": ["packages/*/src/**/*", "packages/*/test/**/*", "packages/coding-agent/examples/**/*"],
|
||||
"exclude": ["packages/web-ui/**/*"]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue