move pi-mono into companion-cloud as apps/companion-os

- Copy all pi-mono source into apps/companion-os/
- Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases
- Update deploy-staging.yml to build pi from source (bun compile) before Docker build
- Add apps/companion-os/** to path triggers
- No more cross-repo dispatch needed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Harivansh Rathi 2026-03-07 09:22:50 -08:00
commit 0250f72976
579 changed files with 206942 additions and 0 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,503 @@
/**
* Credential storage for API keys and OAuth tokens.
* Handles loading, saving, and refreshing credentials from auth.json.
*
* Uses file locking to prevent race conditions when multiple pi instances
* try to refresh tokens simultaneously.
*/
import {
getEnvApiKey,
type OAuthCredentials,
type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@mariozechner/pi-ai";
import {
getOAuthApiKey,
getOAuthProvider,
getOAuthProviders,
} from "@mariozechner/pi-ai/oauth";
import {
chmodSync,
existsSync,
mkdirSync,
readFileSync,
writeFileSync,
} from "fs";
import { dirname, join } from "path";
import lockfile from "proper-lockfile";
import { getAgentDir } from "../config.js";
import { resolveConfigValue } from "./resolve-config-value.js";
export type ApiKeyCredential = {
type: "api_key";
key: string;
};
export type OAuthCredential = {
type: "oauth";
} & OAuthCredentials;
export type AuthCredential = ApiKeyCredential | OAuthCredential;
export type AuthStorageData = Record<string, AuthCredential>;
type LockResult<T> = {
result: T;
next?: string;
};
export interface AuthStorageBackend {
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T>;
}
export class FileAuthStorageBackend implements AuthStorageBackend {
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
private ensureParentDir(): void {
const dir = dirname(this.authPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
}
private ensureFileExists(): void {
if (!existsSync(this.authPath)) {
writeFileSync(this.authPath, "{}", "utf-8");
chmodSync(this.authPath, 0o600);
}
}
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => void) | undefined;
try {
release = lockfile.lockSync(this.authPath, { realpath: false });
const current = existsSync(this.authPath)
? readFileSync(this.authPath, "utf-8")
: undefined;
const { result, next } = fn(current);
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
return result;
} finally {
if (release) {
release();
}
}
}
async withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T> {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => Promise<void>) | undefined;
let lockCompromised = false;
let lockCompromisedError: Error | undefined;
const throwIfCompromised = () => {
if (lockCompromised) {
throw (
lockCompromisedError ?? new Error("Auth storage lock was compromised")
);
}
};
try {
release = await lockfile.lock(this.authPath, {
retries: {
retries: 10,
factor: 2,
minTimeout: 100,
maxTimeout: 10000,
randomize: true,
},
stale: 30000,
onCompromised: (err) => {
lockCompromised = true;
lockCompromisedError = err;
},
});
throwIfCompromised();
const current = existsSync(this.authPath)
? readFileSync(this.authPath, "utf-8")
: undefined;
const { result, next } = await fn(current);
throwIfCompromised();
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
throwIfCompromised();
return result;
} finally {
if (release) {
try {
await release();
} catch {
// Ignore unlock errors when lock is compromised.
}
}
}
}
}
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
private value: string | undefined;
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
const { result, next } = fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
async withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T> {
const { result, next } = await fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
}
/**
* Credential storage backed by a JSON file.
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
private loadError: Error | null = null;
private errors: Error[] = [];
private constructor(private storage: AuthStorageBackend) {
this.reload();
}
static create(authPath?: string): AuthStorage {
return new AuthStorage(
new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")),
);
}
static fromStorage(storage: AuthStorageBackend): AuthStorage {
return new AuthStorage(storage);
}
static inMemory(data: AuthStorageData = {}): AuthStorage {
const storage = new InMemoryAuthStorageBackend();
storage.withLock(() => ({
result: undefined,
next: JSON.stringify(data, null, 2),
}));
return AuthStorage.fromStorage(storage);
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Remove a runtime API key override.
*/
removeRuntimeApiKey(provider: string): void {
this.runtimeOverrides.delete(provider);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(
resolver: (provider: string) => string | undefined,
): void {
this.fallbackResolver = resolver;
}
private recordError(error: unknown): void {
const normalizedError =
error instanceof Error ? error : new Error(String(error));
this.errors.push(normalizedError);
}
private parseStorageData(content: string | undefined): AuthStorageData {
if (!content) {
return {};
}
return JSON.parse(content) as AuthStorageData;
}
/**
* Reload credentials from storage.
*/
reload(): void {
let content: string | undefined;
try {
this.storage.withLock((current) => {
content = current;
return { result: undefined };
});
this.data = this.parseStorageData(content);
this.loadError = null;
} catch (error) {
this.loadError = error as Error;
this.recordError(error);
}
}
private persistProviderChange(
provider: string,
credential: AuthCredential | undefined,
): void {
if (this.loadError) {
return;
}
try {
this.storage.withLock((current) => {
const currentData = this.parseStorageData(current);
const merged: AuthStorageData = { ...currentData };
if (credential) {
merged[provider] = credential;
} else {
delete merged[provider];
}
return { result: undefined, next: JSON.stringify(merged, null, 2) };
});
} catch (error) {
this.recordError(error);
}
}
/**
* Get credential for a provider.
*/
get(provider: string): AuthCredential | undefined {
return this.data[provider] ?? undefined;
}
/**
* Set credential for a provider.
*/
set(provider: string, credential: AuthCredential): void {
this.data[provider] = credential;
this.persistProviderChange(provider, credential);
}
/**
* Remove credential for a provider.
*/
remove(provider: string): void {
delete this.data[provider];
this.persistProviderChange(provider, undefined);
}
/**
* List all providers with credentials.
*/
list(): string[] {
return Object.keys(this.data);
}
/**
* Check if credentials exist for a provider in auth.json.
*/
has(provider: string): boolean {
return provider in this.data;
}
/**
* Check if any form of auth is configured for a provider.
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
*/
hasAuth(provider: string): boolean {
if (this.runtimeOverrides.has(provider)) return true;
if (this.data[provider]) return true;
if (getEnvApiKey(provider)) return true;
if (this.fallbackResolver?.(provider)) return true;
return false;
}
/**
* Get all credentials (for passing to getOAuthApiKey).
*/
getAll(): AuthStorageData {
return { ...this.data };
}
drainErrors(): Error[] {
const drained = [...this.errors];
this.errors = [];
return drained;
}
/**
* Login to an OAuth provider.
*/
async login(
providerId: OAuthProviderId,
callbacks: OAuthLoginCallbacks,
): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Refresh OAuth token with backend locking to prevent race conditions.
* Multiple pi instances may try to refresh simultaneously when tokens expire.
*/
private async refreshOAuthTokenWithLock(
providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
const result = await this.storage.withLockAsync(async (current) => {
const currentData = this.parseStorageData(current);
this.data = currentData;
this.loadError = null;
const cred = currentData[providerId];
if (cred?.type !== "oauth") {
return { result: null };
}
if (Date.now() < cred.expires) {
return {
result: { apiKey: provider.getApiKey(cred), newCredentials: cred },
};
}
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(currentData)) {
if (value.type === "oauth") {
oauthCreds[key] = value;
}
}
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
if (!refreshed) {
return { result: null };
}
const merged: AuthStorageData = {
...currentData,
[providerId]: { type: "oauth", ...refreshed.newCredentials },
};
this.data = merged;
this.loadError = null;
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
});
return result;
}
/**
* Get API key for a provider.
* Priority:
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed with locking)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return resolveConfigValue(cred.key);
}
if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) {
// Use locked refresh to prevent race conditions
try {
const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) {
return result.apiKey;
}
} catch (error) {
this.recordError(error);
// Refresh failed - re-read file to check if another instance succeeded
this.reload();
const updatedCred = this.data[providerId];
if (
updatedCred?.type === "oauth" &&
Date.now() < updatedCred.expires
) {
// Another instance refreshed successfully, use those credentials
return provider.getApiKey(updatedCred);
}
// Refresh truly failed - return undefined so model discovery skips this provider
// User can /login to re-authenticate (credentials preserved for retry)
return undefined;
}
} else {
// Token not expired, use current access token
return provider.getApiKey(cred);
}
}
// Fall back to environment variable
const envKey = getEnvApiKey(providerId);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
}
}

View file

@ -0,0 +1,296 @@
/**
* Bash command execution with streaming support and cancellation.
*
* This module provides a unified bash execution implementation used by:
* - AgentSession.executeBash() for interactive and RPC modes
* - Direct calls from modes that need bash execution
*/
import { randomBytes } from "node:crypto";
import { createWriteStream, type WriteStream } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { type ChildProcess, spawn } from "child_process";
import stripAnsi from "strip-ansi";
import {
getShellConfig,
getShellEnv,
killProcessTree,
sanitizeBinaryOutput,
} from "../utils/shell.js";
import type { BashOperations } from "./tools/bash.js";
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
// ============================================================================
// Types
// ============================================================================
export interface BashExecutorOptions {
/** Callback for streaming output chunks (already sanitized) */
onChunk?: (chunk: string) => void;
/** AbortSignal for cancellation */
signal?: AbortSignal;
}
export interface BashResult {
/** Combined stdout + stderr output (sanitized, possibly truncated) */
output: string;
/** Process exit code (undefined if killed/cancelled) */
exitCode: number | undefined;
/** Whether the command was cancelled via signal */
cancelled: boolean;
/** Whether the output was truncated */
truncated: boolean;
/** Path to temp file containing full output (if output exceeded truncation threshold) */
fullOutputPath?: string;
}
// ============================================================================
// Implementation
// ============================================================================
/**
* Execute a bash command with optional streaming and cancellation support.
*
* Features:
* - Streams sanitized output via onChunk callback
* - Writes large output to temp file for later retrieval
* - Supports cancellation via AbortSignal
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
* - Truncates output if it exceeds the default max bytes
*
* @param command - The bash command to execute
* @param options - Optional streaming callback and abort signal
* @returns Promise resolving to execution result
*/
export function executeBash(
command: string,
options?: BashExecutorOptions,
): Promise<BashResult> {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
const child: ChildProcess = spawn(shell, [...args, command], {
detached: true,
env: getShellEnv(),
stdio: ["ignore", "pipe", "pipe"],
});
// Track sanitized output for truncation
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
// Temp file for large output
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
// Handle abort signal
const abortHandler = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (options?.signal) {
if (options.signal.aborted) {
// Already aborted, don't even start
child.kill();
resolve({
output: "",
exitCode: undefined,
cancelled: true,
truncated: false,
});
return;
}
options.signal.addEventListener("abort", abortHandler, { once: true });
}
const decoder = new TextDecoder();
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(
stripAnsi(decoder.decode(data, { stream: true })),
).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
// Write already-buffered chunks to temp file
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer of sanitized text
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback if provided
if (options?.onChunk) {
options.onChunk(text);
}
};
child.stdout?.on("data", handleData);
child.stderr?.on("data", handleData);
child.on("close", (code) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
// Combine buffered chunks for truncation (already sanitized)
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
// code === null means killed (cancelled)
const cancelled = code === null;
resolve({
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: cancelled ? undefined : code,
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
});
});
child.on("error", (err) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
reject(err);
});
});
}
/**
* Execute a bash command using custom BashOperations.
* Used for remote execution (SSH, containers, etc.).
*/
export async function executeBashWithOperations(
command: string,
cwd: string,
operations: BashOperations,
options?: BashExecutorOptions,
): Promise<BashResult> {
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
const decoder = new TextDecoder();
const onData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(
stripAnsi(decoder.decode(data, { stream: true })),
).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback
if (options?.onChunk) {
options.onChunk(text);
}
};
try {
const result = await operations.exec(command, cwd, {
onData,
signal: options?.signal,
});
if (tempFileStream) {
tempFileStream.end();
}
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
const cancelled = options?.signal?.aborted ?? false;
return {
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
} catch (err) {
if (tempFileStream) {
tempFileStream.end();
}
// Check if it was an abort
if (options?.signal?.aborted) {
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
return {
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: undefined,
cancelled: true,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
}
throw err;
}
}

View file

@ -0,0 +1,382 @@
/**
* Branch summarization for tree navigation.
*
* When navigating to a different point in the session tree, this generates
* a summary of the branch being left so context isn't lost.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import { completeSimple } from "@mariozechner/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type {
ReadonlySessionManager,
SessionEntry,
} from "../session-manager.js";
import { estimateTokens } from "./compaction.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// Types
// ============================================================================
export interface BranchSummaryResult {
summary?: string;
readFiles?: string[];
modifiedFiles?: string[];
aborted?: boolean;
error?: string;
}
/** Details stored in BranchSummaryEntry.details for file tracking */
export interface BranchSummaryDetails {
readFiles: string[];
modifiedFiles: string[];
}
export type { FileOperations } from "./utils.js";
export interface BranchPreparation {
/** Messages extracted for summarization, in chronological order */
messages: AgentMessage[];
/** File operations extracted from tool calls */
fileOps: FileOperations;
/** Total estimated tokens in messages */
totalTokens: number;
}
export interface CollectEntriesResult {
/** Entries to summarize, in chronological order */
entries: SessionEntry[];
/** Common ancestor between old and new position, if any */
commonAncestorId: string | null;
}
export interface GenerateBranchSummaryOptions {
/** Model to use for summarization */
model: Model<any>;
/** API key for the model */
apiKey: string;
/** Abort signal for cancellation */
signal: AbortSignal;
/** Optional custom instructions for summarization */
customInstructions?: string;
/** If true, customInstructions replaces the default prompt instead of being appended */
replaceInstructions?: boolean;
/** Tokens reserved for prompt + LLM response (default 16384) */
reserveTokens?: number;
}
// ============================================================================
// Entry Collection
// ============================================================================
/**
* Collect entries that should be summarized when navigating from one position to another.
*
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
* along the way. Does NOT stop at compaction boundaries - those are included and their
* summaries become context.
*
* @param session - Session manager (read-only access)
* @param oldLeafId - Current position (where we're navigating from)
* @param targetId - Target position (where we're navigating to)
* @returns Entries to summarize and the common ancestor
*/
export function collectEntriesForBranchSummary(
session: ReadonlySessionManager,
oldLeafId: string | null,
targetId: string,
): CollectEntriesResult {
// If no old position, nothing to summarize
if (!oldLeafId) {
return { entries: [], commonAncestorId: null };
}
// Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
const targetPath = session.getBranch(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null;
for (let i = targetPath.length - 1; i >= 0; i--) {
if (oldPath.has(targetPath[i].id)) {
commonAncestorId = targetPath[i].id;
break;
}
}
// Collect entries from old leaf back to common ancestor
const entries: SessionEntry[] = [];
let current: string | null = oldLeafId;
while (current && current !== commonAncestorId) {
const entry = session.getEntry(current);
if (!entry) break;
entries.push(entry);
current = entry.parentId;
}
// Reverse to get chronological order
entries.reverse();
return { entries, commonAncestorId };
}
// ============================================================================
// Entry to Message Conversion
// ============================================================================
/**
* Extract AgentMessage from a session entry.
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
switch (entry.type) {
case "message":
// Skip tool results - context is in assistant's tool call
if (entry.message.role === "toolResult") return undefined;
return entry.message;
case "custom_message":
return createCustomMessage(
entry.customType,
entry.content,
entry.display,
entry.details,
entry.timestamp,
);
case "branch_summary":
return createBranchSummaryMessage(
entry.summary,
entry.fromId,
entry.timestamp,
);
case "compaction":
return createCompactionSummaryMessage(
entry.summary,
entry.tokensBefore,
entry.timestamp,
);
// These don't contribute to conversation content
case "thinking_level_change":
case "model_change":
case "custom":
case "label":
return undefined;
}
}
/**
* Prepare entries for summarization with token budget.
*
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
* This ensures we keep the most recent context when the branch is too long.
*
* Also collects file operations from:
* - Tool calls in assistant messages
* - Existing branch_summary entries' details (for cumulative tracking)
*
* @param entries - Entries in chronological order
* @param tokenBudget - Maximum tokens to include (0 = no limit)
*/
export function prepareBranchEntries(
entries: SessionEntry[],
tokenBudget: number = 0,
): BranchPreparation {
const messages: AgentMessage[] = [];
const fileOps = createFileOps();
let totalTokens = 0;
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
// This ensures we capture cumulative file tracking from nested branch summaries
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
for (const entry of entries) {
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
const details = entry.details as BranchSummaryDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
// Modified files go into both edited and written for proper deduplication
for (const f of details.modifiedFiles) {
fileOps.edited.add(f);
}
}
}
}
// Second pass: walk from newest to oldest, adding messages until token budget
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
const message = getMessageFromEntry(entry);
if (!message) continue;
// Extract file ops from assistant messages (tool calls)
extractFileOpsFromMessage(message, fileOps);
const tokens = estimateTokens(message);
// Check budget before adding
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
// If this is a summary entry, try to fit it anyway as it's important context
if (entry.type === "compaction" || entry.type === "branch_summary") {
if (totalTokens < tokenBudget * 0.9) {
messages.unshift(message);
totalTokens += tokens;
}
}
// Stop - we've hit the budget
break;
}
messages.unshift(message);
totalTokens += tokens;
}
return { messages, fileOps, totalTokens };
}
// ============================================================================
// Summary Generation
// ============================================================================
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
Summary of that exploration:
`;
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
Use this EXACT format:
## Goal
[What was the user trying to accomplish in this branch?]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Work that was started but not finished]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [What should happen next to continue this work]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of abandoned branch entries.
*
* @param entries - Session entries to summarize (chronological order)
* @param options - Generation options
*/
export async function generateBranchSummary(
entries: SessionEntry[],
options: GenerateBranchSummaryOptions,
): Promise<BranchSummaryResult> {
const {
model,
apiKey,
signal,
customInstructions,
replaceInstructions,
reserveTokens = 16384,
} = options;
// Token budget = context window minus reserved space for prompt + response
const contextWindow = model.contextWindow || 128000;
const tokenBudget = contextWindow - reserveTokens;
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
if (messages.length === 0) {
return { summary: "No content to summarize" };
}
// Transform to LLM-compatible messages, then serialize to text
// Serialization prevents the model from treating it as a conversation to continue
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
// Build prompt
let instructions: string;
if (replaceInstructions && customInstructions) {
instructions = customInstructions;
} else if (customInstructions) {
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
} else {
instructions = BRANCH_SUMMARY_PROMPT;
}
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
// Call LLM for summarization
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
{ apiKey, signal, maxTokens: 2048 },
);
// Check if aborted or errored
if (response.stopReason === "aborted") {
return { aborted: true };
}
if (response.stopReason === "error") {
return { error: response.errorMessage || "Summarization failed" };
}
let summary = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
// Prepend preamble to provide context about the branch summary
summary = BRANCH_SUMMARY_PREAMBLE + summary;
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
return {
summary: summary || "No summary generated",
readFiles,
modifiedFiles,
};
}

View file

@ -0,0 +1,899 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { completeSimple } from "@mariozechner/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// File Operation Tracking
// ============================================================================
/** Details stored in CompactionEntry.details for file tracking */
export interface CompactionDetails {
readFiles: string[];
modifiedFiles: string[];
}
/**
* Extract file operations from messages and previous compaction entries.
*/
function extractFileOperations(
messages: AgentMessage[],
entries: SessionEntry[],
prevCompactionIndex: number,
): FileOperations {
const fileOps = createFileOps();
// Collect from previous compaction's details (if pi-generated)
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
if (!prevCompaction.fromHook && prevCompaction.details) {
// fromHook field kept for session file compatibility
const details = prevCompaction.details as CompactionDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
for (const f of details.modifiedFiles) fileOps.edited.add(f);
}
}
}
// Extract from tool calls in messages
for (const msg of messages) {
extractFileOpsFromMessage(msg, fileOps);
}
return fileOps;
}
// ============================================================================
// Message Extraction
// ============================================================================
/**
* Extract AgentMessage from an entry if it produces one.
* Returns undefined for entries that don't contribute to LLM context.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
if (entry.type === "message") {
return entry.message;
}
if (entry.type === "custom_message") {
return createCustomMessage(
entry.customType,
entry.content,
entry.display,
entry.details,
entry.timestamp,
);
}
if (entry.type === "branch_summary") {
return createBranchSummaryMessage(
entry.summary,
entry.fromId,
entry.timestamp,
);
}
if (entry.type === "compaction") {
return createCompactionSummaryMessage(
entry.summary,
entry.tokensBefore,
entry.timestamp,
);
}
return undefined;
}
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
export interface CompactionResult<T = unknown> {
summary: string;
firstKeptEntryId: string;
tokensBefore: number;
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
details?: T;
}
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return (
usage.totalTokens ||
usage.input + usage.output + usage.cacheRead + usage.cacheWrite
);
}
/**
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (
assistantMsg.stopReason !== "aborted" &&
assistantMsg.stopReason !== "error" &&
assistantMsg.usage
) {
return assistantMsg.usage;
}
}
return undefined;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(
entries: SessionEntry[],
): Usage | undefined {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return undefined;
}
export interface ContextUsageEstimate {
tokens: number;
usageTokens: number;
trailingTokens: number;
lastUsageIndex: number | null;
}
function getLastAssistantUsageInfo(
messages: AgentMessage[],
): { usage: Usage; index: number } | undefined {
for (let i = messages.length - 1; i >= 0; i--) {
const usage = getAssistantUsage(messages[i]);
if (usage) return { usage, index: i };
}
return undefined;
}
/**
* Estimate context tokens from messages, using the last assistant usage when available.
* If there are messages after the last usage, estimate their tokens with estimateTokens.
*/
export function estimateContextTokens(
messages: AgentMessage[],
): ContextUsageEstimate {
const usageInfo = getLastAssistantUsageInfo(messages);
if (!usageInfo) {
let estimated = 0;
for (const message of messages) {
estimated += estimateTokens(message);
}
return {
tokens: estimated,
usageTokens: 0,
trailingTokens: estimated,
lastUsageIndex: null,
};
}
const usageTokens = calculateContextTokens(usageInfo.usage);
let trailingTokens = 0;
for (let i = usageInfo.index + 1; i < messages.length; i++) {
trailingTokens += estimateTokens(messages[i]);
}
return {
tokens: usageTokens + trailingTokens,
usageTokens,
trailingTokens,
lastUsageIndex: usageInfo.index,
};
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(
contextTokens: number,
contextWindow: number,
settings: CompactionSettings,
): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AgentMessage): number {
let chars = 0;
switch (message.role) {
case "user": {
const content = (
message as { content: string | Array<{ type: string; text?: string }> }
).content;
if (typeof content === "string") {
chars = content.length;
} else if (Array.isArray(content)) {
for (const block of content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
}
return Math.ceil(chars / 4);
}
case "assistant": {
const assistant = message as AssistantMessage;
for (const block of assistant.content) {
if (block.type === "text") {
chars += block.text.length;
} else if (block.type === "thinking") {
chars += block.thinking.length;
} else if (block.type === "toolCall") {
chars += block.name.length + JSON.stringify(block.arguments).length;
}
}
return Math.ceil(chars / 4);
}
case "custom":
case "toolResult": {
if (typeof message.content === "string") {
chars = message.content.length;
} else {
for (const block of message.content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
if (block.type === "image") {
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
}
}
}
return Math.ceil(chars / 4);
}
case "bashExecution": {
chars = message.command.length + message.output.length;
return Math.ceil(chars / 4);
}
case "branchSummary":
case "compactionSummary": {
chars = message.summary.length;
return Math.ceil(chars / 4);
}
}
return 0;
}
/**
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
* Never cut at tool results (they must follow their tool call).
* When we cut at an assistant message with tool calls, its tool results follow it
* and will be kept.
* BashExecutionMessage is treated like a user message (user-initiated context).
*/
function findValidCutPoints(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
): number[] {
const cutPoints: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
switch (entry.type) {
case "message": {
const role = entry.message.role;
switch (role) {
case "bashExecution":
case "custom":
case "branchSummary":
case "compactionSummary":
case "user":
case "assistant":
cutPoints.push(i);
break;
case "toolResult":
break;
}
break;
}
case "thinking_level_change":
case "model_change":
case "compaction":
case "branch_summary":
case "custom":
case "custom_message":
case "label":
}
// branch_summary and custom_message are user-role messages, valid cut points
if (entry.type === "branch_summary" || entry.type === "custom_message") {
cutPoints.push(i);
}
}
return cutPoints;
}
/**
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
* Returns -1 if no turn start found before the index.
* BashExecutionMessage is treated like a user message for turn boundaries.
*/
export function findTurnStartIndex(
entries: SessionEntry[],
entryIndex: number,
startIndex: number,
): number {
for (let i = entryIndex; i >= startIndex; i--) {
const entry = entries[i];
// branch_summary and custom_message are user-role messages, can start a turn
if (entry.type === "branch_summary" || entry.type === "custom_message") {
return i;
}
if (entry.type === "message") {
const role = entry.message.role;
if (role === "user" || role === "bashExecution") {
return i;
}
}
}
return -1;
}
export interface CutPointResult {
/** Index of first entry to keep */
firstKeptEntryIndex: number;
/** Index of user message that starts the turn being split, or -1 if not splitting */
turnStartIndex: number;
/** Whether this cut splits a turn (cut point is not a user message) */
isSplitTurn: boolean;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
*
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
*
* Can cut at user OR assistant messages (never tool results). When cutting at an
* assistant message with tool calls, its tool results come after and will be kept.
*
* Returns CutPointResult with:
* - firstKeptEntryIndex: the entry index to start keeping from
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
* - isSplitTurn: whether we're cutting in the middle of a turn
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): CutPointResult {
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
if (cutPoints.length === 0) {
return {
firstKeptEntryIndex: startIndex,
turnStartIndex: -1,
isSplitTurn: false,
};
}
// Walk backwards from newest, accumulating estimated message sizes
let accumulatedTokens = 0;
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type !== "message") continue;
// Estimate this message's size
const messageTokens = estimateTokens(entry.message);
accumulatedTokens += messageTokens;
// Check if we've exceeded the budget
if (accumulatedTokens >= keepRecentTokens) {
// Find the closest valid cut point at or after this entry
for (let c = 0; c < cutPoints.length; c++) {
if (cutPoints[c] >= i) {
cutIndex = cutPoints[c];
break;
}
}
break;
}
}
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
while (cutIndex > startIndex) {
const prevEntry = entries[cutIndex - 1];
// Stop at session header or compaction boundaries
if (prevEntry.type === "compaction") {
break;
}
if (prevEntry.type === "message") {
// Stop if we hit any message
break;
}
// Include this non-message entry (bash, settings change, etc.)
cutIndex--;
}
// Determine if this is a split turn
const cutEntry = entries[cutIndex];
const isUserMessage =
cutEntry.type === "message" && cutEntry.message.role === "user";
const turnStartIndex = isUserMessage
? -1
: findTurnStartIndex(entries, cutIndex, startIndex);
return {
firstKeptEntryIndex: cutIndex,
turnStartIndex,
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
};
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
Use this EXACT format:
## Goal
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned by user]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Current work]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [Ordered list of what should happen next]
## Critical Context
- [Any data, examples, or references needed to continue]
- [Or "(none)" if not applicable]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
Update the existing structured summary with new information. RULES:
- PRESERVE all existing information from the previous summary
- ADD new progress, decisions, and context from the new messages
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
- UPDATE "Next Steps" based on what was accomplished
- PRESERVE exact file paths, function names, and error messages
- If something is no longer relevant, you may remove it
Use this EXACT format:
## Goal
[Preserve existing goals, add new ones if the task expanded]
## Constraints & Preferences
- [Preserve existing, add new ones discovered]
## Progress
### Done
- [x] [Include previously done items AND newly completed items]
### In Progress
- [ ] [Current work - update based on progress]
### Blocked
- [Current blockers - remove if resolved]
## Key Decisions
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
## Next Steps
1. [Update based on current state]
## Critical Context
- [Preserve important context, add new if needed]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of the conversation using the LLM.
* If previousSummary is provided, uses the update prompt to merge.
*/
export async function generateSummary(
currentMessages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
previousSummary?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
// Use update prompt if we have a previous summary, otherwise initial prompt
let basePrompt = previousSummary
? UPDATE_SUMMARIZATION_PROMPT
: SUMMARIZATION_PROMPT;
if (customInstructions) {
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
}
// Serialize conversation to text so model doesn't try to continue it
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
const llmMessages = convertToLlm(currentMessages);
const conversationText = serializeConversation(llmMessages);
// Build the prompt with conversation wrapped in tags
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
if (previousSummary) {
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
}
promptText += basePrompt;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const completionOptions = model.reasoning
? { maxTokens, signal, apiKey, reasoning: "high" as const }
: { maxTokens, signal, apiKey };
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
completionOptions,
);
if (response.stopReason === "error") {
throw new Error(
`Summarization failed: ${response.errorMessage || "Unknown error"}`,
);
}
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Compaction Preparation (for extensions)
// ============================================================================
export interface CompactionPreparation {
/** UUID of first entry to keep */
firstKeptEntryId: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AgentMessage[];
/** Messages that will be turned into turn prefix summary (if splitting) */
turnPrefixMessages: AgentMessage[];
/** Whether this is a split turn (cut point in middle of turn) */
isSplitTurn: boolean;
tokensBefore: number;
/** Summary from previous compaction, for iterative update */
previousSummary?: string;
/** File operations extracted from messagesToSummarize */
fileOps: FileOperations;
/** Compaction settions from settings.jsonl */
settings: CompactionSettings;
}
export function prepareCompaction(
pathEntries: SessionEntry[],
settings: CompactionSettings,
): CompactionPreparation | undefined {
if (
pathEntries.length > 0 &&
pathEntries[pathEntries.length - 1].type === "compaction"
) {
return undefined;
}
let prevCompactionIndex = -1;
for (let i = pathEntries.length - 1; i >= 0; i--) {
if (pathEntries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = pathEntries.length;
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
const usageMessages: AgentMessage[] = [];
for (let i = usageStart; i < boundaryEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) usageMessages.push(msg);
}
const tokensBefore = estimateContextTokens(usageMessages).tokens;
const cutPoint = findCutPoint(
pathEntries,
boundaryStart,
boundaryEnd,
settings.keepRecentTokens,
);
// Get UUID of first kept entry
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
if (!firstKeptEntry?.id) {
return undefined; // Session needs migration
}
const firstKeptEntryId = firstKeptEntry.id;
const historyEnd = cutPoint.isSplitTurn
? cutPoint.turnStartIndex
: cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) messagesToSummarize.push(msg);
}
// Messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AgentMessage[] = [];
if (cutPoint.isSplitTurn) {
for (
let i = cutPoint.turnStartIndex;
i < cutPoint.firstKeptEntryIndex;
i++
) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) turnPrefixMessages.push(msg);
}
}
// Get previous summary for iterative update
let previousSummary: string | undefined;
if (prevCompactionIndex >= 0) {
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
previousSummary = prevCompaction.summary;
}
// Extract file operations from messages and previous compaction
const fileOps = extractFileOperations(
messagesToSummarize,
pathEntries,
prevCompactionIndex,
);
// Also extract file ops from turn prefix if splitting
if (cutPoint.isSplitTurn) {
for (const msg of turnPrefixMessages) {
extractFileOpsFromMessage(msg, fileOps);
}
}
return {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn: cutPoint.isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
};
}
// ============================================================================
// Main compaction function
// ============================================================================
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
Summarize the prefix to provide context for the retained suffix:
## Original Request
[What did the user ask for in this turn?]
## Early Progress
- [Key decisions and work done in the prefix]
## Context for Suffix
- [Information needed to understand the retained recent work]
Be concise. Focus on what's needed to understand the kept suffix.`;
/**
* Generate summaries for compaction using prepared data.
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
*
* @param preparation - Pre-calculated preparation from prepareCompaction()
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
preparation: CompactionPreparation,
model: Model<any>,
apiKey: string,
customInstructions?: string,
signal?: AbortSignal,
): Promise<CompactionResult> {
const {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
} = preparation;
// Generate summaries (can be parallel if both needed) and merge into one
let summary: string;
if (isSplitTurn && turnPrefixMessages.length > 0) {
// Generate both summaries in parallel
const [historyResult, turnPrefixResult] = await Promise.all([
messagesToSummarize.length > 0
? generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(
turnPrefixMessages,
model,
settings.reserveTokens,
apiKey,
signal,
),
]);
// Merge into single summary
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
} else {
// Just generate history summary
summary = await generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
);
}
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
if (!firstKeptEntryId) {
throw new Error(
"First kept entry has no UUID - session may need migration",
);
}
return {
summary,
firstKeptEntryId,
tokensBefore,
details: { readFiles, modifiedFiles } as CompactionDetails,
};
}
/**
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
): Promise<string> {
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
{ maxTokens, signal, apiKey },
);
if (response.stopReason === "error") {
throw new Error(
`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`,
);
}
return response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
}

View file

@ -0,0 +1,7 @@
/**
* Compaction and summarization utilities.
*/
export * from "./branch-summarization.js";
export * from "./compaction.js";
export * from "./utils.js";

View file

@ -0,0 +1,167 @@
/**
* Shared utilities for compaction and branch summarization.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
// ============================================================================
// File Operation Tracking
// ============================================================================
export interface FileOperations {
read: Set<string>;
written: Set<string>;
edited: Set<string>;
}
export function createFileOps(): FileOperations {
return {
read: new Set(),
written: new Set(),
edited: new Set(),
};
}
/**
* Extract file operations from tool calls in an assistant message.
*/
export function extractFileOpsFromMessage(
message: AgentMessage,
fileOps: FileOperations,
): void {
if (message.role !== "assistant") return;
if (!("content" in message) || !Array.isArray(message.content)) return;
for (const block of message.content) {
if (typeof block !== "object" || block === null) continue;
if (!("type" in block) || block.type !== "toolCall") continue;
if (!("arguments" in block) || !("name" in block)) continue;
const args = block.arguments as Record<string, unknown> | undefined;
if (!args) continue;
const path = typeof args.path === "string" ? args.path : undefined;
if (!path) continue;
switch (block.name) {
case "read":
fileOps.read.add(path);
break;
case "write":
fileOps.written.add(path);
break;
case "edit":
fileOps.edited.add(path);
break;
}
}
}
/**
* Compute final file lists from file operations.
* Returns readFiles (files only read, not modified) and modifiedFiles.
*/
export function computeFileLists(fileOps: FileOperations): {
readFiles: string[];
modifiedFiles: string[];
} {
const modified = new Set([...fileOps.edited, ...fileOps.written]);
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
const modifiedFiles = [...modified].sort();
return { readFiles: readOnly, modifiedFiles };
}
/**
* Format file operations as XML tags for summary.
*/
export function formatFileOperations(
readFiles: string[],
modifiedFiles: string[],
): string {
const sections: string[] = [];
if (readFiles.length > 0) {
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
}
if (modifiedFiles.length > 0) {
sections.push(
`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`,
);
}
if (sections.length === 0) return "";
return `\n\n${sections.join("\n\n")}`;
}
// ============================================================================
// Message Serialization
// ============================================================================
/**
* Serialize LLM messages to text for summarization.
* This prevents the model from treating it as a conversation to continue.
* Call convertToLlm() first to handle custom message types.
*/
export function serializeConversation(messages: Message[]): string {
const parts: string[] = [];
for (const msg of messages) {
if (msg.role === "user") {
const content =
typeof msg.content === "string"
? msg.content
: msg.content
.filter(
(c): c is { type: "text"; text: string } => c.type === "text",
)
.map((c) => c.text)
.join("");
if (content) parts.push(`[User]: ${content}`);
} else if (msg.role === "assistant") {
const textParts: string[] = [];
const thinkingParts: string[] = [];
const toolCalls: string[] = [];
for (const block of msg.content) {
if (block.type === "text") {
textParts.push(block.text);
} else if (block.type === "thinking") {
thinkingParts.push(block.thinking);
} else if (block.type === "toolCall") {
const args = block.arguments as Record<string, unknown>;
const argsStr = Object.entries(args)
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
.join(", ");
toolCalls.push(`${block.name}(${argsStr})`);
}
}
if (thinkingParts.length > 0) {
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
}
if (textParts.length > 0) {
parts.push(`[Assistant]: ${textParts.join("\n")}`);
}
if (toolCalls.length > 0) {
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
}
} else if (msg.role === "toolResult") {
const content = msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) {
parts.push(`[Tool result]: ${content}`);
}
}
}
return parts.join("\n\n");
}
// ============================================================================
// Summarization System Prompt
// ============================================================================
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;

View file

@ -0,0 +1,3 @@
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";

View file

@ -0,0 +1,15 @@
export interface ResourceCollision {
resourceType: "extension" | "skill" | "prompt" | "theme";
name: string; // skill name, command/tool/flag name, prompt name, theme name
winnerPath: string;
loserPath: string;
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
loserSource?: string;
}
export interface ResourceDiagnostic {
type: "warning" | "error" | "collision";
message: string;
path?: string;
collision?: ResourceCollision;
}

View file

@ -0,0 +1,33 @@
import { EventEmitter } from "node:events";
export interface EventBus {
emit(channel: string, data: unknown): void;
on(channel: string, handler: (data: unknown) => void): () => void;
}
export interface EventBusController extends EventBus {
clear(): void;
}
export function createEventBus(): EventBusController {
const emitter = new EventEmitter();
return {
emit: (channel, data) => {
emitter.emit(channel, data);
},
on: (channel, handler) => {
const safeHandler = async (data: unknown) => {
try {
await handler(data);
} catch (err) {
console.error(`Event handler error (${channel}):`, err);
}
};
emitter.on(channel, safeHandler);
return () => emitter.off(channel, safeHandler);
},
clear: () => {
emitter.removeAllListeners();
},
};
}

View file

@ -0,0 +1,104 @@
/**
* Shared command execution utilities for extensions and custom tools.
*/
import { spawn } from "node:child_process";
/**
* Options for executing shell commands.
*/
export interface ExecOptions {
/** AbortSignal to cancel the command */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
/** Working directory */
cwd?: string;
}
/**
* Result of executing a shell command.
*/
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
killed: boolean;
}
/**
* Execute a shell command and return stdout/stderr/code.
* Supports timeout and abort signal.
*/
export async function execCommand(
command: string,
args: string[],
cwd: string,
options?: ExecOptions,
): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, {
cwd,
shell: false,
stdio: ["ignore", "pipe", "pipe"],
});
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout?.on("data", (data) => {
stdout += data.toString();
});
proc.stderr?.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: code ?? 0, killed });
});
proc.on("error", (_err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: 1, killed });
});
});
}

View file

@ -0,0 +1,271 @@
/**
* ANSI escape code to HTML converter.
*
* Converts terminal ANSI color/style codes to HTML with inline styles.
* Supports:
* - Standard foreground colors (30-37) and bright variants (90-97)
* - Standard background colors (40-47) and bright variants (100-107)
* - 256-color palette (38;5;N and 48;5;N)
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
* - Text styles: bold (1), dim (2), italic (3), underline (4)
* - Reset (0)
*/
// Standard ANSI color palette (0-15)
const ANSI_COLORS = [
"#000000", // 0: black
"#800000", // 1: red
"#008000", // 2: green
"#808000", // 3: yellow
"#000080", // 4: blue
"#800080", // 5: magenta
"#008080", // 6: cyan
"#c0c0c0", // 7: white
"#808080", // 8: bright black
"#ff0000", // 9: bright red
"#00ff00", // 10: bright green
"#ffff00", // 11: bright yellow
"#0000ff", // 12: bright blue
"#ff00ff", // 13: bright magenta
"#00ffff", // 14: bright cyan
"#ffffff", // 15: bright white
];
/**
* Convert 256-color index to hex.
*/
function color256ToHex(index: number): string {
// Standard colors (0-15)
if (index < 16) {
return ANSI_COLORS[index];
}
// Color cube (16-231): 6x6x6 = 216 colors
if (index < 232) {
const cubeIndex = index - 16;
const r = Math.floor(cubeIndex / 36);
const g = Math.floor((cubeIndex % 36) / 6);
const b = cubeIndex % 6;
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
}
// Grayscale (232-255): 24 shades
const gray = 8 + (index - 232) * 10;
const grayHex = gray.toString(16).padStart(2, "0");
return `#${grayHex}${grayHex}${grayHex}`;
}
/**
* Escape HTML special characters.
*/
function escapeHtml(text: string): string {
return text
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
interface TextStyle {
fg: string | null;
bg: string | null;
bold: boolean;
dim: boolean;
italic: boolean;
underline: boolean;
}
function createEmptyStyle(): TextStyle {
return {
fg: null,
bg: null,
bold: false,
dim: false,
italic: false,
underline: false,
};
}
function styleToInlineCSS(style: TextStyle): string {
const parts: string[] = [];
if (style.fg) parts.push(`color:${style.fg}`);
if (style.bg) parts.push(`background-color:${style.bg}`);
if (style.bold) parts.push("font-weight:bold");
if (style.dim) parts.push("opacity:0.6");
if (style.italic) parts.push("font-style:italic");
if (style.underline) parts.push("text-decoration:underline");
return parts.join(";");
}
function hasStyle(style: TextStyle): boolean {
return (
style.fg !== null ||
style.bg !== null ||
style.bold ||
style.dim ||
style.italic ||
style.underline
);
}
/**
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
*/
function applySgrCode(params: number[], style: TextStyle): void {
let i = 0;
while (i < params.length) {
const code = params[i];
if (code === 0) {
// Reset all
style.fg = null;
style.bg = null;
style.bold = false;
style.dim = false;
style.italic = false;
style.underline = false;
} else if (code === 1) {
style.bold = true;
} else if (code === 2) {
style.dim = true;
} else if (code === 3) {
style.italic = true;
} else if (code === 4) {
style.underline = true;
} else if (code === 22) {
// Reset bold/dim
style.bold = false;
style.dim = false;
} else if (code === 23) {
style.italic = false;
} else if (code === 24) {
style.underline = false;
} else if (code >= 30 && code <= 37) {
// Standard foreground colors
style.fg = ANSI_COLORS[code - 30];
} else if (code === 38) {
// Extended foreground color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 38;5;N
style.fg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 38;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.fg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 39) {
// Default foreground
style.fg = null;
} else if (code >= 40 && code <= 47) {
// Standard background colors
style.bg = ANSI_COLORS[code - 40];
} else if (code === 48) {
// Extended background color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 48;5;N
style.bg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 48;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.bg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 49) {
// Default background
style.bg = null;
} else if (code >= 90 && code <= 97) {
// Bright foreground colors
style.fg = ANSI_COLORS[code - 90 + 8];
} else if (code >= 100 && code <= 107) {
// Bright background colors
style.bg = ANSI_COLORS[code - 100 + 8];
}
// Ignore unrecognized codes
i++;
}
}
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
/**
* Convert ANSI-escaped text to HTML with inline styles.
*/
export function ansiToHtml(text: string): string {
const style = createEmptyStyle();
let result = "";
let lastIndex = 0;
let inSpan = false;
// Reset regex state
ANSI_REGEX.lastIndex = 0;
let match = ANSI_REGEX.exec(text);
while (match !== null) {
// Add text before this escape sequence
const beforeText = text.slice(lastIndex, match.index);
if (beforeText) {
result += escapeHtml(beforeText);
}
// Parse SGR parameters
const paramStr = match[1];
const params = paramStr
? paramStr.split(";").map((p) => parseInt(p, 10) || 0)
: [0];
// Close existing span if we have one
if (inSpan) {
result += "</span>";
inSpan = false;
}
// Apply the codes
applySgrCode(params, style);
// Open new span if we have any styling
if (hasStyle(style)) {
result += `<span style="${styleToInlineCSS(style)}">`;
inSpan = true;
}
lastIndex = match.index + match[0].length;
match = ANSI_REGEX.exec(text);
}
// Add remaining text
const remainingText = text.slice(lastIndex);
if (remainingText) {
result += escapeHtml(remainingText);
}
// Close any open span
if (inSpan) {
result += "</span>";
}
return result;
}
/**
* Convert array of ANSI-escaped lines to HTML.
* Each line is wrapped in a div element.
*/
export function ansiLinesToHtml(lines: string[]): string {
return lines
.map(
(line) => `<div class="ansi-line">${ansiToHtml(line) || "&nbsp;"}</div>`,
)
.join("\n");
}

View file

@ -0,0 +1,353 @@
import type { AgentState } from "@mariozechner/pi-agent-core";
import { existsSync, readFileSync, writeFileSync } from "fs";
import { basename, join } from "path";
import { APP_NAME, getExportTemplateDir } from "../../config.js";
import {
getResolvedThemeColors,
getThemeExportColors,
} from "../../modes/interactive/theme/theme.js";
import type { ToolInfo } from "../extensions/types.js";
import type { SessionEntry } from "../session-manager.js";
import { SessionManager } from "../session-manager.js";
/**
* Interface for rendering custom tools to HTML.
* Used by agent-session to pre-render extension tool output.
*/
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined;
}
/** Pre-rendered HTML for a custom tool call and result */
interface RenderedToolHtml {
callHtml?: string;
resultHtml?: string;
}
export interface ExportOptions {
outputPath?: string;
themeName?: string;
/** Optional tool renderer for custom tools */
toolRenderer?: ToolHtmlRenderer;
}
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
function parseColor(
color: string,
): { r: number; g: number; b: number } | undefined {
const hexMatch = color.match(
/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/,
);
if (hexMatch) {
return {
r: Number.parseInt(hexMatch[1], 16),
g: Number.parseInt(hexMatch[2], 16),
b: Number.parseInt(hexMatch[3], 16),
};
}
const rgbMatch = color.match(
/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/,
);
if (rgbMatch) {
return {
r: Number.parseInt(rgbMatch[1], 10),
g: Number.parseInt(rgbMatch[2], 10),
b: Number.parseInt(rgbMatch[3], 10),
};
}
return undefined;
}
/** Calculate relative luminance of a color (0-1, higher = lighter). */
function getLuminance(r: number, g: number, b: number): number {
const toLinear = (c: number) => {
const s = c / 255;
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
};
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
}
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
function adjustBrightness(color: string, factor: number): string {
const parsed = parseColor(color);
if (!parsed) return color;
const adjust = (c: number) =>
Math.min(255, Math.max(0, Math.round(c * factor)));
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
}
/** Derive export background colors from a base color (e.g., userMessageBg). */
function deriveExportColors(baseColor: string): {
pageBg: string;
cardBg: string;
infoBg: string;
} {
const parsed = parseColor(baseColor);
if (!parsed) {
return {
pageBg: "rgb(24, 24, 30)",
cardBg: "rgb(30, 30, 36)",
infoBg: "rgb(60, 55, 40)",
};
}
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
const isLight = luminance > 0.5;
if (isLight) {
return {
pageBg: adjustBrightness(baseColor, 0.96),
cardBg: baseColor,
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
};
}
return {
pageBg: adjustBrightness(baseColor, 0.7),
cardBg: adjustBrightness(baseColor, 0.85),
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
};
}
/**
* Generate CSS custom property declarations from theme colors.
*/
function generateThemeVars(themeName?: string): string {
const colors = getResolvedThemeColors(themeName);
const lines: string[] = [];
for (const [key, value] of Object.entries(colors)) {
lines.push(`--${key}: ${value};`);
}
// Use explicit theme export colors if available, otherwise derive from userMessageBg
const themeExport = getThemeExportColors(themeName);
const userMessageBg = colors.userMessageBg || "#343541";
const derivedColors = deriveExportColors(userMessageBg);
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
return lines.join("\n ");
}
interface SessionData {
header: ReturnType<SessionManager["getHeader"]>;
entries: ReturnType<SessionManager["getEntries"]>;
leafId: string | null;
systemPrompt?: string;
tools?: ToolInfo[];
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
renderedTools?: Record<string, RenderedToolHtml>;
}
/**
* Core HTML generation logic shared by both export functions.
*/
function generateHtml(sessionData: SessionData, themeName?: string): string {
const templateDir = getExportTemplateDir();
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
const markedJs = readFileSync(
join(templateDir, "vendor", "marked.min.js"),
"utf-8",
);
const hljsJs = readFileSync(
join(templateDir, "vendor", "highlight.min.js"),
"utf-8",
);
const themeVars = generateThemeVars(themeName);
const colors = getResolvedThemeColors(themeName);
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
const bodyBg = exportColors.pageBg;
const containerBg = exportColors.cardBg;
const infoBg = exportColors.infoBg;
// Base64 encode session data to avoid escaping issues
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString(
"base64",
);
// Build the CSS with theme variables injected
const css = templateCss
.replace("{{THEME_VARS}}", themeVars)
.replace("{{BODY_BG}}", bodyBg)
.replace("{{CONTAINER_BG}}", containerBg)
.replace("{{INFO_BG}}", infoBg);
return template
.replace("{{CSS}}", css)
.replace("{{JS}}", templateJs)
.replace("{{SESSION_DATA}}", sessionDataBase64)
.replace("{{MARKED_JS}}", markedJs)
.replace("{{HIGHLIGHT_JS}}", hljsJs);
}
/** Built-in tool names that have custom rendering in template.js */
const BUILTIN_TOOLS = new Set([
"bash",
"read",
"write",
"edit",
"ls",
"find",
"grep",
]);
/**
* Pre-render custom tools to HTML using their TUI renderers.
*/
function preRenderCustomTools(
entries: SessionEntry[],
toolRenderer: ToolHtmlRenderer,
): Record<string, RenderedToolHtml> {
const renderedTools: Record<string, RenderedToolHtml> = {};
for (const entry of entries) {
if (entry.type !== "message") continue;
const msg = entry.message;
// Find tool calls in assistant messages
if (msg.role === "assistant" && Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
if (callHtml) {
renderedTools[block.id] = { callHtml };
}
}
}
}
// Find tool results
if (msg.role === "toolResult" && msg.toolCallId) {
const toolName = msg.toolName || "";
// Only render if we have a pre-rendered call OR it's not a built-in tool
const existing = renderedTools[msg.toolCallId];
if (existing || !BUILTIN_TOOLS.has(toolName)) {
const resultHtml = toolRenderer.renderResult(
toolName,
msg.content,
msg.details,
msg.isError || false,
);
if (resultHtml) {
renderedTools[msg.toolCallId] = {
...existing,
resultHtml,
};
}
}
}
}
return renderedTools;
}
/**
* Export session to HTML using SessionManager and AgentState.
* Used by TUI's /export command.
*/
export async function exportSessionToHtml(
sm: SessionManager,
state?: AgentState,
options?: ExportOptions | string,
): Promise<string> {
const opts: ExportOptions =
typeof options === "string" ? { outputPath: options } : options || {};
const sessionFile = sm.getSessionFile();
if (!sessionFile) {
throw new Error("Cannot export in-memory session to HTML");
}
if (!existsSync(sessionFile)) {
throw new Error("Nothing to export yet - start a conversation first");
}
const entries = sm.getEntries();
// Pre-render custom tools if a tool renderer is provided
let renderedTools: Record<string, RenderedToolHtml> | undefined;
if (opts.toolRenderer) {
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
// Only include if we actually rendered something
if (Object.keys(renderedTools).length === 0) {
renderedTools = undefined;
}
}
const sessionData: SessionData = {
header: sm.getHeader(),
entries,
leafId: sm.getLeafId(),
systemPrompt: state?.systemPrompt,
tools: state?.tools?.map((t) => ({
name: t.name,
description: t.description,
parameters: t.parameters,
})),
renderedTools,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const sessionBasename = basename(sessionFile, ".jsonl");
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}
/**
* Export session file to HTML (standalone, without AgentState).
* Used by CLI for exporting arbitrary session files.
*/
export async function exportFromFile(
inputPath: string,
options?: ExportOptions | string,
): Promise<string> {
const opts: ExportOptions =
typeof options === "string" ? { outputPath: options } : options || {};
if (!existsSync(inputPath)) {
throw new Error(`File not found: ${inputPath}`);
}
const sm = SessionManager.open(inputPath);
const sessionData: SessionData = {
header: sm.getHeader(),
entries: sm.getEntries(),
leafId: sm.getLeafId(),
systemPrompt: undefined,
tools: undefined,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const inputBasename = basename(inputPath, ".jsonl");
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}

View file

@ -0,0 +1,971 @@
:root {
{{THEME_VARS}}
--body-bg: {{BODY_BG}};
--container-bg: {{CONTAINER_BG}};
--info-bg: {{INFO_BG}};
}
* { margin: 0; padding: 0; box-sizing: border-box; }
:root {
--line-height: 18px; /* 12px font * 1.5 */
}
body {
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
font-size: 12px;
line-height: var(--line-height);
color: var(--text);
background: var(--body-bg);
}
#app {
display: flex;
min-height: 100vh;
}
/* Sidebar */
#sidebar {
width: 400px;
background: var(--container-bg);
flex-shrink: 0;
display: flex;
flex-direction: column;
position: sticky;
top: 0;
height: 100vh;
border-right: 1px solid var(--dim);
}
.sidebar-header {
padding: 8px 12px;
flex-shrink: 0;
}
.sidebar-controls {
padding: 8px 8px 4px 8px;
}
.sidebar-search {
width: 100%;
box-sizing: border-box;
padding: 4px 8px;
font-size: 11px;
font-family: inherit;
background: var(--body-bg);
color: var(--text);
border: 1px solid var(--dim);
border-radius: 3px;
}
.sidebar-filters {
display: flex;
padding: 4px 8px 8px 8px;
gap: 4px;
align-items: center;
flex-wrap: wrap;
}
.sidebar-search:focus {
outline: none;
border-color: var(--accent);
}
.sidebar-search::placeholder {
color: var(--muted);
}
.filter-btn {
padding: 3px 8px;
font-size: 10px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
.filter-btn:hover {
color: var(--text);
border-color: var(--text);
}
.filter-btn.active {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.sidebar-close {
display: none;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
margin-left: auto;
}
.sidebar-close:hover {
color: var(--text);
border-color: var(--text);
}
.tree-container {
flex: 1;
overflow: auto;
padding: 4px 0;
}
.tree-node {
padding: 0 8px;
cursor: pointer;
display: flex;
align-items: baseline;
font-size: 11px;
line-height: 13px;
white-space: nowrap;
}
.tree-node:hover {
background: var(--selectedBg);
}
.tree-node.active {
background: var(--selectedBg);
}
.tree-node.active .tree-content {
font-weight: bold;
}
.tree-node.in-path {
background: color-mix(in srgb, var(--accent) 10%, transparent);
}
.tree-node:not(.in-path) {
opacity: 0.5;
}
.tree-node:not(.in-path):hover {
opacity: 1;
}
.tree-prefix {
color: var(--muted);
flex-shrink: 0;
font-family: monospace;
white-space: pre;
}
.tree-marker {
color: var(--accent);
flex-shrink: 0;
}
.tree-content {
color: var(--text);
}
.tree-role-user {
color: var(--accent);
}
.tree-role-assistant {
color: var(--success);
}
.tree-role-tool {
color: var(--muted);
}
.tree-muted {
color: var(--muted);
}
.tree-error {
color: var(--error);
}
.tree-compaction {
color: var(--borderAccent);
}
.tree-branch-summary {
color: var(--warning);
}
.tree-custom-message {
color: var(--customMessageLabel);
}
.tree-status {
padding: 4px 12px;
font-size: 10px;
color: var(--muted);
flex-shrink: 0;
}
/* Main content */
#content {
flex: 1;
overflow-y: auto;
padding: var(--line-height) calc(var(--line-height) * 2);
display: flex;
flex-direction: column;
align-items: center;
}
#content > * {
width: 100%;
max-width: 800px;
}
/* Help bar */
.help-bar {
font-size: 11px;
color: var(--warning);
margin-bottom: var(--line-height);
display: flex;
align-items: center;
gap: 12px;
}
.download-json-btn {
font-size: 10px;
padding: 2px 8px;
background: var(--container-bg);
border: 1px solid var(--border);
border-radius: 3px;
color: var(--text);
cursor: pointer;
font-family: inherit;
}
.download-json-btn:hover {
background: var(--hover);
border-color: var(--borderAccent);
}
/* Header */
.header {
background: var(--container-bg);
border-radius: 4px;
padding: var(--line-height);
margin-bottom: var(--line-height);
}
.header h1 {
font-size: 12px;
font-weight: bold;
color: var(--borderAccent);
margin-bottom: var(--line-height);
}
.header-info {
display: flex;
flex-direction: column;
gap: 0;
font-size: 11px;
}
.info-item {
color: var(--dim);
display: flex;
align-items: baseline;
}
.info-label {
font-weight: 600;
margin-right: 8px;
min-width: 100px;
}
.info-value {
color: var(--text);
flex: 1;
}
/* Messages */
#messages {
display: flex;
flex-direction: column;
gap: var(--line-height);
}
.message-timestamp {
font-size: 10px;
color: var(--dim);
opacity: 0.8;
}
.user-message {
background: var(--userMessageBg);
color: var(--userMessageText);
padding: var(--line-height);
border-radius: 4px;
position: relative;
}
.assistant-message {
padding: 0;
position: relative;
}
/* Copy link button - appears on hover */
.copy-link-btn {
position: absolute;
top: 8px;
right: 8px;
width: 28px;
height: 28px;
padding: 6px;
background: var(--container-bg);
border: 1px solid var(--dim);
border-radius: 4px;
color: var(--muted);
cursor: pointer;
opacity: 0;
transition: opacity 0.15s, background 0.15s, color 0.15s;
display: flex;
align-items: center;
justify-content: center;
z-index: 10;
}
.user-message:hover .copy-link-btn,
.assistant-message:hover .copy-link-btn {
opacity: 1;
}
.copy-link-btn:hover {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.copy-link-btn.copied {
background: var(--success, #22c55e);
color: white;
border-color: var(--success, #22c55e);
}
/* Highlight effect for deep-linked messages */
.user-message.highlight,
.assistant-message.highlight {
animation: highlight-pulse 2s ease-out;
}
@keyframes highlight-pulse {
0% {
box-shadow: 0 0 0 3px var(--accent);
}
100% {
box-shadow: 0 0 0 0 transparent;
}
}
.assistant-message > .message-timestamp {
padding-left: var(--line-height);
}
.assistant-text {
padding: var(--line-height);
padding-bottom: 0;
}
.message-timestamp + .assistant-text,
.message-timestamp + .thinking-block {
padding-top: 0;
}
.thinking-block + .assistant-text {
padding-top: 0;
}
.thinking-text {
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
white-space: pre-wrap;
}
.message-timestamp + .thinking-block .thinking-text,
.message-timestamp + .thinking-block .thinking-collapsed {
padding-top: 0;
}
.thinking-collapsed {
display: none;
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
}
/* Tool execution */
.tool-execution {
padding: var(--line-height);
border-radius: 4px;
}
.tool-execution + .tool-execution {
margin-top: var(--line-height);
}
.assistant-text + .tool-execution {
margin-top: var(--line-height);
}
.tool-execution.pending { background: var(--toolPendingBg); }
.tool-execution.success { background: var(--toolSuccessBg); }
.tool-execution.error { background: var(--toolErrorBg); }
.tool-header, .tool-name {
font-weight: bold;
}
.tool-path {
color: var(--accent);
word-break: break-all;
}
.line-numbers {
color: var(--warning);
}
.line-count {
color: var(--dim);
}
.tool-command {
font-weight: bold;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
}
.tool-output {
margin-top: var(--line-height);
color: var(--toolOutput);
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
font-family: inherit;
overflow-x: auto;
}
.tool-output > div,
.output-preview,
.output-full {
margin: 0;
padding: 0;
line-height: var(--line-height);
}
.tool-output pre {
margin: 0;
padding: 0;
font-family: inherit;
color: inherit;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
}
.tool-output code {
padding: 0;
background: none;
color: var(--text);
}
.tool-output.expandable {
cursor: pointer;
}
.tool-output.expandable:hover {
opacity: 0.9;
}
.tool-output.expandable .output-full {
display: none;
}
.tool-output.expandable.expanded .output-preview {
display: none;
}
.tool-output.expandable.expanded .output-full {
display: block;
}
.ansi-line {
white-space: pre-wrap;
}
.tool-images {
}
.tool-image {
max-width: 100%;
max-height: 500px;
border-radius: 4px;
margin: var(--line-height) 0;
}
.expand-hint {
color: var(--toolOutput);
}
/* Diff */
.tool-diff {
font-size: 11px;
overflow-x: auto;
white-space: pre;
}
.diff-added { color: var(--toolDiffAdded); }
.diff-removed { color: var(--toolDiffRemoved); }
.diff-context { color: var(--toolDiffContext); }
/* Model change */
.model-change {
padding: 0 var(--line-height);
color: var(--dim);
font-size: 11px;
}
.model-name {
color: var(--borderAccent);
font-weight: bold;
}
/* Compaction / Branch Summary - matches customMessage colors from TUI */
.compaction {
background: var(--customMessageBg);
border-radius: 4px;
padding: var(--line-height);
cursor: pointer;
}
.compaction-label {
color: var(--customMessageLabel);
font-weight: bold;
}
.compaction-collapsed {
color: var(--customMessageText);
}
.compaction-content {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
margin-top: var(--line-height);
}
.compaction.expanded .compaction-collapsed {
display: none;
}
.compaction.expanded .compaction-content {
display: block;
}
/* System prompt */
.system-prompt {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.system-prompt.expandable {
cursor: pointer;
}
.system-prompt-header {
font-weight: bold;
color: var(--customMessageLabel);
}
.system-prompt-preview {
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt-expand-hint {
color: var(--muted);
font-style: italic;
margin-top: 4px;
}
.system-prompt-full {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt.expanded .system-prompt-preview,
.system-prompt.expanded .system-prompt-expand-hint {
display: none;
}
.system-prompt.expanded .system-prompt-full {
display: block;
}
.system-prompt.provider-prompt {
border-left: 3px solid var(--warning);
}
.system-prompt-note {
font-size: 10px;
font-style: italic;
color: var(--muted);
margin-top: 4px;
}
/* Tools list */
.tools-list {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.tools-header {
font-weight: bold;
color: var(--customMessageLabel);
margin-bottom: var(--line-height);
}
.tool-item {
font-size: 11px;
}
.tool-item-name {
font-weight: bold;
color: var(--text);
}
.tool-item-desc {
color: var(--dim);
}
.tool-params-hint {
color: var(--muted);
font-style: italic;
}
.tool-item:has(.tool-params-hint) {
cursor: pointer;
}
.tool-params-hint::after {
content: '[click to show parameters]';
}
.tool-item.params-expanded .tool-params-hint::after {
content: '[hide parameters]';
}
.tool-params-content {
display: none;
margin-top: 4px;
margin-left: 12px;
padding-left: 8px;
border-left: 1px solid var(--dim);
}
.tool-item.params-expanded .tool-params-content {
display: block;
}
.tool-param {
margin-bottom: 4px;
font-size: 11px;
}
.tool-param-name {
font-weight: bold;
color: var(--text);
}
.tool-param-type {
color: var(--dim);
font-style: italic;
}
.tool-param-required {
color: var(--warning, #e8a838);
font-size: 10px;
}
.tool-param-optional {
color: var(--dim);
font-size: 10px;
}
.tool-param-desc {
color: var(--dim);
margin-left: 8px;
}
/* Hook/custom messages */
.hook-message {
background: var(--customMessageBg);
color: var(--customMessageText);
padding: var(--line-height);
border-radius: 4px;
}
.hook-type {
color: var(--customMessageLabel);
font-weight: bold;
}
/* Branch summary */
.branch-summary {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
}
.branch-summary-header {
font-weight: bold;
color: var(--borderAccent);
}
/* Error */
.error-text {
color: var(--error);
padding: 0 var(--line-height);
}
.tool-error {
color: var(--error);
}
/* Images */
.message-images {
margin-bottom: 12px;
}
.message-image {
max-width: 100%;
max-height: 400px;
border-radius: 4px;
margin: var(--line-height) 0;
}
/* Markdown content */
.markdown-content h1,
.markdown-content h2,
.markdown-content h3,
.markdown-content h4,
.markdown-content h5,
.markdown-content h6 {
color: var(--mdHeading);
margin: var(--line-height) 0 0 0;
font-weight: bold;
}
.markdown-content h1 { font-size: 1em; }
.markdown-content h2 { font-size: 1em; }
.markdown-content h3 { font-size: 1em; }
.markdown-content h4 { font-size: 1em; }
.markdown-content h5 { font-size: 1em; }
.markdown-content h6 { font-size: 1em; }
.markdown-content p { margin: 0; }
.markdown-content p + p { margin-top: var(--line-height); }
.markdown-content a {
color: var(--mdLink);
text-decoration: underline;
}
.markdown-content code {
background: rgba(128, 128, 128, 0.2);
color: var(--mdCode);
padding: 0 4px;
border-radius: 3px;
font-family: inherit;
}
.markdown-content pre {
background: transparent;
margin: var(--line-height) 0;
overflow-x: auto;
}
.markdown-content pre code {
display: block;
background: none;
color: var(--text);
}
.markdown-content blockquote {
border-left: 3px solid var(--mdQuoteBorder);
padding-left: var(--line-height);
margin: var(--line-height) 0;
color: var(--mdQuote);
font-style: italic;
}
.markdown-content ul,
.markdown-content ol {
margin: var(--line-height) 0;
padding-left: calc(var(--line-height) * 2);
}
.markdown-content li { margin: 0; }
.markdown-content li::marker { color: var(--mdListBullet); }
.markdown-content hr {
border: none;
border-top: 1px solid var(--mdHr);
margin: var(--line-height) 0;
}
.markdown-content table {
border-collapse: collapse;
margin: 0.5em 0;
width: 100%;
}
.markdown-content th,
.markdown-content td {
border: 1px solid var(--mdCodeBlockBorder);
padding: 6px 10px;
text-align: left;
}
.markdown-content th {
background: rgba(128, 128, 128, 0.1);
font-weight: bold;
}
.markdown-content img {
max-width: 100%;
border-radius: 4px;
}
/* Syntax highlighting */
.hljs { background: transparent; color: var(--text); }
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
/* Types: hljs v11 uses .hljs-title.class_ for class names */
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
.hljs-operator { color: var(--syntaxOperator); }
.hljs-punctuation { color: var(--syntaxPunctuation); }
.hljs-subst { color: var(--text); }
/* Footer */
.footer {
margin-top: 48px;
padding: 20px;
text-align: center;
color: var(--dim);
font-size: 10px;
}
/* Mobile */
#hamburger {
display: none;
position: fixed;
top: 10px;
left: 10px;
z-index: 100;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
#hamburger:hover {
color: var(--text);
border-color: var(--text);
}
#sidebar-overlay {
display: none;
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
z-index: 98;
}
@media (max-width: 900px) {
#sidebar {
position: fixed;
left: -400px;
width: 400px;
top: 0;
bottom: 0;
height: 100vh;
z-index: 99;
transition: left 0.3s;
}
#sidebar.open {
left: 0;
}
#sidebar-overlay.open {
display: block;
}
#hamburger {
display: block;
}
.sidebar-close {
display: block;
}
#content {
padding: var(--line-height) 16px;
}
#content > * {
max-width: 100%;
}
}
@media (max-width: 500px) {
#sidebar {
width: 100vw;
left: -100vw;
}
}
@media print {
#sidebar, #sidebar-toggle { display: none !important; }
body { background: white; color: black; }
#content { max-width: none; }
}

View file

@ -0,0 +1,54 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Session Export</title>
<style>
{{CSS}}
</style>
</head>
<body>
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
<div id="sidebar-overlay"></div>
<div id="app">
<aside id="sidebar">
<div class="sidebar-header">
<div class="sidebar-controls">
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
</div>
<div class="sidebar-filters">
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
<button class="sidebar-close" id="sidebar-close" title="Close"></button>
</div>
</div>
<div class="tree-container" id="tree-container"></div>
<div class="tree-status" id="tree-status"></div>
</aside>
<main id="content">
<div id="header-container"></div>
<div id="messages"></div>
</main>
<div id="image-modal" class="image-modal">
<img id="modal-image" src="" alt="">
</div>
</div>
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
<!-- Vendored libraries -->
<script>{{MARKED_JS}}</script>
<!-- highlight.js -->
<script>{{HIGHLIGHT_JS}}</script>
<!-- Main application code -->
<script>
{{JS}}
</script>
</body>
</html>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,112 @@
/**
* Tool HTML renderer for custom tools in HTML export.
*
* Renders custom tool calls and results to HTML by invoking their TUI renderers
* and converting the ANSI output to HTML.
*/
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ToolDefinition } from "../extensions/types.js";
import { ansiLinesToHtml } from "./ansi-to-html.js";
export interface ToolHtmlRendererDeps {
/** Function to look up tool definition by name */
getToolDefinition: (name: string) => ToolDefinition | undefined;
/** Theme for styling */
theme: Theme;
/** Terminal width for rendering (default: 100) */
width?: number;
}
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined;
}
/**
* Create a tool HTML renderer.
*
* The renderer looks up tool definitions and invokes their renderCall/renderResult
* methods, converting the resulting TUI Component output (ANSI) to HTML.
*/
export function createToolHtmlRenderer(
deps: ToolHtmlRendererDeps,
): ToolHtmlRenderer {
const { getToolDefinition, theme, width = 100 } = deps;
return {
renderCall(toolName: string, args: unknown): string | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderCall) {
return undefined;
}
const component = toolDef.renderCall(args, theme);
if (!component) {
return undefined;
}
const lines = component.render(width);
return ansiLinesToHtml(lines);
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderResult) {
return undefined;
}
// Build AgentToolResult from content array
// Cast content since session storage uses generic object types
const agentToolResult = {
content: result as (TextContent | ImageContent)[],
details,
isError,
};
// Always render expanded, client-side will apply truncation
const component = toolDef.renderResult(
agentToolResult,
{ expanded: true, isPartial: false },
theme,
);
if (!component) {
return undefined;
}
const lines = component.render(width);
return ansiLinesToHtml(lines);
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
};
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,170 @@
/**
* Extension system for lifecycle events and custom tools.
*/
export type {
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
} from "../slash-commands.js";
export {
createExtensionRuntime,
discoverAndLoadExtensions,
loadExtensionFromFactory,
loadExtensions,
} from "./loader.js";
export type {
ExtensionErrorListener,
ForkHandler,
NavigateTreeHandler,
NewSessionHandler,
ShutdownHandler,
SwitchSessionHandler,
} from "./runner.js";
export { ExtensionRunner } from "./runner.js";
export type {
AgentEndEvent,
AgentStartEvent,
// Re-exports
AgentToolResult,
AgentToolUpdateCallback,
// App keybindings (for custom editors)
AppAction,
AppendEntryHandler,
// Events - Tool (ToolCallEvent types)
BashToolCallEvent,
BashToolResultEvent,
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
// Context
CompactOptions,
// Events - Agent
ContextEvent,
// Event Results
ContextEventResult,
ContextUsage,
CustomToolCallEvent,
CustomToolResultEvent,
EditToolCallEvent,
EditToolResultEvent,
ExecOptions,
ExecResult,
Extension,
ExtensionActions,
// API
ExtensionAPI,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
// Errors
ExtensionError,
ExtensionEvent,
ExtensionFactory,
ExtensionFlag,
ExtensionHandler,
// Runtime
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
ExtensionUIDialogOptions,
ExtensionWidgetOptions,
FindToolCallEvent,
FindToolResultEvent,
GetActiveToolsHandler,
GetAllToolsHandler,
GetCommandsHandler,
GetThinkingLevelHandler,
GrepToolCallEvent,
GrepToolResultEvent,
// Events - Input
InputEvent,
InputEventResult,
InputSource,
KeybindingsManager,
LoadExtensionsResult,
LsToolCallEvent,
LsToolResultEvent,
// Events - Message
MessageEndEvent,
// Message Rendering
MessageRenderer,
MessageRenderOptions,
MessageStartEvent,
MessageUpdateEvent,
ModelSelectEvent,
ModelSelectSource,
// Provider Registration
ProviderConfig,
ProviderModelConfig,
ReadToolCallEvent,
ReadToolResultEvent,
// Commands
RegisteredCommand,
RegisteredTool,
// Events - Resources
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SendMessageHandler,
SendUserMessageHandler,
SessionBeforeCompactEvent,
SessionBeforeCompactResult,
SessionBeforeForkEvent,
SessionBeforeForkResult,
SessionBeforeSwitchEvent,
SessionBeforeSwitchResult,
SessionBeforeTreeEvent,
SessionBeforeTreeResult,
SessionCompactEvent,
SessionEvent,
SessionForkEvent,
SessionShutdownEvent,
// Events - Session
SessionStartEvent,
SessionSwitchEvent,
SessionTreeEvent,
SetActiveToolsHandler,
SetLabelHandler,
SetModelHandler,
SetThinkingLevelHandler,
TerminalInputHandler,
// Events - Tool
ToolCallEvent,
ToolCallEventResult,
// Tools
ToolDefinition,
// Events - Tool Execution
ToolExecutionEndEvent,
ToolExecutionStartEvent,
ToolExecutionUpdateEvent,
ToolInfo,
ToolRenderResultOptions,
ToolResultEvent,
ToolResultEventResult,
TreePreparation,
TurnEndEvent,
TurnStartEvent,
// Events - User Bash
UserBashEvent,
UserBashEventResult,
WidgetPlacement,
WriteToolCallEvent,
WriteToolResultEvent,
} from "./types.js";
// Type guards
export {
isBashToolResult,
isEditToolResult,
isFindToolResult,
isGrepToolResult,
isLsToolResult,
isReadToolResult,
isToolCallEventType,
isWriteToolResult,
} from "./types.js";
export {
wrapRegisteredTool,
wrapRegisteredTools,
wrapToolsWithExtensions,
wrapToolWithExtensions,
} from "./wrapper.js";

View file

@ -0,0 +1,607 @@
/**
* Extension loader - loads TypeScript extension modules using jiti.
*
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
*/
import * as fs from "node:fs";
import { createRequire } from "node:module";
import * as os from "node:os";
import * as path from "node:path";
import { fileURLToPath } from "node:url";
import { createJiti } from "@mariozechner/jiti";
import * as _bundledPiAgentCore from "@mariozechner/pi-agent-core";
import * as _bundledPiAi from "@mariozechner/pi-ai";
import * as _bundledPiAiOauth from "@mariozechner/pi-ai/oauth";
import type { KeyId } from "@mariozechner/pi-tui";
import * as _bundledPiTui from "@mariozechner/pi-tui";
// Static imports of packages that extensions may use.
// These MUST be static so Bun bundles them into the compiled binary.
// The virtualModules option then makes them available to extensions.
import * as _bundledTypebox from "@sinclair/typebox";
import { getAgentDir, isBunBinary } from "../../config.js";
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
// avoiding a circular dependency. Extensions can import from @mariozechner/pi-coding-agent.
import * as _bundledPiCodingAgent from "../../index.js";
import { createEventBus, type EventBus } from "../event-bus.js";
import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js";
import type {
Extension,
ExtensionAPI,
ExtensionFactory,
ExtensionRuntime,
LoadExtensionsResult,
MessageRenderer,
ProviderConfig,
RegisteredCommand,
ToolDefinition,
} from "./types.js";
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
const VIRTUAL_MODULES: Record<string, unknown> = {
"@sinclair/typebox": _bundledTypebox,
"@mariozechner/pi-agent-core": _bundledPiAgentCore,
"@mariozechner/pi-tui": _bundledPiTui,
"@mariozechner/pi-ai": _bundledPiAi,
"@mariozechner/pi-ai/oauth": _bundledPiAiOauth,
"@mariozechner/pi-coding-agent": _bundledPiCodingAgent,
};
const require = createRequire(import.meta.url);
/**
* Get aliases for jiti (used in Node.js/development mode).
* In Bun binary mode, virtualModules is used instead.
*/
let _aliases: Record<string, string> | null = null;
function getAliases(): Record<string, string> {
if (_aliases) return _aliases;
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const packageIndex = path.resolve(__dirname, "../..", "index.js");
const typeboxEntry = require.resolve("@sinclair/typebox");
const typeboxRoot = typeboxEntry.replace(
/[\\/]build[\\/]cjs[\\/]index\.js$/,
"",
);
const packagesRoot = path.resolve(__dirname, "../../../../");
const resolveWorkspaceOrImport = (
workspaceRelativePath: string,
specifier: string,
): string => {
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
if (fs.existsSync(workspacePath)) {
return workspacePath;
}
return fileURLToPath(import.meta.resolve(specifier));
};
_aliases = {
"@mariozechner/pi-coding-agent": packageIndex,
"@mariozechner/pi-agent-core": resolveWorkspaceOrImport(
"agent/dist/index.js",
"@mariozechner/pi-agent-core",
),
"@mariozechner/pi-tui": resolveWorkspaceOrImport(
"tui/dist/index.js",
"@mariozechner/pi-tui",
),
"@mariozechner/pi-ai": resolveWorkspaceOrImport(
"ai/dist/index.js",
"@mariozechner/pi-ai",
),
"@mariozechner/pi-ai/oauth": resolveWorkspaceOrImport(
"ai/dist/oauth.js",
"@mariozechner/pi-ai/oauth",
),
"@sinclair/typebox": typeboxRoot,
};
return _aliases;
}
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
function normalizeUnicodeSpaces(str: string): string {
return str.replace(UNICODE_SPACES, " ");
}
function expandPath(p: string): string {
const normalized = normalizeUnicodeSpaces(p);
if (normalized.startsWith("~/")) {
return path.join(os.homedir(), normalized.slice(2));
}
if (normalized.startsWith("~")) {
return path.join(os.homedir(), normalized.slice(1));
}
return normalized;
}
function resolvePath(extPath: string, cwd: string): string {
const expanded = expandPath(extPath);
if (path.isAbsolute(expanded)) {
return expanded;
}
return path.resolve(cwd, expanded);
}
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
/**
* Create a runtime with throwing stubs for action methods.
* Runner.bindCore() replaces these with real implementations.
*/
export function createExtensionRuntime(): ExtensionRuntime {
const notInitialized = () => {
throw new Error(
"Extension runtime not initialized. Action methods cannot be called during extension loading.",
);
};
const runtime: ExtensionRuntime = {
sendMessage: notInitialized,
sendUserMessage: notInitialized,
appendEntry: notInitialized,
setSessionName: notInitialized,
getSessionName: notInitialized,
setLabel: notInitialized,
getActiveTools: notInitialized,
getAllTools: notInitialized,
setActiveTools: notInitialized,
// registerTool() is valid during extension load; refresh is only needed post-bind.
refreshTools: () => {},
getCommands: notInitialized,
setModel: () =>
Promise.reject(new Error("Extension runtime not initialized")),
getThinkingLevel: notInitialized,
setThinkingLevel: notInitialized,
flagValues: new Map(),
pendingProviderRegistrations: [],
// Pre-bind: queue registrations so bindCore() can flush them once the
// model registry is available. bindCore() replaces both with direct calls.
registerProvider: (name, config) => {
runtime.pendingProviderRegistrations.push({ name, config });
},
unregisterProvider: (name) => {
runtime.pendingProviderRegistrations =
runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
},
};
return runtime;
}
/**
* Create the ExtensionAPI for an extension.
* Registration methods write to the extension object.
* Action methods delegate to the shared runtime.
*/
function createExtensionAPI(
extension: Extension,
runtime: ExtensionRuntime,
cwd: string,
eventBus: EventBus,
): ExtensionAPI {
const api = {
// Registration methods - write to extension
on(event: string, handler: HandlerFn): void {
const list = extension.handlers.get(event) ?? [];
list.push(handler);
extension.handlers.set(event, list);
},
registerTool(tool: ToolDefinition): void {
extension.tools.set(tool.name, {
definition: tool,
extensionPath: extension.path,
});
runtime.refreshTools();
},
registerCommand(
name: string,
options: Omit<RegisteredCommand, "name">,
): void {
extension.commands.set(name, { name, ...options });
},
registerShortcut(
shortcut: KeyId,
options: {
description?: string;
handler: (
ctx: import("./types.js").ExtensionContext,
) => Promise<void> | void;
},
): void {
extension.shortcuts.set(shortcut, {
shortcut,
extensionPath: extension.path,
...options,
});
},
registerFlag(
name: string,
options: {
description?: string;
type: "boolean" | "string";
default?: boolean | string;
},
): void {
extension.flags.set(name, {
name,
extensionPath: extension.path,
...options,
});
if (options.default !== undefined && !runtime.flagValues.has(name)) {
runtime.flagValues.set(name, options.default);
}
},
registerMessageRenderer<T>(
customType: string,
renderer: MessageRenderer<T>,
): void {
extension.messageRenderers.set(customType, renderer as MessageRenderer);
},
// Flag access - checks extension registered it, reads from runtime
getFlag(name: string): boolean | string | undefined {
if (!extension.flags.has(name)) return undefined;
return runtime.flagValues.get(name);
},
// Action methods - delegate to shared runtime
sendMessage(message, options): void {
runtime.sendMessage(message, options);
},
sendUserMessage(content, options): void {
runtime.sendUserMessage(content, options);
},
appendEntry(customType: string, data?: unknown): void {
runtime.appendEntry(customType, data);
},
setSessionName(name: string): void {
runtime.setSessionName(name);
},
getSessionName(): string | undefined {
return runtime.getSessionName();
},
setLabel(entryId: string, label: string | undefined): void {
runtime.setLabel(entryId, label);
},
exec(command: string, args: string[], options?: ExecOptions) {
return execCommand(command, args, options?.cwd ?? cwd, options);
},
getActiveTools(): string[] {
return runtime.getActiveTools();
},
getAllTools() {
return runtime.getAllTools();
},
setActiveTools(toolNames: string[]): void {
runtime.setActiveTools(toolNames);
},
getCommands() {
return runtime.getCommands();
},
setModel(model) {
return runtime.setModel(model);
},
getThinkingLevel() {
return runtime.getThinkingLevel();
},
setThinkingLevel(level) {
runtime.setThinkingLevel(level);
},
registerProvider(name: string, config: ProviderConfig) {
runtime.registerProvider(name, config);
},
unregisterProvider(name: string) {
runtime.unregisterProvider(name);
},
events: eventBus,
} as ExtensionAPI;
return api;
}
async function loadExtensionModule(extensionPath: string) {
const jiti = createJiti(import.meta.url, {
moduleCache: false,
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
// In Node.js/dev: use aliases to resolve to node_modules paths
...(isBunBinary
? { virtualModules: VIRTUAL_MODULES, tryNative: false }
: { alias: getAliases() }),
});
const module = await jiti.import(extensionPath, { default: true });
const factory = module as ExtensionFactory;
return typeof factory !== "function" ? undefined : factory;
}
/**
* Create an Extension object with empty collections.
*/
function createExtension(
extensionPath: string,
resolvedPath: string,
): Extension {
return {
path: extensionPath,
resolvedPath,
handlers: new Map(),
tools: new Map(),
messageRenderers: new Map(),
commands: new Map(),
flags: new Map(),
shortcuts: new Map(),
};
}
async function loadExtension(
extensionPath: string,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
): Promise<{ extension: Extension | null; error: string | null }> {
const resolvedPath = resolvePath(extensionPath, cwd);
try {
const factory = await loadExtensionModule(resolvedPath);
if (!factory) {
return {
extension: null,
error: `Extension does not export a valid factory function: ${extensionPath}`,
};
}
const extension = createExtension(extensionPath, resolvedPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return { extension, error: null };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
return { extension: null, error: `Failed to load extension: ${message}` };
}
}
/**
* Create an Extension from an inline factory function.
*/
export async function loadExtensionFromFactory(
factory: ExtensionFactory,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
extensionPath = "<inline>",
): Promise<Extension> {
const extension = createExtension(extensionPath, extensionPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return extension;
}
/**
* Load extensions from paths.
*/
export async function loadExtensions(
paths: string[],
cwd: string,
eventBus?: EventBus,
): Promise<LoadExtensionsResult> {
const extensions: Extension[] = [];
const errors: Array<{ path: string; error: string }> = [];
const resolvedEventBus = eventBus ?? createEventBus();
const runtime = createExtensionRuntime();
for (const extPath of paths) {
const { extension, error } = await loadExtension(
extPath,
cwd,
resolvedEventBus,
runtime,
);
if (error) {
errors.push({ path: extPath, error });
continue;
}
if (extension) {
extensions.push(extension);
}
}
return {
extensions,
errors,
runtime,
};
}
interface PiManifest {
extensions?: string[];
themes?: string[];
skills?: string[];
prompts?: string[];
}
function readPiManifest(packageJsonPath: string): PiManifest | null {
try {
const content = fs.readFileSync(packageJsonPath, "utf-8");
const pkg = JSON.parse(content);
if (pkg.pi && typeof pkg.pi === "object") {
return pkg.pi as PiManifest;
}
return null;
} catch {
return null;
}
}
function isExtensionFile(name: string): boolean {
return name.endsWith(".ts") || name.endsWith(".js");
}
/**
* Resolve extension entry points from a directory.
*
* Checks for:
* 1. package.json with "pi.extensions" field -> returns declared paths
* 2. index.ts or index.js -> returns the index file
*
* Returns resolved paths or null if no entry points found.
*/
function resolveExtensionEntries(dir: string): string[] | null {
// Check for package.json with "pi" field first
const packageJsonPath = path.join(dir, "package.json");
if (fs.existsSync(packageJsonPath)) {
const manifest = readPiManifest(packageJsonPath);
if (manifest?.extensions?.length) {
const entries: string[] = [];
for (const extPath of manifest.extensions) {
const resolvedExtPath = path.resolve(dir, extPath);
if (fs.existsSync(resolvedExtPath)) {
entries.push(resolvedExtPath);
}
}
if (entries.length > 0) {
return entries;
}
}
}
// Check for index.ts or index.js
const indexTs = path.join(dir, "index.ts");
const indexJs = path.join(dir, "index.js");
if (fs.existsSync(indexTs)) {
return [indexTs];
}
if (fs.existsSync(indexJs)) {
return [indexJs];
}
return null;
}
/**
* Discover extensions in a directory.
*
* Discovery rules:
* 1. Direct files: `extensions/*.ts` or `*.js` load
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` load
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field load what it declares
*
* No recursion beyond one level. Complex packages must use package.json manifest.
*/
function discoverExtensionsInDir(dir: string): string[] {
if (!fs.existsSync(dir)) {
return [];
}
const discovered: string[] = [];
try {
const entries = fs.readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const entryPath = path.join(dir, entry.name);
// 1. Direct files: *.ts or *.js
if (
(entry.isFile() || entry.isSymbolicLink()) &&
isExtensionFile(entry.name)
) {
discovered.push(entryPath);
continue;
}
// 2 & 3. Subdirectories
if (entry.isDirectory() || entry.isSymbolicLink()) {
const entries = resolveExtensionEntries(entryPath);
if (entries) {
discovered.push(...entries);
}
}
}
} catch {
return [];
}
return discovered;
}
/**
* Discover and load extensions from standard locations.
*/
export async function discoverAndLoadExtensions(
configuredPaths: string[],
cwd: string,
agentDir: string = getAgentDir(),
eventBus?: EventBus,
): Promise<LoadExtensionsResult> {
const allPaths: string[] = [];
const seen = new Set<string>();
const addPaths = (paths: string[]) => {
for (const p of paths) {
const resolved = path.resolve(p);
if (!seen.has(resolved)) {
seen.add(resolved);
allPaths.push(p);
}
}
};
// 1. Project-local extensions: cwd/.pi/extensions/
const localExtDir = path.join(cwd, ".pi", "extensions");
addPaths(discoverExtensionsInDir(localExtDir));
// 2. Global extensions: agentDir/extensions/
const globalExtDir = path.join(agentDir, "extensions");
addPaths(discoverExtensionsInDir(globalExtDir));
// 3. Explicitly configured paths
for (const p of configuredPaths) {
const resolved = resolvePath(p, cwd);
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
// Check for package.json with pi manifest or index.ts
const entries = resolveExtensionEntries(resolved);
if (entries) {
addPaths(entries);
continue;
}
// No explicit entries - discover individual files in directory
addPaths(discoverExtensionsInDir(resolved));
continue;
}
addPaths([resolved]);
}
return loadExtensions(allPaths, cwd, eventBus);
}

View file

@ -0,0 +1,950 @@
/**
* Extension runner - executes extensions and manages their lifecycle.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Model } from "@mariozechner/pi-ai";
import type { KeyId } from "@mariozechner/pi-tui";
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
import type { ResourceDiagnostic } from "../diagnostics.js";
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js";
import type {
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
CompactOptions,
ContextEvent,
ContextEventResult,
ContextUsage,
Extension,
ExtensionActions,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
ExtensionError,
ExtensionEvent,
ExtensionFlag,
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
InputEvent,
InputEventResult,
InputSource,
MessageRenderer,
RegisteredCommand,
RegisteredTool,
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SessionBeforeCompactResult,
SessionBeforeForkResult,
SessionBeforeSwitchResult,
SessionBeforeTreeResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEvent,
ToolResultEventResult,
UserBashEvent,
UserBashEventResult,
} from "./types.js";
// Keybindings for these actions cannot be overridden by extensions
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"externalEditor",
"followUp",
"submit",
"selectConfirm",
"selectCancel",
"copy",
"deleteToLineEnd",
];
type BuiltInKeyBindings = Partial<
Record<KeyId, { action: KeyAction; restrictOverride: boolean }>
>;
const buildBuiltinKeybindings = (
effectiveKeybindings: Required<KeybindingsConfig>,
): BuiltInKeyBindings => {
const builtinKeybindings = {} as BuiltInKeyBindings;
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
const keyAction = action as KeyAction;
const keyList = Array.isArray(keys) ? keys : [keys];
const restrictOverride =
RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
for (const key of keyList) {
const normalizedKey = key.toLowerCase() as KeyId;
builtinKeybindings[normalizedKey] = {
action: keyAction,
restrictOverride: restrictOverride,
};
}
}
return builtinKeybindings;
};
/** Combined result from all before_agent_start handlers */
interface BeforeAgentStartCombinedResult {
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
systemPrompt?: string;
}
/**
* Events handled by the generic emit() method.
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
*/
type RunnerEmitEvent = Exclude<
ExtensionEvent,
| ToolCallEvent
| ToolResultEvent
| UserBashEvent
| ContextEvent
| BeforeAgentStartEvent
| ResourcesDiscoverEvent
| InputEvent
>;
type SessionBeforeEvent = Extract<
RunnerEmitEvent,
{
type:
| "session_before_switch"
| "session_before_fork"
| "session_before_compact"
| "session_before_tree";
}
>;
type SessionBeforeEventResult =
| SessionBeforeSwitchResult
| SessionBeforeForkResult
| SessionBeforeCompactResult
| SessionBeforeTreeResult;
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends {
type: "session_before_switch";
}
? SessionBeforeSwitchResult | undefined
: TEvent extends { type: "session_before_fork" }
? SessionBeforeForkResult | undefined
: TEvent extends { type: "session_before_compact" }
? SessionBeforeCompactResult | undefined
: TEvent extends { type: "session_before_tree" }
? SessionBeforeTreeResult | undefined
: undefined;
export type ExtensionErrorListener = (error: ExtensionError) => void;
export type NewSessionHandler = (options?: {
parentSession?: string;
setup?: (sessionManager: SessionManager) => Promise<void>;
}) => Promise<{ cancelled: boolean }>;
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
export type NavigateTreeHandler = (
targetId: string,
options?: {
summarize?: boolean;
customInstructions?: string;
replaceInstructions?: boolean;
label?: string;
},
) => Promise<{ cancelled: boolean }>;
export type SwitchSessionHandler = (
sessionPath: string,
) => Promise<{ cancelled: boolean }>;
export type ReloadHandler = () => Promise<void>;
export type ShutdownHandler = () => void;
/**
* Helper function to emit session_shutdown event to extensions.
* Returns true if the event was emitted, false if there were no handlers.
*/
export async function emitSessionShutdownEvent(
extensionRunner: ExtensionRunner | undefined,
): Promise<boolean> {
if (extensionRunner?.hasHandlers("session_shutdown")) {
await extensionRunner.emit({
type: "session_shutdown",
});
return true;
}
return false;
}
const noOpUIContext: ExtensionUIContext = {
select: async () => undefined,
confirm: async () => false,
input: async () => undefined,
notify: () => {},
onTerminalInput: () => () => {},
setStatus: () => {},
setWorkingMessage: () => {},
setWidget: () => {},
setFooter: () => {},
setHeader: () => {},
setTitle: () => {},
custom: async () => undefined as never,
pasteToEditor: () => {},
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
getAllThemes: () => [],
getTheme: () => undefined,
setTheme: (_theme: string | Theme) => ({
success: false,
error: "UI not available",
}),
getToolsExpanded: () => false,
setToolsExpanded: () => {},
};
export class ExtensionRunner {
private extensions: Extension[];
private runtime: ExtensionRuntime;
private uiContext: ExtensionUIContext;
private cwd: string;
private sessionManager: SessionManager;
private modelRegistry: ModelRegistry;
private errorListeners: Set<ExtensionErrorListener> = new Set();
private getModel: () => Model<any> | undefined = () => undefined;
private isIdleFn: () => boolean = () => true;
private waitForIdleFn: () => Promise<void> = async () => {};
private abortFn: () => void = () => {};
private hasPendingMessagesFn: () => boolean = () => false;
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
private compactFn: (options?: CompactOptions) => void = () => {};
private getSystemPromptFn: () => string = () => "";
private newSessionHandler: NewSessionHandler = async () => ({
cancelled: false,
});
private forkHandler: ForkHandler = async () => ({ cancelled: false });
private navigateTreeHandler: NavigateTreeHandler = async () => ({
cancelled: false,
});
private switchSessionHandler: SwitchSessionHandler = async () => ({
cancelled: false,
});
private reloadHandler: ReloadHandler = async () => {};
private shutdownHandler: ShutdownHandler = () => {};
private shortcutDiagnostics: ResourceDiagnostic[] = [];
private commandDiagnostics: ResourceDiagnostic[] = [];
constructor(
extensions: Extension[],
runtime: ExtensionRuntime,
cwd: string,
sessionManager: SessionManager,
modelRegistry: ModelRegistry,
) {
this.extensions = extensions;
this.runtime = runtime;
this.uiContext = noOpUIContext;
this.cwd = cwd;
this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry;
}
bindCore(
actions: ExtensionActions,
contextActions: ExtensionContextActions,
): void {
// Copy actions into the shared runtime (all extension APIs reference this)
this.runtime.sendMessage = actions.sendMessage;
this.runtime.sendUserMessage = actions.sendUserMessage;
this.runtime.appendEntry = actions.appendEntry;
this.runtime.setSessionName = actions.setSessionName;
this.runtime.getSessionName = actions.getSessionName;
this.runtime.setLabel = actions.setLabel;
this.runtime.getActiveTools = actions.getActiveTools;
this.runtime.getAllTools = actions.getAllTools;
this.runtime.setActiveTools = actions.setActiveTools;
this.runtime.refreshTools = actions.refreshTools;
this.runtime.getCommands = actions.getCommands;
this.runtime.setModel = actions.setModel;
this.runtime.getThinkingLevel = actions.getThinkingLevel;
this.runtime.setThinkingLevel = actions.setThinkingLevel;
// Context actions (required)
this.getModel = contextActions.getModel;
this.isIdleFn = contextActions.isIdle;
this.abortFn = contextActions.abort;
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact;
this.getSystemPromptFn = contextActions.getSystemPrompt;
// Flush provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
}
this.runtime.pendingProviderRegistrations = [];
// From this point on, provider registration/unregistration takes effect immediately
// without requiring a /reload.
this.runtime.registerProvider = (name, config) =>
this.modelRegistry.registerProvider(name, config);
this.runtime.unregisterProvider = (name) =>
this.modelRegistry.unregisterProvider(name);
}
bindCommandContext(actions?: ExtensionCommandContextActions): void {
if (actions) {
this.waitForIdleFn = actions.waitForIdle;
this.newSessionHandler = actions.newSession;
this.forkHandler = actions.fork;
this.navigateTreeHandler = actions.navigateTree;
this.switchSessionHandler = actions.switchSession;
this.reloadHandler = actions.reload;
return;
}
this.waitForIdleFn = async () => {};
this.newSessionHandler = async () => ({ cancelled: false });
this.forkHandler = async () => ({ cancelled: false });
this.navigateTreeHandler = async () => ({ cancelled: false });
this.switchSessionHandler = async () => ({ cancelled: false });
this.reloadHandler = async () => {};
}
setUIContext(uiContext?: ExtensionUIContext): void {
this.uiContext = uiContext ?? noOpUIContext;
}
getUIContext(): ExtensionUIContext {
return this.uiContext;
}
hasUI(): boolean {
return this.uiContext !== noOpUIContext;
}
getExtensionPaths(): string[] {
return this.extensions.map((e) => e.path);
}
/** Get all registered tools from all extensions (first registration per name wins). */
getAllRegisteredTools(): RegisteredTool[] {
const toolsByName = new Map<string, RegisteredTool>();
for (const ext of this.extensions) {
for (const tool of ext.tools.values()) {
if (!toolsByName.has(tool.definition.name)) {
toolsByName.set(tool.definition.name, tool);
}
}
}
return Array.from(toolsByName.values());
}
/** Get a tool definition by name. Returns undefined if not found. */
getToolDefinition(
toolName: string,
): RegisteredTool["definition"] | undefined {
for (const ext of this.extensions) {
const tool = ext.tools.get(toolName);
if (tool) {
return tool.definition;
}
}
return undefined;
}
getFlags(): Map<string, ExtensionFlag> {
const allFlags = new Map<string, ExtensionFlag>();
for (const ext of this.extensions) {
for (const [name, flag] of ext.flags) {
if (!allFlags.has(name)) {
allFlags.set(name, flag);
}
}
}
return allFlags;
}
setFlagValue(name: string, value: boolean | string): void {
this.runtime.flagValues.set(name, value);
}
getFlagValues(): Map<string, boolean | string> {
return new Map(this.runtime.flagValues);
}
getShortcuts(
effectiveKeybindings: Required<KeybindingsConfig>,
): Map<KeyId, ExtensionShortcut> {
this.shortcutDiagnostics = [];
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
const addDiagnostic = (message: string, extensionPath: string) => {
this.shortcutDiagnostics.push({
type: "warning",
message,
path: extensionPath,
});
if (!this.hasUI()) {
console.warn(message);
}
};
for (const ext of this.extensions) {
for (const [key, shortcut] of ext.shortcuts) {
const normalizedKey = key.toLowerCase() as KeyId;
const builtInKeybinding = builtinKeybindings[normalizedKey];
if (builtInKeybinding?.restrictOverride === true) {
addDiagnostic(
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
shortcut.extensionPath,
);
continue;
}
if (builtInKeybinding?.restrictOverride === false) {
addDiagnostic(
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
if (existingExtensionShortcut) {
addDiagnostic(
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
extensionShortcuts.set(normalizedKey, shortcut);
}
}
return extensionShortcuts;
}
getShortcutDiagnostics(): ResourceDiagnostic[] {
return this.shortcutDiagnostics;
}
onError(listener: ExtensionErrorListener): () => void {
this.errorListeners.add(listener);
return () => this.errorListeners.delete(listener);
}
emitError(error: ExtensionError): void {
for (const listener of this.errorListeners) {
listener(error);
}
}
hasHandlers(eventType: string): boolean {
for (const ext of this.extensions) {
const handlers = ext.handlers.get(eventType);
if (handlers && handlers.length > 0) {
return true;
}
}
return false;
}
getMessageRenderer(customType: string): MessageRenderer | undefined {
for (const ext of this.extensions) {
const renderer = ext.messageRenderers.get(customType);
if (renderer) {
return renderer;
}
}
return undefined;
}
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
this.commandDiagnostics = [];
const commands: RegisteredCommand[] = [];
const commandOwners = new Map<string, string>();
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
if (reserved?.has(command.name)) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
this.commandDiagnostics.push({
type: "warning",
message,
path: ext.path,
});
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
const existingOwner = commandOwners.get(command.name);
if (existingOwner) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
this.commandDiagnostics.push({
type: "warning",
message,
path: ext.path,
});
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
commandOwners.set(command.name, ext.path);
commands.push(command);
}
}
return commands;
}
getCommandDiagnostics(): ResourceDiagnostic[] {
return this.commandDiagnostics;
}
getRegisteredCommandsWithPaths(): Array<{
command: RegisteredCommand;
extensionPath: string;
}> {
const result: Array<{ command: RegisteredCommand; extensionPath: string }> =
[];
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
result.push({ command, extensionPath: ext.path });
}
}
return result;
}
getCommand(name: string): RegisteredCommand | undefined {
for (const ext of this.extensions) {
const command = ext.commands.get(name);
if (command) {
return command;
}
}
return undefined;
}
/**
* Request a graceful shutdown. Called by extension tools and event handlers.
* The actual shutdown behavior is provided by the mode via bindExtensions().
*/
shutdown(): void {
this.shutdownHandler();
}
/**
* Create an ExtensionContext for use in event handlers and tool execution.
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
*/
createContext(): ExtensionContext {
const getModel = this.getModel;
return {
ui: this.uiContext,
hasUI: this.hasUI(),
cwd: this.cwd,
sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry,
get model() {
return getModel();
},
isIdle: () => this.isIdleFn(),
abort: () => this.abortFn(),
hasPendingMessages: () => this.hasPendingMessagesFn(),
shutdown: () => this.shutdownHandler(),
getContextUsage: () => this.getContextUsageFn(),
compact: (options) => this.compactFn(options),
getSystemPrompt: () => this.getSystemPromptFn(),
};
}
createCommandContext(): ExtensionCommandContext {
return {
...this.createContext(),
waitForIdle: () => this.waitForIdleFn(),
newSession: (options) => this.newSessionHandler(options),
fork: (entryId) => this.forkHandler(entryId),
navigateTree: (targetId, options) =>
this.navigateTreeHandler(targetId, options),
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
reload: () => this.reloadHandler(),
};
}
private isSessionBeforeEvent(
event: RunnerEmitEvent,
): event is SessionBeforeEvent {
return (
event.type === "session_before_switch" ||
event.type === "session_before_fork" ||
event.type === "session_before_compact" ||
event.type === "session_before_tree"
);
}
async emit<TEvent extends RunnerEmitEvent>(
event: TEvent,
): Promise<RunnerEmitResult<TEvent>> {
const ctx = this.createContext();
let result: SessionBeforeEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get(event.type);
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (this.isSessionBeforeEvent(event) && handlerResult) {
result = handlerResult as SessionBeforeEventResult;
if (result.cancel) {
return result as RunnerEmitResult<TEvent>;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: event.type,
error: message,
stack,
});
}
}
}
return result as RunnerEmitResult<TEvent>;
}
async emitToolResult(
event: ToolResultEvent,
): Promise<ToolResultEventResult | undefined> {
const ctx = this.createContext();
const currentEvent: ToolResultEvent = { ...event };
let modified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_result");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = (await handler(currentEvent, ctx)) as
| ToolResultEventResult
| undefined;
if (!handlerResult) continue;
if (handlerResult.content !== undefined) {
currentEvent.content = handlerResult.content;
modified = true;
}
if (handlerResult.details !== undefined) {
currentEvent.details = handlerResult.details;
modified = true;
}
if (handlerResult.isError !== undefined) {
currentEvent.isError = handlerResult.isError;
modified = true;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "tool_result",
error: message,
stack,
});
}
}
}
if (!modified) {
return undefined;
}
return {
content: currentEvent.content,
details: currentEvent.details,
isError: currentEvent.isError,
};
}
async emitToolCall(
event: ToolCallEvent,
): Promise<ToolCallEventResult | undefined> {
const ctx = this.createContext();
let result: ToolCallEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_call");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
result = handlerResult as ToolCallEventResult;
if (result.block) {
return result;
}
}
}
}
return result;
}
async emitUserBash(
event: UserBashEvent,
): Promise<UserBashEventResult | undefined> {
const ctx = this.createContext();
for (const ext of this.extensions) {
const handlers = ext.handlers.get("user_bash");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
return handlerResult as UserBashEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "user_bash",
error: message,
stack,
});
}
}
}
return undefined;
}
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
const ctx = this.createContext();
let currentMessages = structuredClone(messages);
for (const ext of this.extensions) {
const handlers = ext.handlers.get("context");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ContextEvent = {
type: "context",
messages: currentMessages,
};
const handlerResult = await handler(event, ctx);
if (handlerResult && (handlerResult as ContextEventResult).messages) {
currentMessages = (handlerResult as ContextEventResult).messages!;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "context",
error: message,
stack,
});
}
}
}
return currentMessages;
}
async emitBeforeAgentStart(
prompt: string,
images: ImageContent[] | undefined,
systemPrompt: string,
): Promise<BeforeAgentStartCombinedResult | undefined> {
const ctx = this.createContext();
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
let currentSystemPrompt = systemPrompt;
let systemPromptModified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("before_agent_start");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeAgentStartEvent = {
type: "before_agent_start",
prompt,
images,
systemPrompt: currentSystemPrompt,
};
const handlerResult = await handler(event, ctx);
if (handlerResult) {
const result = handlerResult as BeforeAgentStartEventResult;
if (result.message) {
messages.push(result.message);
}
if (result.systemPrompt !== undefined) {
currentSystemPrompt = result.systemPrompt;
systemPromptModified = true;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "before_agent_start",
error: message,
stack,
});
}
}
}
if (messages.length > 0 || systemPromptModified) {
return {
messages: messages.length > 0 ? messages : undefined,
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
};
}
return undefined;
}
async emitResourcesDiscover(
cwd: string,
reason: ResourcesDiscoverEvent["reason"],
): Promise<{
skillPaths: Array<{ path: string; extensionPath: string }>;
promptPaths: Array<{ path: string; extensionPath: string }>;
themePaths: Array<{ path: string; extensionPath: string }>;
}> {
const ctx = this.createContext();
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
const themePaths: Array<{ path: string; extensionPath: string }> = [];
for (const ext of this.extensions) {
const handlers = ext.handlers.get("resources_discover");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ResourcesDiscoverEvent = {
type: "resources_discover",
cwd,
reason,
};
const handlerResult = await handler(event, ctx);
const result = handlerResult as ResourcesDiscoverResult | undefined;
if (result?.skillPaths?.length) {
skillPaths.push(
...result.skillPaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
if (result?.promptPaths?.length) {
promptPaths.push(
...result.promptPaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
if (result?.themePaths?.length) {
themePaths.push(
...result.themePaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "resources_discover",
error: message,
stack,
});
}
}
}
return { skillPaths, promptPaths, themePaths };
}
/** Emit input event. Transforms chain, "handled" short-circuits. */
async emitInput(
text: string,
images: ImageContent[] | undefined,
source: InputSource,
): Promise<InputEventResult> {
const ctx = this.createContext();
let currentText = text;
let currentImages = images;
for (const ext of this.extensions) {
for (const handler of ext.handlers.get("input") ?? []) {
try {
const event: InputEvent = {
type: "input",
text: currentText,
images: currentImages,
source,
};
const result = (await handler(event, ctx)) as
| InputEventResult
| undefined;
if (result?.action === "handled") return result;
if (result?.action === "transform") {
currentText = result.text;
currentImages = result.images ?? currentImages;
}
} catch (err) {
this.emitError({
extensionPath: ext.path,
event: "input",
error: err instanceof Error ? err.message : String(err),
stack: err instanceof Error ? err.stack : undefined,
});
}
}
}
return currentText !== text || currentImages !== images
? { action: "transform", text: currentText, images: currentImages }
: { action: "continue" };
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,147 @@
/**
* Tool wrappers for extensions.
*/
import type {
AgentTool,
AgentToolUpdateCallback,
} from "@mariozechner/pi-agent-core";
import type { ExtensionRunner } from "./runner.js";
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
/**
* Wrap a RegisteredTool into an AgentTool.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTool(
registeredTool: RegisteredTool,
runner: ExtensionRunner,
): AgentTool {
const { definition } = registeredTool;
return {
name: definition.name,
label: definition.label,
description: definition.description,
parameters: definition.parameters,
execute: (toolCallId, params, signal, onUpdate) =>
definition.execute(
toolCallId,
params,
signal,
onUpdate,
runner.createContext(),
),
};
}
/**
* Wrap all registered tools into AgentTools.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTools(
registeredTools: RegisteredTool[],
runner: ExtensionRunner,
): AgentTool[] {
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
}
/**
* Wrap a tool with extension callbacks for interception.
* - Emits tool_call event before execution (can block)
* - Emits tool_result event after execution (can modify result)
*/
export function wrapToolWithExtensions<T>(
tool: AgentTool<any, T>,
runner: ExtensionRunner,
): AgentTool<any, T> {
return {
...tool,
execute: async (
toolCallId: string,
params: Record<string, unknown>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<T>,
) => {
// Emit tool_call event - extensions can block execution
if (runner.hasHandlers("tool_call")) {
try {
const callResult = (await runner.emitToolCall({
type: "tool_call",
toolName: tool.name,
toolCallId,
input: params,
})) as ToolCallEventResult | undefined;
if (callResult?.block) {
const reason =
callResult.reason || "Tool execution was blocked by an extension";
throw new Error(reason);
}
} catch (err) {
if (err instanceof Error) {
throw err;
}
throw new Error(
`Extension failed, blocking execution: ${String(err)}`,
);
}
}
// Execute the actual tool
try {
const result = await tool.execute(toolCallId, params, signal, onUpdate);
// Emit tool_result event - extensions can modify the result
if (runner.hasHandlers("tool_result")) {
const resultResult = await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: result.content,
details: result.details,
isError: false,
});
if (resultResult) {
return {
content: resultResult.content ?? result.content,
details: (resultResult.details ?? result.details) as T,
};
}
}
return result;
} catch (err) {
// Emit tool_result event for errors
if (runner.hasHandlers("tool_result")) {
await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: [
{
type: "text",
text: err instanceof Error ? err.message : String(err),
},
],
details: undefined,
isError: true,
});
}
throw err;
}
},
};
}
/**
* Wrap all tools with extension callbacks.
*/
export function wrapToolsWithExtensions<T>(
tools: AgentTool<any, T>[],
runner: ExtensionRunner,
): AgentTool<any, T>[] {
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
}

View file

@ -0,0 +1,149 @@
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
import { dirname, join, resolve } from "path";
/**
* Find the git HEAD path by walking up from cwd.
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
*/
function findGitHeadPath(): string | null {
let dir = process.cwd();
while (true) {
const gitPath = join(dir, ".git");
if (existsSync(gitPath)) {
try {
const stat = statSync(gitPath);
if (stat.isFile()) {
const content = readFileSync(gitPath, "utf8").trim();
if (content.startsWith("gitdir: ")) {
const gitDir = content.slice(8);
const headPath = resolve(dir, gitDir, "HEAD");
if (existsSync(headPath)) return headPath;
}
} else if (stat.isDirectory()) {
const headPath = join(gitPath, "HEAD");
if (existsSync(headPath)) return headPath;
}
} catch {
return null;
}
}
const parent = dirname(dir);
if (parent === dir) return null;
dir = parent;
}
}
/**
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
* Token stats, model info available via ctx.sessionManager and ctx.model.
*/
export class FooterDataProvider {
private extensionStatuses = new Map<string, string>();
private cachedBranch: string | null | undefined = undefined;
private gitWatcher: FSWatcher | null = null;
private branchChangeCallbacks = new Set<() => void>();
private availableProviderCount = 0;
constructor() {
this.setupGitWatcher();
}
/** Current git branch, null if not in repo, "detached" if detached HEAD */
getGitBranch(): string | null {
if (this.cachedBranch !== undefined) return this.cachedBranch;
try {
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) {
this.cachedBranch = null;
return null;
}
const content = readFileSync(gitHeadPath, "utf8").trim();
this.cachedBranch = content.startsWith("ref: refs/heads/")
? content.slice(16)
: "detached";
} catch {
this.cachedBranch = null;
}
return this.cachedBranch;
}
/** Extension status texts set via ctx.ui.setStatus() */
getExtensionStatuses(): ReadonlyMap<string, string> {
return this.extensionStatuses;
}
/** Subscribe to git branch changes. Returns unsubscribe function. */
onBranchChange(callback: () => void): () => void {
this.branchChangeCallbacks.add(callback);
return () => this.branchChangeCallbacks.delete(callback);
}
/** Internal: set extension status */
setExtensionStatus(key: string, text: string | undefined): void {
if (text === undefined) {
this.extensionStatuses.delete(key);
} else {
this.extensionStatuses.set(key, text);
}
}
/** Internal: clear extension statuses */
clearExtensionStatuses(): void {
this.extensionStatuses.clear();
}
/** Number of unique providers with available models (for footer display) */
getAvailableProviderCount(): number {
return this.availableProviderCount;
}
/** Internal: update available provider count */
setAvailableProviderCount(count: number): void {
this.availableProviderCount = count;
}
/** Internal: cleanup */
dispose(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
this.branchChangeCallbacks.clear();
}
private setupGitWatcher(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) return;
// Watch the directory containing HEAD, not HEAD itself.
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
// fs.watch on a file stops working after the inode changes.
const gitDir = dirname(gitHeadPath);
try {
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
if (filename === "HEAD") {
this.cachedBranch = undefined;
for (const cb of this.branchChangeCallbacks) cb();
}
});
} catch {
// Silently fail if we can't watch
}
}
}
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
export type ReadonlyFooterDataProvider = Pick<
FooterDataProvider,
| "getGitBranch"
| "getExtensionStatuses"
| "getAvailableProviderCount"
| "onBranchChange"
>;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
/**
* Core modules shared between all run modes.
*/
export {
AgentSession,
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type ModelCycleResult,
type PromptOptions,
type SessionStats,
} from "./agent-session.js";
export {
type BashExecutorOptions,
type BashResult,
executeBash,
executeBashWithOperations,
} from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js";
export {
createEventBus,
type EventBus,
type EventBusController,
} from "./event-bus.js";
// Extensions system
export {
type AgentEndEvent,
type AgentStartEvent,
type AgentToolResult,
type AgentToolUpdateCallback,
type BeforeAgentStartEvent,
type ContextEvent,
discoverAndLoadExtensions,
type ExecOptions,
type ExecResult,
type Extension,
type ExtensionAPI,
type ExtensionCommandContext,
type ExtensionContext,
type ExtensionError,
type ExtensionEvent,
type ExtensionFactory,
type ExtensionFlag,
type ExtensionHandler,
ExtensionRunner,
type ExtensionShortcut,
type ExtensionUIContext,
type LoadExtensionsResult,
type MessageRenderer,
type RegisteredCommand,
type SessionBeforeCompactEvent,
type SessionBeforeForkEvent,
type SessionBeforeSwitchEvent,
type SessionBeforeTreeEvent,
type SessionCompactEvent,
type SessionForkEvent,
type SessionShutdownEvent,
type SessionStartEvent,
type SessionSwitchEvent,
type SessionTreeEvent,
type ToolCallEvent,
type ToolDefinition,
type ToolRenderResultOptions,
type ToolResultEvent,
type TurnEndEvent,
type TurnStartEvent,
wrapToolsWithExtensions,
} from "./extensions/index.js";

View file

@ -0,0 +1,211 @@
import {
DEFAULT_EDITOR_KEYBINDINGS,
type EditorAction,
type EditorKeybindingsConfig,
EditorKeybindingsManager,
type KeyId,
matchesKey,
setEditorKeybindings,
} from "@mariozechner/pi-tui";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
/**
* Application-level actions (coding agent specific).
*/
export type AppAction =
| "interrupt"
| "clear"
| "exit"
| "suspend"
| "cycleThinkingLevel"
| "cycleModelForward"
| "cycleModelBackward"
| "selectModel"
| "expandTools"
| "toggleThinking"
| "toggleSessionNamedFilter"
| "externalEditor"
| "followUp"
| "dequeue"
| "pasteImage"
| "newSession"
| "tree"
| "fork"
| "resume";
/**
* All configurable actions.
*/
export type KeyAction = AppAction | EditorAction;
/**
* Full keybindings configuration (app + editor actions).
*/
export type KeybindingsConfig = {
[K in KeyAction]?: KeyId | KeyId[];
};
/**
* Default application keybindings.
*/
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
interrupt: "escape",
clear: "ctrl+c",
exit: "ctrl+d",
suspend: "ctrl+z",
cycleThinkingLevel: "shift+tab",
cycleModelForward: "ctrl+p",
cycleModelBackward: "shift+ctrl+p",
selectModel: "ctrl+l",
expandTools: "ctrl+o",
toggleThinking: "ctrl+t",
toggleSessionNamedFilter: "ctrl+n",
externalEditor: "ctrl+g",
followUp: "alt+enter",
dequeue: "alt+up",
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
newSession: [],
tree: [],
fork: [],
resume: [],
};
/**
* All default keybindings (app + editor).
*/
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
...DEFAULT_EDITOR_KEYBINDINGS,
...DEFAULT_APP_KEYBINDINGS,
};
// App actions list for type checking
const APP_ACTIONS: AppAction[] = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"toggleSessionNamedFilter",
"externalEditor",
"followUp",
"dequeue",
"pasteImage",
"newSession",
"tree",
"fork",
"resume",
];
function isAppAction(action: string): action is AppAction {
return APP_ACTIONS.includes(action as AppAction);
}
/**
* Manages all keybindings (app + editor).
*/
export class KeybindingsManager {
private config: KeybindingsConfig;
private appActionToKeys: Map<AppAction, KeyId[]>;
private constructor(config: KeybindingsConfig) {
this.config = config;
this.appActionToKeys = new Map();
this.buildMaps();
}
/**
* Create from config file and set up editor keybindings.
*/
static create(agentDir: string = getAgentDir()): KeybindingsManager {
const configPath = join(agentDir, "keybindings.json");
const config = KeybindingsManager.loadFromFile(configPath);
const manager = new KeybindingsManager(config);
// Set up editor keybindings globally
// Include both editor actions and expandTools (shared between app and editor)
const editorConfig: EditorKeybindingsConfig = {};
for (const [action, keys] of Object.entries(config)) {
if (!isAppAction(action) || action === "expandTools") {
editorConfig[action as EditorAction] = keys;
}
}
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
return manager;
}
/**
* Create in-memory.
*/
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
return new KeybindingsManager(config);
}
private static loadFromFile(path: string): KeybindingsConfig {
if (!existsSync(path)) return {};
try {
return JSON.parse(readFileSync(path, "utf-8"));
} catch {
return {};
}
}
private buildMaps(): void {
this.appActionToKeys.clear();
// Set defaults for app actions
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action as AppAction, [...keyArray]);
}
// Override with user config (app actions only)
for (const [action, keys] of Object.entries(this.config)) {
if (keys === undefined || !isAppAction(action)) continue;
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action, keyArray);
}
}
/**
* Check if input matches an app action.
*/
matches(data: string, action: AppAction): boolean {
const keys = this.appActionToKeys.get(action);
if (!keys) return false;
for (const key of keys) {
if (matchesKey(data, key)) return true;
}
return false;
}
/**
* Get keys bound to an app action.
*/
getKeys(action: AppAction): KeyId[] {
return this.appActionToKeys.get(action) ?? [];
}
/**
* Get the full effective config.
*/
getEffectiveConfig(): Required<KeybindingsConfig> {
const result = { ...DEFAULT_KEYBINDINGS };
for (const [action, keys] of Object.entries(this.config)) {
if (keys !== undefined) {
(result as KeybindingsConfig)[action as KeyAction] = keys;
}
}
return result;
}
}
// Re-export for convenience
export type { EditorAction, KeyId };

View file

@ -0,0 +1,217 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AgentMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
<summary>
`;
export const COMPACTION_SUMMARY_SUFFIX = `
</summary>`;
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
<summary>
`;
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
/**
* Message type for bash executions via the ! command.
*/
export interface BashExecutionMessage {
role: "bashExecution";
command: string;
output: string;
exitCode: number | undefined;
cancelled: boolean;
truncated: boolean;
fullOutputPath?: string;
timestamp: number;
/** If true, this message is excluded from LLM context (!! prefix) */
excludeFromContext?: boolean;
}
/**
* Message type for extension-injected messages via sendMessage().
* These are custom messages that extensions can inject into the conversation.
*/
export interface CustomMessage<T = unknown> {
role: "custom";
customType: string;
content: string | (TextContent | ImageContent)[];
display: boolean;
details?: T;
timestamp: number;
}
export interface BranchSummaryMessage {
role: "branchSummary";
summary: string;
fromId: string;
timestamp: number;
}
export interface CompactionSummaryMessage {
role: "compactionSummary";
summary: string;
tokensBefore: number;
timestamp: number;
}
// Extend CustomAgentMessages via declaration merging
declare module "@mariozechner/pi-agent-core" {
interface CustomAgentMessages {
bashExecution: BashExecutionMessage;
custom: CustomMessage;
branchSummary: BranchSummaryMessage;
compactionSummary: CompactionSummaryMessage;
}
}
/**
* Convert a BashExecutionMessage to user message text for LLM context.
*/
export function bashExecutionToText(msg: BashExecutionMessage): string {
let text = `Ran \`${msg.command}\`\n`;
if (msg.output) {
text += `\`\`\`\n${msg.output}\n\`\`\``;
} else {
text += "(no output)";
}
if (msg.cancelled) {
text += "\n\n(command cancelled)";
} else if (
msg.exitCode !== null &&
msg.exitCode !== undefined &&
msg.exitCode !== 0
) {
text += `\n\nCommand exited with code ${msg.exitCode}`;
}
if (msg.truncated && msg.fullOutputPath) {
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
}
return text;
}
export function createBranchSummaryMessage(
summary: string,
fromId: string,
timestamp: string,
): BranchSummaryMessage {
return {
role: "branchSummary",
summary,
fromId,
timestamp: new Date(timestamp).getTime(),
};
}
export function createCompactionSummaryMessage(
summary: string,
tokensBefore: number,
timestamp: string,
): CompactionSummaryMessage {
return {
role: "compactionSummary",
summary: summary,
tokensBefore,
timestamp: new Date(timestamp).getTime(),
};
}
/** Convert CustomMessageEntry to AgentMessage format */
export function createCustomMessage(
customType: string,
content: string | (TextContent | ImageContent)[],
display: boolean,
details: unknown | undefined,
timestamp: string,
): CustomMessage {
return {
role: "custom",
customType,
content,
display,
details,
timestamp: new Date(timestamp).getTime(),
};
}
/**
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's transormToLlm option (for prompt calls and queued messages)
* - Compaction's generateSummary (for summarization)
* - Custom extensions and tools
*/
export function convertToLlm(messages: AgentMessage[]): Message[] {
return messages
.map((m): Message | undefined => {
switch (m.role) {
case "bashExecution":
// Skip messages excluded from context (!! prefix)
if (m.excludeFromContext) {
return undefined;
}
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
case "custom": {
const content =
typeof m.content === "string"
? [{ type: "text" as const, text: m.content }]
: m.content;
return {
role: "user",
content,
timestamp: m.timestamp,
};
}
case "branchSummary":
return {
role: "user",
content: [
{
type: "text" as const,
text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX,
},
],
timestamp: m.timestamp,
};
case "compactionSummary":
return {
role: "user",
content: [
{
type: "text" as const,
text:
COMPACTION_SUMMARY_PREFIX +
m.summary +
COMPACTION_SUMMARY_SUFFIX,
},
],
timestamp: m.timestamp,
};
case "user":
case "assistant":
case "toolResult":
return m;
default:
// biome-ignore lint/correctness/noSwitchDeclarations: fine
const _exhaustiveCheck: never = m;
return undefined;
}
})
.filter((m) => m !== undefined);
}

View file

@ -0,0 +1,822 @@
/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import {
type Api,
type AssistantMessageEventStream,
type Context,
getModels,
getProviders,
type KnownProvider,
type Model,
type OAuthProviderInterface,
type OpenAICompletionsCompat,
type OpenAIResponsesCompat,
registerApiProvider,
resetApiProviders,
type SimpleStreamOptions,
} from "@mariozechner/pi-ai";
import {
registerOAuthProvider,
resetOAuthProviders,
} from "@mariozechner/pi-ai/oauth";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import type { AuthStorage } from "./auth-storage.js";
import {
clearConfigValueCache,
resolveConfigValue,
resolveHeaders,
} from "./resolve-config-value.js";
const Ajv = (AjvModule as any).default || AjvModule;
const ajv = new Ajv();
// Schema for OpenRouter routing preferences
const OpenRouterRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for Vercel AI Gateway routing preferences
const VercelGatewayRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for OpenAI compatibility settings
const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(
Type.Union([
Type.Literal("max_completion_tokens"),
Type.Literal("max_tokens"),
]),
),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(
Type.Union([
Type.Literal("openai"),
Type.Literal("zai"),
Type.Literal("qwen"),
]),
),
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
});
const OpenAIResponsesCompatSchema = Type.Object({
// Reserved for future use
});
const OpenAICompatSchema = Type.Union([
OpenAICompletionsCompatSchema,
OpenAIResponsesCompatSchema,
]);
// Schema for custom model definition
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
// Schema for per-model overrides (all fields optional, merged with built-in model)
const ModelOverrideSchema = Type.Object({
name: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Optional(Type.Number()),
output: Type.Optional(Type.Number()),
cacheRead: Type.Optional(Type.Number()),
cacheWrite: Type.Optional(Type.Number()),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
type ModelOverride = Static<typeof ModelOverrideSchema>;
const ProviderConfigSchema = Type.Object({
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
modelOverrides: Type.Optional(
Type.Record(Type.String(), ModelOverrideSchema),
),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/** Provider override config (baseUrl, headers, apiKey) without custom models */
interface ProviderOverride {
baseUrl?: string;
headers?: Record<string, string>;
apiKey?: string;
}
/** Result of loading custom models from models.json */
interface CustomModelsResult {
models: Model<Api>[];
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
overrides: Map<string, ProviderOverride>;
/** Per-model overrides: provider -> modelId -> override */
modelOverrides: Map<string, Map<string, ModelOverride>>;
error: string | undefined;
}
function emptyCustomModelsResult(error?: string): CustomModelsResult {
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
}
function mergeCompat(
baseCompat: Model<Api>["compat"],
overrideCompat: ModelOverride["compat"],
): Model<Api>["compat"] | undefined {
if (!overrideCompat) return baseCompat;
const base = baseCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat
| undefined;
const override = overrideCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const merged = { ...base, ...override } as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const baseCompletions = base as OpenAICompletionsCompat | undefined;
const overrideCompletions = override as OpenAICompletionsCompat;
const mergedCompletions = merged as OpenAICompletionsCompat;
if (
baseCompletions?.openRouterRouting ||
overrideCompletions.openRouterRouting
) {
mergedCompletions.openRouterRouting = {
...baseCompletions?.openRouterRouting,
...overrideCompletions.openRouterRouting,
};
}
if (
baseCompletions?.vercelGatewayRouting ||
overrideCompletions.vercelGatewayRouting
) {
mergedCompletions.vercelGatewayRouting = {
...baseCompletions?.vercelGatewayRouting,
...overrideCompletions.vercelGatewayRouting,
};
}
return merged as Model<Api>["compat"];
}
/**
* Deep merge a model override into a model.
* Handles nested objects (cost, compat) by merging rather than replacing.
*/
function applyModelOverride(
model: Model<Api>,
override: ModelOverride,
): Model<Api> {
const result = { ...model };
// Simple field overrides
if (override.name !== undefined) result.name = override.name;
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
if (override.input !== undefined)
result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined)
result.contextWindow = override.contextWindow;
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
// Merge cost (partial override)
if (override.cost) {
result.cost = {
input: override.cost.input ?? model.cost.input,
output: override.cost.output ?? model.cost.output,
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
};
}
// Merge headers
if (override.headers) {
const resolvedHeaders = resolveHeaders(override.headers);
result.headers = resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers;
}
// Deep merge compat
result.compat = mergeCompat(model.compat, override.compat);
return result;
}
/** Clear the config value command cache. Exported for testing. */
export const clearApiKeyCache = clearConfigValueCache;
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | undefined = join(
getAgentDir(),
"models.json",
),
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveConfigValue(keyConfig);
}
return undefined;
});
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = undefined;
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
resetApiProviders();
resetOAuthProviders();
this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
}
/**
* Get any error from loading models.json (undefined if no error).
*/
getError(): string | undefined {
return this.loadError;
}
private loadModels(): void {
// Load custom models and overrides from models.json
const {
models: customModels,
overrides,
modelOverrides,
error,
} = this.modelsJsonPath
? this.loadCustomModels(this.modelsJsonPath)
: emptyCustomModelsResult();
if (error) {
this.loadError = error;
// Keep built-in models even if custom models failed to load
}
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
let combined = this.mergeCustomModels(builtInModels, customModels);
// Let OAuth providers modify their models (e.g., update baseUrl)
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
const cred = this.authStorage.get(oauthProvider.id);
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
combined = oauthProvider.modifyModels(combined, cred);
}
}
this.models = combined;
}
/** Load built-in models and apply provider/model overrides */
private loadBuiltInModels(
overrides: Map<string, ProviderOverride>,
modelOverrides: Map<string, Map<string, ModelOverride>>,
): Model<Api>[] {
return getProviders().flatMap((provider) => {
const models = getModels(provider as KnownProvider) as Model<Api>[];
const providerOverride = overrides.get(provider);
const perModelOverrides = modelOverrides.get(provider);
return models.map((m) => {
let model = m;
// Apply provider-level baseUrl/headers override
if (providerOverride) {
const resolvedHeaders = resolveHeaders(providerOverride.headers);
model = {
...model,
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
headers: resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers,
};
}
// Apply per-model override
const modelOverride = perModelOverrides?.get(m.id);
if (modelOverride) {
model = applyModelOverride(model, modelOverride);
}
return model;
});
});
}
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
private mergeCustomModels(
builtInModels: Model<Api>[],
customModels: Model<Api>[],
): Model<Api>[] {
const merged = [...builtInModels];
for (const customModel of customModels) {
const existingIndex = merged.findIndex(
(m) => m.provider === customModel.provider && m.id === customModel.id,
);
if (existingIndex >= 0) {
merged[existingIndex] = customModel;
} else {
merged.push(customModel);
}
}
return merged;
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
if (!existsSync(modelsJsonPath)) {
return emptyCustomModelsResult();
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const validate = ajv.getSchema("ModelsConfig")!;
if (!validate(config)) {
const errors =
validate.errors
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
.join("\n") || "Unknown schema error";
return emptyCustomModelsResult(
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
);
}
// Additional validation
this.validateConfig(config);
const overrides = new Map<string, ProviderOverride>();
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
if (
providerConfig.baseUrl ||
providerConfig.headers ||
providerConfig.apiKey
) {
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
apiKey: providerConfig.apiKey,
});
}
// Store API key for fallback resolver.
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
if (providerConfig.modelOverrides) {
modelOverrides.set(
providerName,
new Map(Object.entries(providerConfig.modelOverrides)),
);
}
}
return {
models: this.parseModels(config),
overrides,
modelOverrides,
error: undefined,
};
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
);
}
return emptyCustomModelsResult(
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
);
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
const hasModelOverrides =
providerConfig.modelOverrides &&
Object.keys(providerConfig.modelOverrides).length > 0;
if (models.length === 0) {
// Override-only config: needs baseUrl OR modelOverrides (or both)
if (!providerConfig.baseUrl && !hasModelOverrides) {
throw new Error(
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
);
}
} else {
// Custom models are merged into provider models and require endpoint + auth.
if (!providerConfig.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
);
}
if (!providerConfig.apiKey) {
throw new Error(
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
);
}
}
for (const modelDef of models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id)
throw new Error(`Provider ${providerName}: model missing "id"`);
// Validate contextWindow/maxTokens only if provided (they have defaults)
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
);
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
// Store API key config for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
for (const modelDef of modelDefs) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
// Resolve env vars and shell commands in header values
const providerHeaders = resolveHeaders(providerConfig.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader && providerConfig.apiKey) {
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
// Provider baseUrl is required when custom models are defined.
// Individual models can override it with modelDef.baseUrl.
const defaultCost = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
};
models.push({
id: modelDef.id,
name: modelDef.name ?? modelDef.id,
api: api as Api,
provider: providerName,
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
reasoning: modelDef.reasoning ?? false,
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
cost: modelDef.cost ?? defaultCost,
contextWindow: modelDef.contextWindow ?? 128000,
maxTokens: modelDef.maxTokens ?? 16384,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
}
/**
* Get only models that have auth configured.
* This is a fast check that doesn't refresh OAuth tokens.
*/
getAvailable(): Model<Api>[] {
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | undefined {
return this.models.find((m) => m.provider === provider && m.id === modelId);
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | undefined> {
return this.authStorage.getApiKey(model.provider);
}
/**
* Get API key for a provider.
*/
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
return this.authStorage.getApiKey(provider);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
/**
* Unregister a previously registered provider.
*
* Removes the provider from the registry and reloads models from disk so that
* built-in models overridden by this provider are restored to their original state.
* Also resets dynamic OAuth and API stream registrations before reapplying
* remaining dynamic providers.
* Has no effect if the provider was never registered.
*/
unregisterProvider(providerName: string): void {
if (!this.registeredProviders.has(providerName)) return;
this.registeredProviders.delete(providerName);
this.customProviderApiKeys.delete(providerName);
this.refresh();
}
private applyProviderConfig(
providerName: string,
config: ProviderConfigInput,
): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
if (config.streamSimple) {
if (!config.api) {
throw new Error(
`Provider ${providerName}: "api" is required when registering streamSimple.`,
);
}
const streamSimple = config.streamSimple;
registerApiProvider(
{
api: config.api,
stream: (model, context, options) =>
streamSimple(model, context, options as SimpleStreamOptions),
streamSimple,
},
`provider:${providerName}`,
);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining models.`,
);
}
if (!config.apiKey && !config.oauth) {
throw new Error(
`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`,
);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
if (config.oauth?.modifyModels) {
const cred = this.authStorage.get(providerName);
if (cred?.type === "oauth") {
this.models = config.oauth.modifyModels(this.models, cred);
}
}
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders
? { ...m.headers, ...resolvedHeaders }
: m.headers,
};
});
}
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
}>;
}

View file

@ -0,0 +1,707 @@
/**
* Model resolution, scoping, and initial selection
*/
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import {
type Api,
type KnownProvider,
type Model,
modelsAreEqual,
} from "@mariozechner/pi-ai";
import chalk from "chalk";
import { minimatch } from "minimatch";
import { isValidThinkingLevel } from "../cli/args.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type { ModelRegistry } from "./model-registry.js";
/** Default model IDs for each known provider */
export const defaultModelPerProvider: Record<KnownProvider, string> = {
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
anthropic: "claude-opus-4-6",
openai: "gpt-5.4",
"azure-openai-responses": "gpt-5.2",
"openai-codex": "gpt-5.4",
google: "gemini-2.5-pro",
"google-gemini-cli": "gemini-2.5-pro",
"google-antigravity": "gemini-3.1-pro-high",
"google-vertex": "gemini-3-pro-preview",
"github-copilot": "gpt-4o",
openrouter: "openai/gpt-5.1-codex",
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
xai: "grok-4-fast-non-reasoning",
groq: "openai/gpt-oss-120b",
cerebras: "zai-glm-4.6",
zai: "glm-4.6",
mistral: "devstral-medium-latest",
minimax: "MiniMax-M2.1",
"minimax-cn": "MiniMax-M2.1",
huggingface: "moonshotai/Kimi-K2.5",
opencode: "claude-opus-4-6",
"opencode-go": "kimi-k2.5",
"kimi-coding": "kimi-k2-thinking",
};
export interface ScopedModel {
model: Model<Api>;
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
thinkingLevel?: ThinkingLevel;
}
/**
* Helper to check if a model ID looks like an alias (no date suffix)
* Dates are typically in format: -20241022 or -20250929
*/
function isAlias(id: string): boolean {
// Check if ID ends with -latest
if (id.endsWith("-latest")) return true;
// Check if ID ends with a date pattern (-YYYYMMDD)
const datePattern = /-\d{8}$/;
return !datePattern.test(id);
}
/**
* Try to match a pattern to a model from the available models list.
* Returns the matched model or undefined if no match found.
*/
function tryMatchModel(
modelPattern: string,
availableModels: Model<Api>[],
): Model<Api> | undefined {
// Check for provider/modelId format (provider is everything before the first /)
const slashIndex = modelPattern.indexOf("/");
if (slashIndex !== -1) {
const provider = modelPattern.substring(0, slashIndex);
const modelId = modelPattern.substring(slashIndex + 1);
const providerMatch = availableModels.find(
(m) =>
m.provider.toLowerCase() === provider.toLowerCase() &&
m.id.toLowerCase() === modelId.toLowerCase(),
);
if (providerMatch) {
return providerMatch;
}
// No exact provider/model match - fall through to other matching
}
// Check for exact ID match (case-insensitive)
const exactMatch = availableModels.find(
(m) => m.id.toLowerCase() === modelPattern.toLowerCase(),
);
if (exactMatch) {
return exactMatch;
}
// No exact match - fall back to partial matching
const matches = availableModels.filter(
(m) =>
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
);
if (matches.length === 0) {
return undefined;
}
// Separate into aliases and dated versions
const aliases = matches.filter((m) => isAlias(m.id));
const datedVersions = matches.filter((m) => !isAlias(m.id));
if (aliases.length > 0) {
// Prefer alias - if multiple aliases, pick the one that sorts highest
aliases.sort((a, b) => b.id.localeCompare(a.id));
return aliases[0];
} else {
// No alias found, pick latest dated version
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
return datedVersions[0];
}
}
export interface ParsedModelResult {
model: Model<Api> | undefined;
/** Thinking level if explicitly specified in pattern, undefined otherwise */
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
}
function buildFallbackModel(
provider: string,
modelId: string,
availableModels: Model<Api>[],
): Model<Api> | undefined {
const providerModels = availableModels.filter((m) => m.provider === provider);
if (providerModels.length === 0) return undefined;
const defaultId = defaultModelPerProvider[provider as KnownProvider];
const baseModel = defaultId
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
: providerModels[0];
return {
...baseModel,
id: modelId,
name: modelId,
};
}
/**
* Parse a pattern to extract model and thinking level.
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
*
* Algorithm:
* 1. Try to match full pattern as a model
* 2. If found, return it with "off" thinking level
* 3. If not found and has colons, split on last colon:
* - If suffix is valid thinking level, use it and recurse on prefix
* - If suffix is invalid, warn and recurse on prefix with "off"
*
* @internal Exported for testing
*/
export function parseModelPattern(
pattern: string,
availableModels: Model<Api>[],
options?: { allowInvalidThinkingLevelFallback?: boolean },
): ParsedModelResult {
// Try exact match first
const exactMatch = tryMatchModel(pattern, availableModels);
if (exactMatch) {
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
}
// No match - try splitting on last colon if present
const lastColonIndex = pattern.lastIndexOf(":");
if (lastColonIndex === -1) {
// No colons, pattern simply doesn't match any model
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
const prefix = pattern.substring(0, lastColonIndex);
const suffix = pattern.substring(lastColonIndex + 1);
if (isValidThinkingLevel(suffix)) {
// Valid thinking level - recurse on prefix and use this level
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
// Only use this thinking level if no warning from inner recursion
return {
model: result.model,
thinkingLevel: result.warning ? undefined : suffix,
warning: result.warning,
};
}
return result;
} else {
// Invalid suffix
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
if (!allowFallback) {
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
// This avoids accidentally resolving to a different model.
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
// Scope mode: recurse on prefix and warn
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
return {
model: result.model,
thinkingLevel: undefined,
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
};
}
return result;
}
}
/**
* Resolve model patterns to actual Model objects with optional thinking levels
* Format: "pattern:level" where :level is optional
* For each pattern, finds all matching models and picks the best version:
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
* 2. If no alias, pick the latest dated version
*
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
* The algorithm tries to match the full pattern first, then progressively
* strips colon-suffixes to find a match.
*/
export async function resolveModelScope(
patterns: string[],
modelRegistry: ModelRegistry,
): Promise<ScopedModel[]> {
const availableModels = await modelRegistry.getAvailable();
const scopedModels: ScopedModel[] = [];
for (const pattern of patterns) {
// Check if pattern contains glob characters
if (
pattern.includes("*") ||
pattern.includes("?") ||
pattern.includes("[")
) {
// Extract optional thinking level suffix (e.g., "provider/*:high")
const colonIdx = pattern.lastIndexOf(":");
let globPattern = pattern;
let thinkingLevel: ThinkingLevel | undefined;
if (colonIdx !== -1) {
const suffix = pattern.substring(colonIdx + 1);
if (isValidThinkingLevel(suffix)) {
thinkingLevel = suffix;
globPattern = pattern.substring(0, colonIdx);
}
}
// Match against "provider/modelId" format OR just model ID
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
const matchingModels = availableModels.filter((m) => {
const fullId = `${m.provider}/${m.id}`;
return (
minimatch(fullId, globPattern, { nocase: true }) ||
minimatch(m.id, globPattern, { nocase: true })
);
});
if (matchingModels.length === 0) {
console.warn(
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
);
continue;
}
for (const model of matchingModels) {
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
continue;
}
const { model, thinkingLevel, warning } = parseModelPattern(
pattern,
availableModels,
);
if (warning) {
console.warn(chalk.yellow(`Warning: ${warning}`));
}
if (!model) {
console.warn(
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
);
continue;
}
// Avoid duplicates
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
return scopedModels;
}
export interface ResolveCliModelResult {
model: Model<Api> | undefined;
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
/**
* Error message suitable for CLI display.
* When set, model will be undefined.
*/
error: string | undefined;
}
/**
* Resolve a single model from CLI flags.
*
* Supports:
* - --provider <provider> --model <pattern>
* - --model <provider>/<pattern>
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
*
* Note: This does not apply the thinking level by itself, but it may *parse* and
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
*/
export function resolveCliModel(options: {
cliProvider?: string;
cliModel?: string;
modelRegistry: ModelRegistry;
}): ResolveCliModelResult {
const { cliProvider, cliModel, modelRegistry } = options;
if (!cliModel) {
return { model: undefined, warning: undefined, error: undefined };
}
// Important: use *all* models here, not just models with pre-configured auth.
// This allows "--api-key" to be used for first-time setup.
const availableModels = modelRegistry.getAll();
if (availableModels.length === 0) {
return {
model: undefined,
warning: undefined,
error:
"No models available. Check your installation or add models to models.json.",
};
}
// Build canonical provider lookup (case-insensitive)
const providerMap = new Map<string, string>();
for (const m of availableModels) {
providerMap.set(m.provider.toLowerCase(), m.provider);
}
let provider = cliProvider
? providerMap.get(cliProvider.toLowerCase())
: undefined;
if (cliProvider && !provider) {
return {
model: undefined,
warning: undefined,
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
};
}
// If no explicit --provider, try to interpret "provider/model" format first.
// When the prefix before the first slash matches a known provider, prefer that
// interpretation over matching models whose IDs literally contain slashes
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
// vercel-ai-gateway model with id "zai/glm-5").
let pattern = cliModel;
let inferredProvider = false;
if (!provider) {
const slashIndex = cliModel.indexOf("/");
if (slashIndex !== -1) {
const maybeProvider = cliModel.substring(0, slashIndex);
const canonical = providerMap.get(maybeProvider.toLowerCase());
if (canonical) {
provider = canonical;
pattern = cliModel.substring(slashIndex + 1);
inferredProvider = true;
}
}
}
// If no provider was inferred from the slash, try exact matches without provider inference.
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
if (!provider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) =>
m.id.toLowerCase() === lower ||
`${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return {
model: exact,
warning: undefined,
thinkingLevel: undefined,
error: undefined,
};
}
}
if (cliProvider && provider) {
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
const prefix = `${provider}/`;
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
pattern = cliModel.substring(prefix.length);
}
}
const candidates = provider
? availableModels.filter((m) => m.provider === provider)
: availableModels;
const { model, thinkingLevel, warning } = parseModelPattern(
pattern,
candidates,
{
allowInvalidThinkingLevelFallback: false,
},
);
if (model) {
return { model, thinkingLevel, warning, error: undefined };
}
// If we inferred a provider from the slash but found no match within that provider,
// fall back to matching the full input as a raw model id across all models.
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
// looks like a provider but the full string is actually a model id on openrouter.
if (inferredProvider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) =>
m.id.toLowerCase() === lower ||
`${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return {
model: exact,
warning: undefined,
thinkingLevel: undefined,
error: undefined,
};
}
// Also try parseModelPattern on the full input against all models
const fallback = parseModelPattern(cliModel, availableModels, {
allowInvalidThinkingLevelFallback: false,
});
if (fallback.model) {
return {
model: fallback.model,
thinkingLevel: fallback.thinkingLevel,
warning: fallback.warning,
error: undefined,
};
}
}
if (provider) {
const fallbackModel = buildFallbackModel(
provider,
pattern,
availableModels,
);
if (fallbackModel) {
const fallbackWarning = warning
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
return {
model: fallbackModel,
thinkingLevel: undefined,
warning: fallbackWarning,
error: undefined,
};
}
}
const display = provider ? `${provider}/${pattern}` : cliModel;
return {
model: undefined,
thinkingLevel: undefined,
warning,
error: `Model "${display}" not found. Use --list-models to see available models.`,
};
}
export interface InitialModelResult {
model: Model<Api> | undefined;
thinkingLevel: ThinkingLevel;
fallbackMessage: string | undefined;
}
/**
* Find the initial model to use based on priority:
* 1. CLI args (provider + model)
* 2. First model from scoped models (if not continuing/resuming)
* 3. Restored from session (if continuing/resuming)
* 4. Saved default from settings
* 5. First available model with valid API key
*/
export async function findInitialModel(options: {
cliProvider?: string;
cliModel?: string;
scopedModels: ScopedModel[];
isContinuing: boolean;
defaultProvider?: string;
defaultModelId?: string;
defaultThinkingLevel?: ThinkingLevel;
modelRegistry: ModelRegistry;
}): Promise<InitialModelResult> {
const {
cliProvider,
cliModel,
scopedModels,
isContinuing,
defaultProvider,
defaultModelId,
defaultThinkingLevel,
modelRegistry,
} = options;
let model: Model<Api> | undefined;
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
// 1. CLI args take priority
if (cliProvider && cliModel) {
const resolved = resolveCliModel({
cliProvider,
cliModel,
modelRegistry,
});
if (resolved.error) {
console.error(chalk.red(resolved.error));
process.exit(1);
}
if (resolved.model) {
return {
model: resolved.model,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
}
// 2. Use first model from scoped models (skip if continuing/resuming)
if (scopedModels.length > 0 && !isContinuing) {
return {
model: scopedModels[0].model,
thinkingLevel:
scopedModels[0].thinkingLevel ??
defaultThinkingLevel ??
DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
// 3. Try saved default from settings
if (defaultProvider && defaultModelId) {
const found = modelRegistry.find(defaultProvider, defaultModelId);
if (found) {
model = found;
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: undefined };
}
}
// 4. Try first available model with valid API key
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
for (const provider of Object.keys(
defaultModelPerProvider,
) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find(
(m) => m.provider === provider && m.id === defaultId,
);
if (match) {
return {
model: match,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
}
// If no default found, use first available
return {
model: availableModels[0],
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
// 5. No model found
return {
model: undefined,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
/**
* Restore model from session, with fallback to available models
*/
export async function restoreModelFromSession(
savedProvider: string,
savedModelId: string,
currentModel: Model<Api> | undefined,
shouldPrintMessages: boolean,
modelRegistry: ModelRegistry,
): Promise<{
model: Model<Api> | undefined;
fallbackMessage: string | undefined;
}> {
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
const hasApiKey = restoredModel
? !!(await modelRegistry.getApiKey(restoredModel))
: false;
if (restoredModel && hasApiKey) {
if (shouldPrintMessages) {
console.log(
chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`),
);
}
return { model: restoredModel, fallbackMessage: undefined };
}
// Model not found or no API key - fall back
const reason = !restoredModel
? "model no longer exists"
: "no API key available";
if (shouldPrintMessages) {
console.error(
chalk.yellow(
`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`,
),
);
}
// If we already have a model, use it as fallback
if (currentModel) {
if (shouldPrintMessages) {
console.log(
chalk.dim(
`Falling back to: ${currentModel.provider}/${currentModel.id}`,
),
);
}
return {
model: currentModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
};
}
// Try to find any available model
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
let fallbackModel: Model<Api> | undefined;
for (const provider of Object.keys(
defaultModelPerProvider,
) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find(
(m) => m.provider === provider && m.id === defaultId,
);
if (match) {
fallbackModel = match;
break;
}
}
// If no default found, use first available
if (!fallbackModel) {
fallbackModel = availableModels[0];
}
if (shouldPrintMessages) {
console.log(
chalk.dim(
`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`,
),
);
}
return {
model: fallbackModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
};
}
// No models available
return { model: undefined, fallbackMessage: undefined };
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,327 @@
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
import { homedir } from "os";
import { basename, isAbsolute, join, resolve, sep } from "path";
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
import { parseFrontmatter } from "../utils/frontmatter.js";
/**
* Represents a prompt template loaded from a markdown file
*/
export interface PromptTemplate {
name: string;
description: string;
content: string;
source: string; // "user", "project", or "path"
filePath: string; // Absolute path to the template file
}
/**
* Parse command arguments respecting quoted strings (bash-style)
* Returns array of arguments
*/
export function parseCommandArgs(argsString: string): string[] {
const args: string[] = [];
let current = "";
let inQuote: string | null = null;
for (let i = 0; i < argsString.length; i++) {
const char = argsString[i];
if (inQuote) {
if (char === inQuote) {
inQuote = null;
} else {
current += char;
}
} else if (char === '"' || char === "'") {
inQuote = char;
} else if (char === " " || char === "\t") {
if (current) {
args.push(current);
current = "";
}
} else {
current += char;
}
}
if (current) {
args.push(current);
}
return args;
}
/**
* Substitute argument placeholders in template content
* Supports:
* - $1, $2, ... for positional args
* - $@ and $ARGUMENTS for all args
* - ${@:N} for args from Nth onwards (bash-style slicing)
* - ${@:N:L} for L args starting from Nth
*
* Note: Replacement happens on the template string only. Argument values
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
*/
export function substituteArgs(content: string, args: string[]): string {
let result = content;
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
result = result.replace(/\$(\d+)/g, (_, num) => {
const index = parseInt(num, 10) - 1;
return args[index] ?? "";
});
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
// Process BEFORE simple $@ to avoid conflicts
result = result.replace(
/\$\{@:(\d+)(?::(\d+))?\}/g,
(_, startStr, lengthStr) => {
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
// Treat 0 as 1 (bash convention: args start at 1)
if (start < 0) start = 0;
if (lengthStr) {
const length = parseInt(lengthStr, 10);
return args.slice(start, start + length).join(" ");
}
return args.slice(start).join(" ");
},
);
// Pre-compute all args joined (optimization)
const allArgs = args.join(" ");
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
result = result.replace(/\$ARGUMENTS/g, allArgs);
// Replace $@ with all args joined (existing syntax)
result = result.replace(/\$@/g, allArgs);
return result;
}
function loadTemplateFromFile(
filePath: string,
source: string,
sourceLabel: string,
): PromptTemplate | null {
try {
const rawContent = readFileSync(filePath, "utf-8");
const { frontmatter, body } =
parseFrontmatter<Record<string, string>>(rawContent);
const name = basename(filePath).replace(/\.md$/, "");
// Get description from frontmatter or first non-empty line
let description = frontmatter.description || "";
if (!description) {
const firstLine = body.split("\n").find((line) => line.trim());
if (firstLine) {
// Truncate if too long
description = firstLine.slice(0, 60);
if (firstLine.length > 60) description += "...";
}
}
// Append source to description
description = description ? `${description} ${sourceLabel}` : sourceLabel;
return {
name,
description,
content: body,
source,
filePath,
};
} catch {
return null;
}
}
/**
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
*/
function loadTemplatesFromDir(
dir: string,
source: string,
sourceLabel: string,
): PromptTemplate[] {
const templates: PromptTemplate[] = [];
if (!existsSync(dir)) {
return templates;
}
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const fullPath = join(dir, entry.name);
// For symlinks, check if they point to a file
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
const stats = statSync(fullPath);
isFile = stats.isFile();
} catch {
// Broken symlink, skip it
continue;
}
}
if (isFile && entry.name.endsWith(".md")) {
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
if (template) {
templates.push(template);
}
}
}
} catch {
return templates;
}
return templates;
}
export interface LoadPromptTemplatesOptions {
/** Working directory for project-local templates. Default: process.cwd() */
cwd?: string;
/** Agent config directory for global templates. Default: from getPromptsDir() */
agentDir?: string;
/** Explicit prompt template paths (files or directories) */
promptPaths?: string[];
/** Include default prompt directories. Default: true */
includeDefaults?: boolean;
}
function normalizePath(input: string): string {
const trimmed = input.trim();
if (trimmed === "~") return homedir();
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
return trimmed;
}
function resolvePromptPath(p: string, cwd: string): string {
const normalized = normalizePath(p);
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
}
function buildPathSourceLabel(p: string): string {
const base = basename(p).replace(/\.md$/, "") || "path";
return `(path:${base})`;
}
/**
* Load all prompt templates from:
* 1. Global: agentDir/prompts/
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
* 3. Explicit prompt paths
*/
export function loadPromptTemplates(
options: LoadPromptTemplatesOptions = {},
): PromptTemplate[] {
const resolvedCwd = options.cwd ?? process.cwd();
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
const promptPaths = options.promptPaths ?? [];
const includeDefaults = options.includeDefaults ?? true;
const templates: PromptTemplate[] = [];
if (includeDefaults) {
// 1. Load global templates from agentDir/prompts/
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
const globalPromptsDir = options.agentDir
? join(options.agentDir, "prompts")
: resolvedAgentDir;
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
templates.push(
...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"),
);
}
const userPromptsDir = options.agentDir
? join(options.agentDir, "prompts")
: resolvedAgentDir;
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
const isUnderPath = (target: string, root: string): boolean => {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep)
? normalizedRoot
: `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
};
const getSourceInfo = (
resolvedPath: string,
): { source: string; label: string } => {
if (!includeDefaults) {
if (isUnderPath(resolvedPath, userPromptsDir)) {
return { source: "user", label: "(user)" };
}
if (isUnderPath(resolvedPath, projectPromptsDir)) {
return { source: "project", label: "(project)" };
}
}
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
};
// 3. Load explicit prompt paths
for (const rawPath of promptPaths) {
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
if (!existsSync(resolvedPath)) {
continue;
}
try {
const stats = statSync(resolvedPath);
const { source, label } = getSourceInfo(resolvedPath);
if (stats.isDirectory()) {
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
const template = loadTemplateFromFile(resolvedPath, source, label);
if (template) {
templates.push(template);
}
}
} catch {
// Ignore read failures
}
}
return templates;
}
/**
* Expand a prompt template if it matches a template name.
* Returns the expanded content or the original text if not a template.
*/
export function expandPromptTemplate(
text: string,
templates: PromptTemplate[],
): string {
if (!text.startsWith("/")) return text;
const spaceIndex = text.indexOf(" ");
const templateName =
spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
const template = templates.find((t) => t.name === templateName);
if (template) {
const args = parseCommandArgs(argsString);
return substituteArgs(template.content, args);
}
return text;
}

View file

@ -0,0 +1,66 @@
/**
* Resolve configuration values that may be shell commands, environment variables, or literals.
* Used by auth-storage.ts and model-registry.ts.
*/
import { execSync } from "child_process";
// Cache for shell command results (persists for process lifetime)
const commandResultCache = new Map<string, string | undefined>();
/**
* Resolve a config value (API key, header value, etc.) to an actual value.
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
* - Otherwise checks environment variable first, then treats as literal (not cached)
*/
export function resolveConfigValue(config: string): string | undefined {
if (config.startsWith("!")) {
return executeCommand(config);
}
const envValue = process.env[config];
return envValue || config;
}
function executeCommand(commandConfig: string): string | undefined {
if (commandResultCache.has(commandConfig)) {
return commandResultCache.get(commandConfig);
}
const command = commandConfig.slice(1);
let result: string | undefined;
try {
const output = execSync(command, {
encoding: "utf-8",
timeout: 10000,
stdio: ["ignore", "pipe", "ignore"],
});
result = output.trim() || undefined;
} catch {
result = undefined;
}
commandResultCache.set(commandConfig, result);
return result;
}
/**
* Resolve all header values using the same resolution logic as API keys.
*/
export function resolveHeaders(
headers: Record<string, string> | undefined,
): Record<string, string> | undefined {
if (!headers) return undefined;
const resolved: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
const resolvedValue = resolveConfigValue(value);
if (resolvedValue) {
resolved[key] = resolvedValue;
}
}
return Object.keys(resolved).length > 0 ? resolved : undefined;
}
/** Clear the config value command cache. Exported for testing. */
export function clearConfigValueCache(): void {
commandResultCache.clear();
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,398 @@
import { join } from "node:path";
import {
Agent,
type AgentMessage,
type ThinkingLevel,
} from "@mariozechner/pi-agent-core";
import type { Message, Model } from "@mariozechner/pi-ai";
import { getAgentDir, getDocsPath } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type {
ExtensionRunner,
LoadExtensionsResult,
ToolDefinition,
} from "./extensions/index.js";
import { convertToLlm } from "./messages.js";
import { ModelRegistry } from "./model-registry.js";
import { findInitialModel } from "./model-resolver.js";
import type { ResourceLoader } from "./resource-loader.js";
import { DefaultResourceLoader } from "./resource-loader.js";
import { SessionManager } from "./session-manager.js";
import { SettingsManager } from "./settings-manager.js";
import { time } from "./timings.js";
import {
allTools,
bashTool,
codingTools,
createBashTool,
createCodingTools,
createEditTool,
createFindTool,
createGrepTool,
createLsTool,
createReadOnlyTools,
createReadTool,
createWriteTool,
editTool,
findTool,
grepTool,
lsTool,
readOnlyTools,
readTool,
type Tool,
type ToolName,
writeTool,
} from "./tools/index.js";
export interface CreateAgentSessionOptions {
/** Working directory for project-local discovery. Default: process.cwd() */
cwd?: string;
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
authStorage?: AuthStorage;
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
thinkingLevel?: ThinkingLevel;
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
tools?: Tool[];
/** Custom tools to register (in addition to built-in tools). */
customTools?: ToolDefinition[];
/** Resource loader. When omitted, DefaultResourceLoader is used. */
resourceLoader?: ResourceLoader;
/** Session manager. Default: SessionManager.create(cwd) */
sessionManager?: SessionManager;
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
settingsManager?: SettingsManager;
}
/** Result from createAgentSession */
export interface CreateAgentSessionResult {
/** The created session */
session: AgentSession;
/** Extensions result (for UI context setup in interactive mode) */
extensionsResult: LoadExtensionsResult;
/** Warning if session was restored with a different model than saved */
modelFallbackMessage?: string;
}
// Re-exports
export type {
ExtensionAPI,
ExtensionCommandContext,
ExtensionContext,
ExtensionFactory,
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
ToolDefinition,
} from "./extensions/index.js";
export type { PromptTemplate } from "./prompt-templates.js";
export type { Skill } from "./skills.js";
export type { Tool } from "./tools/index.js";
export {
// Pre-built tools (use process.cwd())
readTool,
bashTool,
editTool,
writeTool,
grepTool,
findTool,
lsTool,
codingTools,
readOnlyTools,
allTools as allBuiltInTools,
// Tool factories (for custom cwd)
createCodingTools,
createReadOnlyTools,
createReadTool,
createBashTool,
createEditTool,
createWriteTool,
createGrepTool,
createFindTool,
createLsTool,
};
// Helper Functions
function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Create an AgentSession with the specified options.
*
* @example
* ```typescript
* // Minimal - uses defaults
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@mariozechner/pi-ai';
* const { session } = await createAgentSession({
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
* // Continue previous session
* const { session, modelFallbackMessage } = await createAgentSession({
* continueSession: true,
* });
*
* // Full control
* const loader = new DefaultResourceLoader({
* cwd: process.cwd(),
* agentDir: getAgentDir(),
* settingsManager: SettingsManager.create(),
* });
* await loader.reload();
* const { session } = await createAgentSession({
* model: myModel,
* tools: [readTool, bashTool],
* resourceLoader: loader,
* sessionManager: SessionManager.inMemory(),
* });
* ```
*/
export async function createAgentSession(
options: CreateAgentSessionOptions = {},
): Promise<CreateAgentSessionResult> {
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
let resourceLoader = options.resourceLoader;
// Use provided or create AuthStorage and ModelRegistry
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
const modelsPath = options.agentDir
? join(agentDir, "models.json")
: undefined;
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
const modelRegistry =
options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const settingsManager =
options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
if (!resourceLoader) {
resourceLoader = new DefaultResourceLoader({
cwd,
agentDir,
settingsManager,
});
await resourceLoader.reload();
time("resourceLoader.reload");
}
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
const hasExistingSession = existingSession.messages.length > 0;
const hasThinkingEntry = sessionManager
.getBranch()
.some((entry) => entry.type === "thinking_level_change");
let model = options.model;
let modelFallbackMessage: string | undefined;
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = modelRegistry.find(
existingSession.model.provider,
existingSession.model.modelId,
);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
}
}
// If still no model, use findInitialModel (checks settings default, then provider defaults)
if (!model) {
const result = await findInitialModel({
scopedModels: [],
isContinuing: hasExistingSession,
defaultProvider: settingsManager.getDefaultProvider(),
defaultModelId: settingsManager.getDefaultModel(),
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
modelRegistry,
});
model = result.model;
if (!model) {
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
} else if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
}
}
let thinkingLevel = options.thinkingLevel;
// If session has data, restore thinking level from it
if (thinkingLevel === undefined && hasExistingSession) {
thinkingLevel = hasThinkingEntry
? (existingSession.thinkingLevel as ThinkingLevel)
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
}
// Fall back to settings default
if (thinkingLevel === undefined) {
thinkingLevel =
settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
}
// Clamp to model capabilities
if (!model || !model.reasoning) {
thinkingLevel = "off";
}
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
const initialActiveToolNames: ToolName[] = options.tools
? options.tools
.map((t) => t.name)
.filter((n): n is ToolName => n in allTools)
: defaultActiveToolNames;
let agent: Agent;
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
const converted = convertToLlm(messages);
// Check setting dynamically so mid-session changes take effect
if (!settingsManager.getBlockImages()) {
return converted;
}
// Filter out ImageContent from all messages, replacing with text placeholder
return converted.map((msg) => {
if (msg.role === "user" || msg.role === "toolResult") {
const content = msg.content;
if (Array.isArray(content)) {
const hasImages = content.some((c) => c.type === "image");
if (hasImages) {
const filteredContent = content
.map((c) =>
c.type === "image"
? {
type: "text" as const,
text: "Image reading is disabled.",
}
: c,
)
.filter(
(c, i, arr) =>
// Dedupe consecutive "Image reading is disabled." texts
!(
c.type === "text" &&
c.text === "Image reading is disabled." &&
i > 0 &&
arr[i - 1].type === "text" &&
(arr[i - 1] as { type: "text"; text: string }).text ===
"Image reading is disabled."
),
);
return { ...msg, content: filteredContent };
}
}
}
return msg;
});
};
const extensionRunnerRef: { current?: ExtensionRunner } = {};
agent = new Agent({
initialState: {
systemPrompt: "",
model,
thinkingLevel,
tools: [],
},
convertToLlm: convertToLlmWithBlockImages,
sessionId: sessionManager.getSessionId(),
transformContext: async (messages) => {
const runner = extensionRunnerRef.current;
if (!runner) return messages;
return runner.emitContext(messages);
},
steeringMode: settingsManager.getSteeringMode(),
followUpMode: settingsManager.getFollowUpMode(),
transport: settingsManager.getTransport(),
thinkingBudgets: settingsManager.getThinkingBudgets(),
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
getApiKey: async (provider) => {
// Use the provider argument from the in-flight request;
// agent.state.model may already be switched mid-turn.
const resolvedProvider = provider || agent.state.model?.provider;
if (!resolvedProvider) {
throw new Error("No model selected");
}
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
if (!key) {
const model = agent.state.model;
const isOAuth = model && modelRegistry.isUsingOAuth(model);
if (isOAuth) {
throw new Error(
`Authentication failed for "${resolvedProvider}". ` +
`Credentials may have expired or network is unavailable. ` +
`Run '/login ${resolvedProvider}' to re-authenticate.`,
);
}
throw new Error(
`No API key found for "${resolvedProvider}". ` +
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
);
}
return key;
},
});
// Restore messages if session has existing data
if (hasExistingSession) {
agent.replaceMessages(existingSession.messages);
if (!hasThinkingEntry) {
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
} else {
// Save initial model and thinking level for new sessions so they can be restored on resume
if (model) {
sessionManager.appendModelChange(model.provider, model.id);
}
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
const session = new AgentSession({
agent,
sessionManager,
settingsManager,
cwd,
scopedModels: options.scopedModels,
resourceLoader,
customTools: options.customTools,
modelRegistry,
initialActiveToolNames,
extensionRunnerRef,
});
const extensionsResult = resourceLoader.getExtensions();
return {
session,
extensionsResult,
modelFallbackMessage,
};
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,518 @@
import {
existsSync,
readdirSync,
readFileSync,
realpathSync,
statSync,
} from "fs";
import ignore from "ignore";
import { homedir } from "os";
import {
basename,
dirname,
isAbsolute,
join,
relative,
resolve,
sep,
} from "path";
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
import { parseFrontmatter } from "../utils/frontmatter.js";
import type { ResourceDiagnostic } from "./diagnostics.js";
/** Max name length per spec */
const MAX_NAME_LENGTH = 64;
/** Max description length per spec */
const MAX_DESCRIPTION_LENGTH = 1024;
const IGNORE_FILE_NAMES = [".gitignore", ".ignore", ".fdignore"];
type IgnoreMatcher = ReturnType<typeof ignore>;
function toPosixPath(p: string): string {
return p.split(sep).join("/");
}
function prefixIgnorePattern(line: string, prefix: string): string | null {
const trimmed = line.trim();
if (!trimmed) return null;
if (trimmed.startsWith("#") && !trimmed.startsWith("\\#")) return null;
let pattern = line;
let negated = false;
if (pattern.startsWith("!")) {
negated = true;
pattern = pattern.slice(1);
} else if (pattern.startsWith("\\!")) {
pattern = pattern.slice(1);
}
if (pattern.startsWith("/")) {
pattern = pattern.slice(1);
}
const prefixed = prefix ? `${prefix}${pattern}` : pattern;
return negated ? `!${prefixed}` : prefixed;
}
function addIgnoreRules(ig: IgnoreMatcher, dir: string, rootDir: string): void {
const relativeDir = relative(rootDir, dir);
const prefix = relativeDir ? `${toPosixPath(relativeDir)}/` : "";
for (const filename of IGNORE_FILE_NAMES) {
const ignorePath = join(dir, filename);
if (!existsSync(ignorePath)) continue;
try {
const content = readFileSync(ignorePath, "utf-8");
const patterns = content
.split(/\r?\n/)
.map((line) => prefixIgnorePattern(line, prefix))
.filter((line): line is string => Boolean(line));
if (patterns.length > 0) {
ig.add(patterns);
}
} catch {}
}
}
export interface SkillFrontmatter {
name?: string;
description?: string;
"disable-model-invocation"?: boolean;
[key: string]: unknown;
}
export interface Skill {
name: string;
description: string;
filePath: string;
baseDir: string;
source: string;
disableModelInvocation: boolean;
}
export interface LoadSkillsResult {
skills: Skill[];
diagnostics: ResourceDiagnostic[];
}
/**
* Validate skill name per Agent Skills spec.
* Returns array of validation error messages (empty if valid).
*/
function validateName(name: string, parentDirName: string): string[] {
const errors: string[] = [];
if (name !== parentDirName) {
errors.push(
`name "${name}" does not match parent directory "${parentDirName}"`,
);
}
if (name.length > MAX_NAME_LENGTH) {
errors.push(`name exceeds ${MAX_NAME_LENGTH} characters (${name.length})`);
}
if (!/^[a-z0-9-]+$/.test(name)) {
errors.push(
`name contains invalid characters (must be lowercase a-z, 0-9, hyphens only)`,
);
}
if (name.startsWith("-") || name.endsWith("-")) {
errors.push(`name must not start or end with a hyphen`);
}
if (name.includes("--")) {
errors.push(`name must not contain consecutive hyphens`);
}
return errors;
}
/**
* Validate description per Agent Skills spec.
*/
function validateDescription(description: string | undefined): string[] {
const errors: string[] = [];
if (!description || description.trim() === "") {
errors.push("description is required");
} else if (description.length > MAX_DESCRIPTION_LENGTH) {
errors.push(
`description exceeds ${MAX_DESCRIPTION_LENGTH} characters (${description.length})`,
);
}
return errors;
}
export interface LoadSkillsFromDirOptions {
/** Directory to scan for skills */
dir: string;
/** Source identifier for these skills */
source: string;
}
/**
* Load skills from a directory.
*
* Discovery rules:
* - direct .md children in the root
* - recursive SKILL.md under subdirectories
*/
export function loadSkillsFromDir(
options: LoadSkillsFromDirOptions,
): LoadSkillsResult {
const { dir, source } = options;
return loadSkillsFromDirInternal(dir, source, true);
}
function loadSkillsFromDirInternal(
dir: string,
source: string,
includeRootFiles: boolean,
ignoreMatcher?: IgnoreMatcher,
rootDir?: string,
): LoadSkillsResult {
const skills: Skill[] = [];
const diagnostics: ResourceDiagnostic[] = [];
if (!existsSync(dir)) {
return { skills, diagnostics };
}
const root = rootDir ?? dir;
const ig = ignoreMatcher ?? ignore();
addIgnoreRules(ig, dir, root);
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
if (entry.name.startsWith(".")) {
continue;
}
// Skip node_modules to avoid scanning dependencies
if (entry.name === "node_modules") {
continue;
}
const fullPath = join(dir, entry.name);
// For symlinks, check if they point to a directory and follow them
let isDirectory = entry.isDirectory();
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
const stats = statSync(fullPath);
isDirectory = stats.isDirectory();
isFile = stats.isFile();
} catch {
// Broken symlink, skip it
continue;
}
}
const relPath = toPosixPath(relative(root, fullPath));
const ignorePath = isDirectory ? `${relPath}/` : relPath;
if (ig.ignores(ignorePath)) {
continue;
}
if (isDirectory) {
const subResult = loadSkillsFromDirInternal(
fullPath,
source,
false,
ig,
root,
);
skills.push(...subResult.skills);
diagnostics.push(...subResult.diagnostics);
continue;
}
if (!isFile) {
continue;
}
const isRootMd = includeRootFiles && entry.name.endsWith(".md");
const isSkillMd = !includeRootFiles && entry.name === "SKILL.md";
if (!isRootMd && !isSkillMd) {
continue;
}
const result = loadSkillFromFile(fullPath, source);
if (result.skill) {
skills.push(result.skill);
}
diagnostics.push(...result.diagnostics);
}
} catch {}
return { skills, diagnostics };
}
function loadSkillFromFile(
filePath: string,
source: string,
): { skill: Skill | null; diagnostics: ResourceDiagnostic[] } {
const diagnostics: ResourceDiagnostic[] = [];
try {
const rawContent = readFileSync(filePath, "utf-8");
const { frontmatter } = parseFrontmatter<SkillFrontmatter>(rawContent);
const skillDir = dirname(filePath);
const parentDirName = basename(skillDir);
// Validate description
const descErrors = validateDescription(frontmatter.description);
for (const error of descErrors) {
diagnostics.push({ type: "warning", message: error, path: filePath });
}
// Use name from frontmatter, or fall back to parent directory name
const name = frontmatter.name || parentDirName;
// Validate name
const nameErrors = validateName(name, parentDirName);
for (const error of nameErrors) {
diagnostics.push({ type: "warning", message: error, path: filePath });
}
// Still load the skill even with warnings (unless description is completely missing)
if (!frontmatter.description || frontmatter.description.trim() === "") {
return { skill: null, diagnostics };
}
return {
skill: {
name,
description: frontmatter.description,
filePath,
baseDir: skillDir,
source,
disableModelInvocation:
frontmatter["disable-model-invocation"] === true,
},
diagnostics,
};
} catch (error) {
const message =
error instanceof Error ? error.message : "failed to parse skill file";
diagnostics.push({ type: "warning", message, path: filePath });
return { skill: null, diagnostics };
}
}
/**
* Format skills for inclusion in a system prompt.
* Uses XML format per Agent Skills standard.
* See: https://agentskills.io/integrate-skills
*
* Skills with disableModelInvocation=true are excluded from the prompt
* (they can only be invoked explicitly via /skill:name commands).
*/
export function formatSkillsForPrompt(skills: Skill[]): string {
const visibleSkills = skills.filter((s) => !s.disableModelInvocation);
if (visibleSkills.length === 0) {
return "";
}
const lines = [
"\n\nThe following skills provide specialized instructions for specific tasks.",
"Use the read tool to load a skill's file when the task matches its description.",
"When a skill file references a relative path, resolve it against the skill directory (parent of SKILL.md / dirname of the path) and use that absolute path in tool commands.",
"",
"<available_skills>",
];
for (const skill of visibleSkills) {
lines.push(" <skill>");
lines.push(` <name>${escapeXml(skill.name)}</name>`);
lines.push(
` <description>${escapeXml(skill.description)}</description>`,
);
lines.push(` <location>${escapeXml(skill.filePath)}</location>`);
lines.push(" </skill>");
}
lines.push("</available_skills>");
return lines.join("\n");
}
function escapeXml(str: string): string {
return str
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&apos;");
}
export interface LoadSkillsOptions {
/** Working directory for project-local skills. Default: process.cwd() */
cwd?: string;
/** Agent config directory for global skills. Default: ~/.pi/agent */
agentDir?: string;
/** Explicit skill paths (files or directories) */
skillPaths?: string[];
/** Include default skills directories. Default: true */
includeDefaults?: boolean;
}
function normalizePath(input: string): string {
const trimmed = input.trim();
if (trimmed === "~") return homedir();
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
return trimmed;
}
function resolveSkillPath(p: string, cwd: string): string {
const normalized = normalizePath(p);
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
}
/**
* Load skills from all configured locations.
* Returns skills and any validation diagnostics.
*/
export function loadSkills(options: LoadSkillsOptions = {}): LoadSkillsResult {
const {
cwd = process.cwd(),
agentDir,
skillPaths = [],
includeDefaults = true,
} = options;
// Resolve agentDir - if not provided, use default from config
const resolvedAgentDir = agentDir ?? getAgentDir();
const skillMap = new Map<string, Skill>();
const realPathSet = new Set<string>();
const allDiagnostics: ResourceDiagnostic[] = [];
const collisionDiagnostics: ResourceDiagnostic[] = [];
function addSkills(result: LoadSkillsResult) {
allDiagnostics.push(...result.diagnostics);
for (const skill of result.skills) {
// Resolve symlinks to detect duplicate files
let realPath: string;
try {
realPath = realpathSync(skill.filePath);
} catch {
realPath = skill.filePath;
}
// Skip silently if we've already loaded this exact file (via symlink)
if (realPathSet.has(realPath)) {
continue;
}
const existing = skillMap.get(skill.name);
if (existing) {
collisionDiagnostics.push({
type: "collision",
message: `name "${skill.name}" collision`,
path: skill.filePath,
collision: {
resourceType: "skill",
name: skill.name,
winnerPath: existing.filePath,
loserPath: skill.filePath,
},
});
} else {
skillMap.set(skill.name, skill);
realPathSet.add(realPath);
}
}
}
if (includeDefaults) {
addSkills(
loadSkillsFromDirInternal(join(resolvedAgentDir, "skills"), "user", true),
);
addSkills(
loadSkillsFromDirInternal(
resolve(cwd, CONFIG_DIR_NAME, "skills"),
"project",
true,
),
);
}
const userSkillsDir = join(resolvedAgentDir, "skills");
const projectSkillsDir = resolve(cwd, CONFIG_DIR_NAME, "skills");
const isUnderPath = (target: string, root: string): boolean => {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep)
? normalizedRoot
: `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
};
const getSource = (resolvedPath: string): "user" | "project" | "path" => {
if (!includeDefaults) {
if (isUnderPath(resolvedPath, userSkillsDir)) return "user";
if (isUnderPath(resolvedPath, projectSkillsDir)) return "project";
}
return "path";
};
for (const rawPath of skillPaths) {
const resolvedPath = resolveSkillPath(rawPath, cwd);
if (!existsSync(resolvedPath)) {
allDiagnostics.push({
type: "warning",
message: "skill path does not exist",
path: resolvedPath,
});
continue;
}
try {
const stats = statSync(resolvedPath);
const source = getSource(resolvedPath);
if (stats.isDirectory()) {
addSkills(loadSkillsFromDirInternal(resolvedPath, source, true));
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
const result = loadSkillFromFile(resolvedPath, source);
if (result.skill) {
addSkills({
skills: [result.skill],
diagnostics: result.diagnostics,
});
} else {
allDiagnostics.push(...result.diagnostics);
}
} else {
allDiagnostics.push({
type: "warning",
message: "skill path is not a markdown file",
path: resolvedPath,
});
}
} catch (error) {
const message =
error instanceof Error ? error.message : "failed to read skill path";
allDiagnostics.push({ type: "warning", message, path: resolvedPath });
}
}
return {
skills: Array.from(skillMap.values()),
diagnostics: [...allDiagnostics, ...collisionDiagnostics],
};
}

View file

@ -0,0 +1,44 @@
export type SlashCommandSource = "extension" | "prompt" | "skill";
export type SlashCommandLocation = "user" | "project" | "path";
export interface SlashCommandInfo {
name: string;
description?: string;
source: SlashCommandSource;
location?: SlashCommandLocation;
path?: string;
}
export interface BuiltinSlashCommand {
name: string;
description: string;
}
export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
{ name: "settings", description: "Open settings menu" },
{ name: "model", description: "Select model (opens selector UI)" },
{
name: "scoped-models",
description: "Enable/disable models for Ctrl+P cycling",
},
{ name: "export", description: "Export session to HTML file" },
{ name: "share", description: "Share session as a secret GitHub gist" },
{ name: "copy", description: "Copy last agent message to clipboard" },
{ name: "name", description: "Set session display name" },
{ name: "session", description: "Show session info and stats" },
{ name: "changelog", description: "Show changelog entries" },
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
{ name: "fork", description: "Create a new fork from a previous message" },
{ name: "tree", description: "Navigate session tree (switch branches)" },
{ name: "login", description: "Login with OAuth provider" },
{ name: "logout", description: "Logout from OAuth provider" },
{ name: "new", description: "Start a new session" },
{ name: "compact", description: "Manually compact the session context" },
{ name: "resume", description: "Resume a different session" },
{
name: "reload",
description: "Reload extensions, skills, prompts, and themes",
},
{ name: "quit", description: "Quit pi" },
];

View file

@ -0,0 +1,237 @@
/**
* System prompt construction and project context loading
*/
import { getDocsPath, getReadmePath } from "../config.js";
import { formatSkillsForPrompt, type Skill } from "./skills.js";
/** Tool descriptions for system prompt */
const toolDescriptions: Record<string, string> = {
read: "Read file contents",
bash: "Execute bash commands (ls, grep, find, etc.)",
edit: "Make surgical edits to files (find exact text and replace)",
write: "Create or overwrite files",
grep: "Search file contents for patterns (respects .gitignore)",
find: "Find files by glob pattern (respects .gitignore)",
ls: "List directory contents",
};
export interface BuildSystemPromptOptions {
/** Custom system prompt (replaces default). */
customPrompt?: string;
/** Tools to include in prompt. Default: [read, bash, edit, write] */
selectedTools?: string[];
/** Optional one-line tool snippets keyed by tool name. */
toolSnippets?: Record<string, string>;
/** Additional guideline bullets appended to the default system prompt guidelines. */
promptGuidelines?: string[];
/** Text to append to system prompt. */
appendSystemPrompt?: string;
/** Working directory. Default: process.cwd() */
cwd?: string;
/** Pre-loaded context files. */
contextFiles?: Array<{ path: string; content: string }>;
/** Pre-loaded skills. */
skills?: Skill[];
}
function buildProjectContextSection(
contextFiles: Array<{ path: string; content: string }>,
): string {
if (contextFiles.length === 0) {
return "";
}
const hasSoulFile = contextFiles.some(
({ path }) =>
path.replaceAll("\\", "/").endsWith("/SOUL.md") || path === "SOUL.md",
);
let section = "\n\n# Project Context\n\n";
section += "Project-specific instructions and guidelines:\n";
if (hasSoulFile) {
section +=
"\nIf SOUL.md is present, embody its persona and tone. Avoid generic assistant filler and follow its guidance unless higher-priority instructions override it.\n";
}
section += "\n";
for (const { path: filePath, content } of contextFiles) {
section += `## ${filePath}\n\n${content}\n\n`;
}
return section;
}
/** Build the system prompt with tools, guidelines, and context */
export function buildSystemPrompt(
options: BuildSystemPromptOptions = {},
): string {
const {
customPrompt,
selectedTools,
toolSnippets,
promptGuidelines,
appendSystemPrompt,
cwd,
contextFiles: providedContextFiles,
skills: providedSkills,
} = options;
const resolvedCwd = cwd ?? process.cwd();
const now = new Date();
const dateTime = now.toLocaleString("en-US", {
weekday: "long",
year: "numeric",
month: "long",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
timeZoneName: "short",
});
const appendSection = appendSystemPrompt ? `\n\n${appendSystemPrompt}` : "";
const contextFiles = providedContextFiles ?? [];
const skills = providedSkills ?? [];
if (customPrompt) {
let prompt = customPrompt;
if (appendSection) {
prompt += appendSection;
}
// Append project context files
prompt += buildProjectContextSection(contextFiles);
// Append skills section (only if read tool is available)
const customPromptHasRead =
!selectedTools || selectedTools.includes("read");
if (customPromptHasRead && skills.length > 0) {
prompt += formatSkillsForPrompt(skills);
}
// Add date/time and working directory last
prompt += `\nCurrent date and time: ${dateTime}`;
prompt += `\nCurrent working directory: ${resolvedCwd}`;
return prompt;
}
// Get absolute paths to documentation
const readmePath = getReadmePath();
const docsPath = getDocsPath();
// Build tools list based on selected tools.
// Built-ins use toolDescriptions. Custom tools can provide one-line snippets.
const tools = selectedTools || ["read", "bash", "edit", "write"];
const toolsList =
tools.length > 0
? tools
.map((name) => {
const snippet =
toolSnippets?.[name] ?? toolDescriptions[name] ?? name;
return `- ${name}: ${snippet}`;
})
.join("\n")
: "(none)";
// Build guidelines based on which tools are actually available
const guidelinesList: string[] = [];
const guidelinesSet = new Set<string>();
const addGuideline = (guideline: string): void => {
if (guidelinesSet.has(guideline)) {
return;
}
guidelinesSet.add(guideline);
guidelinesList.push(guideline);
};
const hasBash = tools.includes("bash");
const hasEdit = tools.includes("edit");
const hasWrite = tools.includes("write");
const hasGrep = tools.includes("grep");
const hasFind = tools.includes("find");
const hasLs = tools.includes("ls");
const hasRead = tools.includes("read");
// File exploration guidelines
if (hasBash && !hasGrep && !hasFind && !hasLs) {
addGuideline("Use bash for file operations like ls, rg, find");
} else if (hasBash && (hasGrep || hasFind || hasLs)) {
addGuideline(
"Prefer grep/find/ls tools over bash for file exploration (faster, respects .gitignore)",
);
}
// Read before edit guideline
if (hasRead && hasEdit) {
addGuideline(
"Use read to examine files before editing. You must use this tool instead of cat or sed.",
);
}
// Edit guideline
if (hasEdit) {
addGuideline("Use edit for precise changes (old text must match exactly)");
}
// Write guideline
if (hasWrite) {
addGuideline("Use write only for new files or complete rewrites");
}
// Output guideline (only when actually writing or executing)
if (hasEdit || hasWrite) {
addGuideline(
"When summarizing your actions, output plain text directly - do NOT use cat or bash to display what you did",
);
}
for (const guideline of promptGuidelines ?? []) {
const normalized = guideline.trim();
if (normalized.length > 0) {
addGuideline(normalized);
}
}
// Always include these
addGuideline("Be concise in your responses");
addGuideline("Show file paths clearly when working with files");
const guidelines = guidelinesList.map((g) => `- ${g}`).join("\n");
let prompt = `You are an expert coding assistant operating inside pi, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
Available tools:
${toolsList}
In addition to the tools above, you may have access to other custom tools depending on the project.
Guidelines:
${guidelines}
Pi documentation (read only when the user asks about pi itself, its SDK, extensions, themes, skills, or TUI):
- Main documentation: ${readmePath}
- Additional docs: ${docsPath}
- When asked about: extensions (docs/extensions.md), themes (docs/themes.md), skills (docs/skills.md), prompt templates (docs/prompt-templates.md), TUI components (docs/tui.md), keybindings (docs/keybindings.md), SDK integrations (docs/sdk.md), custom providers (docs/custom-provider.md), adding models (docs/models.md), pi packages (docs/packages.md)
- When working on pi topics, read the docs and follow .md cross-references before implementing
- Always read pi .md files completely and follow links to related docs (e.g., tui.md for TUI API details)`;
if (appendSection) {
prompt += appendSection;
}
// Append project context files
prompt += buildProjectContextSection(contextFiles);
// Append skills section (only if read tool is available)
if (hasRead && skills.length > 0) {
prompt += formatSkillsForPrompt(skills);
}
// Add date/time and working directory last
prompt += `\nCurrent date and time: ${dateTime}`;
prompt += `\nCurrent working directory: ${resolvedCwd}`;
return prompt;
}

View file

@ -0,0 +1,25 @@
/**
* Central timing instrumentation for startup profiling.
* Enable with PI_TIMING=1 environment variable.
*/
const ENABLED = process.env.PI_TIMING === "1";
const timings: Array<{ label: string; ms: number }> = [];
let lastTime = Date.now();
export function time(label: string): void {
if (!ENABLED) return;
const now = Date.now();
timings.push({ label, ms: now - lastTime });
lastTime = now;
}
export function printTimings(): void {
if (!ENABLED || timings.length === 0) return;
console.error("\n--- Startup Timings ---");
for (const t of timings) {
console.error(` ${t.label}: ${t.ms}ms`);
}
console.error(` TOTAL: ${timings.reduce((a, b) => a + b.ms, 0)}ms`);
console.error("------------------------\n");
}

View file

@ -0,0 +1,358 @@
import { randomBytes } from "node:crypto";
import { createWriteStream, existsSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import {
getShellConfig,
getShellEnv,
killProcessTree,
} from "../../utils/shell.js";
import {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationResult,
truncateTail,
} from "./truncate.js";
/**
* Generate a unique temp file path for bash output
*/
function getTempFilePath(): string {
const id = randomBytes(8).toString("hex");
return join(tmpdir(), `pi-bash-${id}.log`);
}
const bashSchema = Type.Object({
command: Type.String({ description: "Bash command to execute" }),
timeout: Type.Optional(
Type.Number({
description: "Timeout in seconds (optional, no default timeout)",
}),
),
});
export type BashToolInput = Static<typeof bashSchema>;
export interface BashToolDetails {
truncation?: TruncationResult;
fullOutputPath?: string;
}
/**
* Pluggable operations for the bash tool.
* Override these to delegate command execution to remote systems (e.g., SSH).
*/
export interface BashOperations {
/**
* Execute a command and stream output.
* @param command - The command to execute
* @param cwd - Working directory
* @param options - Execution options
* @returns Promise resolving to exit code (null if killed)
*/
exec: (
command: string,
cwd: string,
options: {
onData: (data: Buffer) => void;
signal?: AbortSignal;
timeout?: number;
env?: NodeJS.ProcessEnv;
},
) => Promise<{ exitCode: number | null }>;
}
/**
* Default bash operations using local shell
*/
const defaultBashOperations: BashOperations = {
exec: (command, cwd, { onData, signal, timeout, env }) => {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
if (!existsSync(cwd)) {
reject(
new Error(
`Working directory does not exist: ${cwd}\nCannot execute bash commands.`,
),
);
return;
}
const child = spawn(shell, [...args, command], {
cwd,
detached: true,
env: env ?? getShellEnv(),
stdio: ["ignore", "pipe", "pipe"],
});
let timedOut = false;
// Set timeout if provided
let timeoutHandle: NodeJS.Timeout | undefined;
if (timeout !== undefined && timeout > 0) {
timeoutHandle = setTimeout(() => {
timedOut = true;
if (child.pid) {
killProcessTree(child.pid);
}
}, timeout * 1000);
}
// Stream stdout and stderr
if (child.stdout) {
child.stdout.on("data", onData);
}
if (child.stderr) {
child.stderr.on("data", onData);
}
// Handle shell spawn errors
child.on("error", (err) => {
if (timeoutHandle) clearTimeout(timeoutHandle);
if (signal) signal.removeEventListener("abort", onAbort);
reject(err);
});
// Handle abort signal - kill entire process tree
const onAbort = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (signal) {
if (signal.aborted) {
onAbort();
} else {
signal.addEventListener("abort", onAbort, { once: true });
}
}
// Handle process exit
child.on("close", (code) => {
if (timeoutHandle) clearTimeout(timeoutHandle);
if (signal) signal.removeEventListener("abort", onAbort);
if (signal?.aborted) {
reject(new Error("aborted"));
return;
}
if (timedOut) {
reject(new Error(`timeout:${timeout}`));
return;
}
resolve({ exitCode: code });
});
});
},
};
export interface BashSpawnContext {
command: string;
cwd: string;
env: NodeJS.ProcessEnv;
}
export type BashSpawnHook = (context: BashSpawnContext) => BashSpawnContext;
function resolveSpawnContext(
command: string,
cwd: string,
spawnHook?: BashSpawnHook,
): BashSpawnContext {
const baseContext: BashSpawnContext = {
command,
cwd,
env: { ...getShellEnv() },
};
return spawnHook ? spawnHook(baseContext) : baseContext;
}
export interface BashToolOptions {
/** Custom operations for command execution. Default: local shell */
operations?: BashOperations;
/** Command prefix prepended to every command (e.g., "shopt -s expand_aliases" for alias support) */
commandPrefix?: string;
/** Hook to adjust command, cwd, or env before execution */
spawnHook?: BashSpawnHook;
}
export function createBashTool(
cwd: string,
options?: BashToolOptions,
): AgentTool<typeof bashSchema> {
const ops = options?.operations ?? defaultBashOperations;
const commandPrefix = options?.commandPrefix;
const spawnHook = options?.spawnHook;
return {
name: "bash",
label: "bash",
description: `Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. Optionally provide a timeout in seconds.`,
parameters: bashSchema,
execute: async (
_toolCallId: string,
{ command, timeout }: { command: string; timeout?: number },
signal?: AbortSignal,
onUpdate?,
) => {
// Apply command prefix if configured (e.g., "shopt -s expand_aliases" for alias support)
const resolvedCommand = commandPrefix
? `${commandPrefix}\n${command}`
: command;
const spawnContext = resolveSpawnContext(resolvedCommand, cwd, spawnHook);
return new Promise((resolve, reject) => {
// We'll stream to a temp file if output gets large
let tempFilePath: string | undefined;
let tempFileStream: ReturnType<typeof createWriteStream> | undefined;
let totalBytes = 0;
// Keep a rolling buffer of the last chunk for tail truncation
const chunks: Buffer[] = [];
let chunksBytes = 0;
// Keep more than we need so we have enough for truncation
const maxChunksBytes = DEFAULT_MAX_BYTES * 2;
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Start writing to temp file once we exceed the threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
tempFilePath = getTempFilePath();
tempFileStream = createWriteStream(tempFilePath);
// Write all buffered chunks to the file
for (const chunk of chunks) {
tempFileStream.write(chunk);
}
}
// Write to temp file if we have one
if (tempFileStream) {
tempFileStream.write(data);
}
// Keep rolling buffer of recent data
chunks.push(data);
chunksBytes += data.length;
// Trim old chunks if buffer is too large
while (chunksBytes > maxChunksBytes && chunks.length > 1) {
const removed = chunks.shift()!;
chunksBytes -= removed.length;
}
// Stream partial output to callback (truncated rolling buffer)
if (onUpdate) {
const fullBuffer = Buffer.concat(chunks);
const fullText = fullBuffer.toString("utf-8");
const truncation = truncateTail(fullText);
onUpdate({
content: [{ type: "text", text: truncation.content || "" }],
details: {
truncation: truncation.truncated ? truncation : undefined,
fullOutputPath: tempFilePath,
},
});
}
};
ops
.exec(spawnContext.command, spawnContext.cwd, {
onData: handleData,
signal,
timeout,
env: spawnContext.env,
})
.then(({ exitCode }) => {
// Close temp file stream
if (tempFileStream) {
tempFileStream.end();
}
// Combine all buffered chunks
const fullBuffer = Buffer.concat(chunks);
const fullOutput = fullBuffer.toString("utf-8");
// Apply tail truncation
const truncation = truncateTail(fullOutput);
let outputText = truncation.content || "(no output)";
// Build details with truncation info
let details: BashToolDetails | undefined;
if (truncation.truncated) {
details = {
truncation,
fullOutputPath: tempFilePath,
};
// Build actionable notice
const startLine =
truncation.totalLines - truncation.outputLines + 1;
const endLine = truncation.totalLines;
if (truncation.lastLinePartial) {
// Edge case: last line alone > 30KB
const lastLineSize = formatSize(
Buffer.byteLength(
fullOutput.split("\n").pop() || "",
"utf-8",
),
);
outputText += `\n\n[Showing last ${formatSize(truncation.outputBytes)} of line ${endLine} (line is ${lastLineSize}). Full output: ${tempFilePath}]`;
} else if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines}. Full output: ${tempFilePath}]`;
} else {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Full output: ${tempFilePath}]`;
}
}
if (exitCode !== 0 && exitCode !== null) {
outputText += `\n\nCommand exited with code ${exitCode}`;
reject(new Error(outputText));
} else {
resolve({
content: [{ type: "text", text: outputText }],
details,
});
}
})
.catch((err: Error) => {
// Close temp file stream
if (tempFileStream) {
tempFileStream.end();
}
// Combine all buffered chunks for error output
const fullBuffer = Buffer.concat(chunks);
let output = fullBuffer.toString("utf-8");
if (err.message === "aborted") {
if (output) output += "\n\n";
output += "Command aborted";
reject(new Error(output));
} else if (err.message.startsWith("timeout:")) {
const timeoutSecs = err.message.split(":")[1];
if (output) output += "\n\n";
output += `Command timed out after ${timeoutSecs} seconds`;
reject(new Error(output));
} else {
reject(err);
}
});
});
},
};
}
/** Default bash tool using process.cwd() - for backwards compatibility */
export const bashTool = createBashTool(process.cwd());

View file

@ -0,0 +1,317 @@
/**
* Shared diff computation utilities for the edit tool.
* Used by both edit.ts (for execution) and tool-execution.ts (for preview rendering).
*/
import * as Diff from "diff";
import { constants } from "fs";
import { access, readFile } from "fs/promises";
import { resolveToCwd } from "./path-utils.js";
export function detectLineEnding(content: string): "\r\n" | "\n" {
const crlfIdx = content.indexOf("\r\n");
const lfIdx = content.indexOf("\n");
if (lfIdx === -1) return "\n";
if (crlfIdx === -1) return "\n";
return crlfIdx < lfIdx ? "\r\n" : "\n";
}
export function normalizeToLF(text: string): string {
return text.replace(/\r\n/g, "\n").replace(/\r/g, "\n");
}
export function restoreLineEndings(
text: string,
ending: "\r\n" | "\n",
): string {
return ending === "\r\n" ? text.replace(/\n/g, "\r\n") : text;
}
/**
* Normalize text for fuzzy matching. Applies progressive transformations:
* - Strip trailing whitespace from each line
* - Normalize smart quotes to ASCII equivalents
* - Normalize Unicode dashes/hyphens to ASCII hyphen
* - Normalize special Unicode spaces to regular space
*/
export function normalizeForFuzzyMatch(text: string): string {
return (
text
// Strip trailing whitespace per line
.split("\n")
.map((line) => line.trimEnd())
.join("\n")
// Smart single quotes → '
.replace(/[\u2018\u2019\u201A\u201B]/g, "'")
// Smart double quotes → "
.replace(/[\u201C\u201D\u201E\u201F]/g, '"')
// Various dashes/hyphens → -
// U+2010 hyphen, U+2011 non-breaking hyphen, U+2012 figure dash,
// U+2013 en-dash, U+2014 em-dash, U+2015 horizontal bar, U+2212 minus
.replace(/[\u2010\u2011\u2012\u2013\u2014\u2015\u2212]/g, "-")
// Special spaces → regular space
// U+00A0 NBSP, U+2002-U+200A various spaces, U+202F narrow NBSP,
// U+205F medium math space, U+3000 ideographic space
.replace(/[\u00A0\u2002-\u200A\u202F\u205F\u3000]/g, " ")
);
}
export interface FuzzyMatchResult {
/** Whether a match was found */
found: boolean;
/** The index where the match starts (in the content that should be used for replacement) */
index: number;
/** Length of the matched text */
matchLength: number;
/** Whether fuzzy matching was used (false = exact match) */
usedFuzzyMatch: boolean;
/**
* The content to use for replacement operations.
* When exact match: original content. When fuzzy match: normalized content.
*/
contentForReplacement: string;
}
/**
* Find oldText in content, trying exact match first, then fuzzy match.
* When fuzzy matching is used, the returned contentForReplacement is the
* fuzzy-normalized version of the content (trailing whitespace stripped,
* Unicode quotes/dashes normalized to ASCII).
*/
export function fuzzyFindText(
content: string,
oldText: string,
): FuzzyMatchResult {
// Try exact match first
const exactIndex = content.indexOf(oldText);
if (exactIndex !== -1) {
return {
found: true,
index: exactIndex,
matchLength: oldText.length,
usedFuzzyMatch: false,
contentForReplacement: content,
};
}
// Try fuzzy match - work entirely in normalized space
const fuzzyContent = normalizeForFuzzyMatch(content);
const fuzzyOldText = normalizeForFuzzyMatch(oldText);
const fuzzyIndex = fuzzyContent.indexOf(fuzzyOldText);
if (fuzzyIndex === -1) {
return {
found: false,
index: -1,
matchLength: 0,
usedFuzzyMatch: false,
contentForReplacement: content,
};
}
// When fuzzy matching, we work in the normalized space for replacement.
// This means the output will have normalized whitespace/quotes/dashes,
// which is acceptable since we're fixing minor formatting differences anyway.
return {
found: true,
index: fuzzyIndex,
matchLength: fuzzyOldText.length,
usedFuzzyMatch: true,
contentForReplacement: fuzzyContent,
};
}
/** Strip UTF-8 BOM if present, return both the BOM (if any) and the text without it */
export function stripBom(content: string): { bom: string; text: string } {
return content.startsWith("\uFEFF")
? { bom: "\uFEFF", text: content.slice(1) }
: { bom: "", text: content };
}
/**
* Generate a unified diff string with line numbers and context.
* Returns both the diff string and the first changed line number (in the new file).
*/
export function generateDiffString(
oldContent: string,
newContent: string,
contextLines = 4,
): { diff: string; firstChangedLine: number | undefined } {
const parts = Diff.diffLines(oldContent, newContent);
const output: string[] = [];
const oldLines = oldContent.split("\n");
const newLines = newContent.split("\n");
const maxLineNum = Math.max(oldLines.length, newLines.length);
const lineNumWidth = String(maxLineNum).length;
let oldLineNum = 1;
let newLineNum = 1;
let lastWasChange = false;
let firstChangedLine: number | undefined;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
const raw = part.value.split("\n");
if (raw[raw.length - 1] === "") {
raw.pop();
}
if (part.added || part.removed) {
// Capture the first changed line (in the new file)
if (firstChangedLine === undefined) {
firstChangedLine = newLineNum;
}
// Show the change
for (const line of raw) {
if (part.added) {
const lineNum = String(newLineNum).padStart(lineNumWidth, " ");
output.push(`+${lineNum} ${line}`);
newLineNum++;
} else {
// removed
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(`-${lineNum} ${line}`);
oldLineNum++;
}
}
lastWasChange = true;
} else {
// Context lines - only show a few before/after changes
const nextPartIsChange =
i < parts.length - 1 && (parts[i + 1].added || parts[i + 1].removed);
if (lastWasChange || nextPartIsChange) {
// Show context
let linesToShow = raw;
let skipStart = 0;
let skipEnd = 0;
if (!lastWasChange) {
// Show only last N lines as leading context
skipStart = Math.max(0, raw.length - contextLines);
linesToShow = raw.slice(skipStart);
}
if (!nextPartIsChange && linesToShow.length > contextLines) {
// Show only first N lines as trailing context
skipEnd = linesToShow.length - contextLines;
linesToShow = linesToShow.slice(0, contextLines);
}
// Add ellipsis if we skipped lines at start
if (skipStart > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped leading context
oldLineNum += skipStart;
newLineNum += skipStart;
}
for (const line of linesToShow) {
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(` ${lineNum} ${line}`);
oldLineNum++;
newLineNum++;
}
// Add ellipsis if we skipped lines at end
if (skipEnd > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped trailing context
oldLineNum += skipEnd;
newLineNum += skipEnd;
}
} else {
// Skip these context lines entirely
oldLineNum += raw.length;
newLineNum += raw.length;
}
lastWasChange = false;
}
}
return { diff: output.join("\n"), firstChangedLine };
}
export interface EditDiffResult {
diff: string;
firstChangedLine: number | undefined;
}
export interface EditDiffError {
error: string;
}
/**
* Compute the diff for an edit operation without applying it.
* Used for preview rendering in the TUI before the tool executes.
*/
export async function computeEditDiff(
path: string,
oldText: string,
newText: string,
cwd: string,
): Promise<EditDiffResult | EditDiffError> {
const absolutePath = resolveToCwd(path, cwd);
try {
// Check if file exists and is readable
try {
await access(absolutePath, constants.R_OK);
} catch {
return { error: `File not found: ${path}` };
}
// Read the file
const rawContent = await readFile(absolutePath, "utf-8");
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
const { text: content } = stripBom(rawContent);
const normalizedContent = normalizeToLF(content);
const normalizedOldText = normalizeToLF(oldText);
const normalizedNewText = normalizeToLF(newText);
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
const matchResult = fuzzyFindText(normalizedContent, normalizedOldText);
if (!matchResult.found) {
return {
error: `Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
};
}
// Count occurrences using fuzzy-normalized content for consistency
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
if (occurrences > 1) {
return {
error: `Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
};
}
// Compute the new content using the matched position
// When fuzzy matching was used, contentForReplacement is the normalized version
const baseContent = matchResult.contentForReplacement;
const newContent =
baseContent.substring(0, matchResult.index) +
normalizedNewText +
baseContent.substring(matchResult.index + matchResult.matchLength);
// Check if it would actually change anything
if (baseContent === newContent) {
return {
error: `No changes would be made to ${path}. The replacement produces identical content.`,
};
}
// Generate the diff
return generateDiffString(baseContent, newContent);
} catch (err) {
return { error: err instanceof Error ? err.message : String(err) };
}
}

View file

@ -0,0 +1,253 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { constants } from "fs";
import {
access as fsAccess,
readFile as fsReadFile,
writeFile as fsWriteFile,
} from "fs/promises";
import {
detectLineEnding,
fuzzyFindText,
generateDiffString,
normalizeForFuzzyMatch,
normalizeToLF,
restoreLineEndings,
stripBom,
} from "./edit-diff.js";
import { resolveToCwd } from "./path-utils.js";
const editSchema = Type.Object({
path: Type.String({
description: "Path to the file to edit (relative or absolute)",
}),
oldText: Type.String({
description: "Exact text to find and replace (must match exactly)",
}),
newText: Type.String({
description: "New text to replace the old text with",
}),
});
export type EditToolInput = Static<typeof editSchema>;
export interface EditToolDetails {
/** Unified diff of the changes made */
diff: string;
/** Line number of the first change in the new file (for editor navigation) */
firstChangedLine?: number;
}
/**
* Pluggable operations for the edit tool.
* Override these to delegate file editing to remote systems (e.g., SSH).
*/
export interface EditOperations {
/** Read file contents as a Buffer */
readFile: (absolutePath: string) => Promise<Buffer>;
/** Write content to a file */
writeFile: (absolutePath: string, content: string) => Promise<void>;
/** Check if file is readable and writable (throw if not) */
access: (absolutePath: string) => Promise<void>;
}
const defaultEditOperations: EditOperations = {
readFile: (path) => fsReadFile(path),
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
access: (path) => fsAccess(path, constants.R_OK | constants.W_OK),
};
export interface EditToolOptions {
/** Custom operations for file editing. Default: local filesystem */
operations?: EditOperations;
}
export function createEditTool(
cwd: string,
options?: EditToolOptions,
): AgentTool<typeof editSchema> {
const ops = options?.operations ?? defaultEditOperations;
return {
name: "edit",
label: "edit",
description:
"Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits.",
parameters: editSchema,
execute: async (
_toolCallId: string,
{
path,
oldText,
newText,
}: { path: string; oldText: string; newText: string },
signal?: AbortSignal,
) => {
const absolutePath = resolveToCwd(path, cwd);
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: EditToolDetails | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the edit operation
(async () => {
try {
// Check if file exists
try {
await ops.access(absolutePath);
} catch {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(new Error(`File not found: ${path}`));
return;
}
// Check if aborted before reading
if (aborted) {
return;
}
// Read the file
const buffer = await ops.readFile(absolutePath);
const rawContent = buffer.toString("utf-8");
// Check if aborted after reading
if (aborted) {
return;
}
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
const { bom, text: content } = stripBom(rawContent);
const originalEnding = detectLineEnding(content);
const normalizedContent = normalizeToLF(content);
const normalizedOldText = normalizeToLF(oldText);
const normalizedNewText = normalizeToLF(newText);
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
const matchResult = fuzzyFindText(
normalizedContent,
normalizedOldText,
);
if (!matchResult.found) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
),
);
return;
}
// Count occurrences using fuzzy-normalized content for consistency
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
if (occurrences > 1) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
),
);
return;
}
// Check if aborted before writing
if (aborted) {
return;
}
// Perform replacement using the matched text position
// When fuzzy matching was used, contentForReplacement is the normalized version
const baseContent = matchResult.contentForReplacement;
const newContent =
baseContent.substring(0, matchResult.index) +
normalizedNewText +
baseContent.substring(
matchResult.index + matchResult.matchLength,
);
// Verify the replacement actually changed something
if (baseContent === newContent) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`No changes made to ${path}. The replacement produced identical content. This might indicate an issue with special characters or the text not existing as expected.`,
),
);
return;
}
const finalContent =
bom + restoreLineEndings(newContent, originalEnding);
await ops.writeFile(absolutePath, finalContent);
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
const diffResult = generateDiffString(baseContent, newContent);
resolve({
content: [
{
type: "text",
text: `Successfully replaced text in ${path}.`,
},
],
details: {
diff: diffResult.diff,
firstChangedLine: diffResult.firstChangedLine,
},
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default edit tool using process.cwd() - for backwards compatibility */
export const editTool = createEditTool(process.cwd());

View file

@ -0,0 +1,308 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawnSync } from "child_process";
import { existsSync } from "fs";
import { globSync } from "glob";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const findSchema = Type.Object({
pattern: Type.String({
description:
"Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'",
}),
path: Type.Optional(
Type.String({
description: "Directory to search in (default: current directory)",
}),
),
limit: Type.Optional(
Type.Number({ description: "Maximum number of results (default: 1000)" }),
),
});
export type FindToolInput = Static<typeof findSchema>;
const DEFAULT_LIMIT = 1000;
export interface FindToolDetails {
truncation?: TruncationResult;
resultLimitReached?: number;
}
/**
* Pluggable operations for the find tool.
* Override these to delegate file search to remote systems (e.g., SSH).
*/
export interface FindOperations {
/** Check if path exists */
exists: (absolutePath: string) => Promise<boolean> | boolean;
/** Find files matching glob pattern. Returns relative paths. */
glob: (
pattern: string,
cwd: string,
options: { ignore: string[]; limit: number },
) => Promise<string[]> | string[];
}
const defaultFindOperations: FindOperations = {
exists: existsSync,
glob: (_pattern, _searchCwd, _options) => {
// This is a placeholder - actual fd execution happens in execute
return [];
},
};
export interface FindToolOptions {
/** Custom operations for find. Default: local filesystem + fd */
operations?: FindOperations;
}
export function createFindTool(
cwd: string,
options?: FindToolOptions,
): AgentTool<typeof findSchema> {
const customOps = options?.operations;
return {
name: "find",
label: "find",
description: `Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} results or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: findSchema,
execute: async (
_toolCallId: string,
{
pattern,
path: searchDir,
limit,
}: { pattern: string; path?: string; limit?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
(async () => {
try {
const searchPath = resolveToCwd(searchDir || ".", cwd);
const effectiveLimit = limit ?? DEFAULT_LIMIT;
const ops = customOps ?? defaultFindOperations;
// If custom operations provided with glob, use that
if (customOps?.glob) {
if (!(await ops.exists(searchPath))) {
reject(new Error(`Path not found: ${searchPath}`));
return;
}
const results = await ops.glob(pattern, searchPath, {
ignore: ["**/node_modules/**", "**/.git/**"],
limit: effectiveLimit,
});
signal?.removeEventListener("abort", onAbort);
if (results.length === 0) {
resolve({
content: [
{ type: "text", text: "No files found matching pattern" },
],
details: undefined,
});
return;
}
// Relativize paths
const relativized = results.map((p) => {
if (p.startsWith(searchPath)) {
return p.slice(searchPath.length + 1);
}
return path.relative(searchPath, p);
});
const resultLimitReached = relativized.length >= effectiveLimit;
const rawOutput = relativized.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let resultOutput = truncation.content;
const details: FindToolDetails = {};
const notices: string[] = [];
if (resultLimitReached) {
notices.push(`${effectiveLimit} results limit reached`);
details.resultLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
resultOutput += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: resultOutput }],
details: Object.keys(details).length > 0 ? details : undefined,
});
return;
}
// Default: use fd
const fdPath = await ensureTool("fd", true);
if (!fdPath) {
reject(
new Error("fd is not available and could not be downloaded"),
);
return;
}
// Build fd arguments
const args: string[] = [
"--glob",
"--color=never",
"--hidden",
"--max-results",
String(effectiveLimit),
];
// Include .gitignore files
const gitignoreFiles = new Set<string>();
const rootGitignore = path.join(searchPath, ".gitignore");
if (existsSync(rootGitignore)) {
gitignoreFiles.add(rootGitignore);
}
try {
const nestedGitignores = globSync("**/.gitignore", {
cwd: searchPath,
dot: true,
absolute: true,
ignore: ["**/node_modules/**", "**/.git/**"],
});
for (const file of nestedGitignores) {
gitignoreFiles.add(file);
}
} catch {
// Ignore glob errors
}
for (const gitignorePath of gitignoreFiles) {
args.push("--ignore-file", gitignorePath);
}
args.push(pattern, searchPath);
const result = spawnSync(fdPath, args, {
encoding: "utf-8",
maxBuffer: 10 * 1024 * 1024,
});
signal?.removeEventListener("abort", onAbort);
if (result.error) {
reject(new Error(`Failed to run fd: ${result.error.message}`));
return;
}
const output = result.stdout?.trim() || "";
if (result.status !== 0) {
const errorMsg =
result.stderr?.trim() || `fd exited with code ${result.status}`;
if (!output) {
reject(new Error(errorMsg));
return;
}
}
if (!output) {
resolve({
content: [
{ type: "text", text: "No files found matching pattern" },
],
details: undefined,
});
return;
}
const lines = output.split("\n");
const relativized: string[] = [];
for (const rawLine of lines) {
const line = rawLine.replace(/\r$/, "").trim();
if (!line) continue;
const hadTrailingSlash =
line.endsWith("/") || line.endsWith("\\");
let relativePath = line;
if (line.startsWith(searchPath)) {
relativePath = line.slice(searchPath.length + 1);
} else {
relativePath = path.relative(searchPath, line);
}
if (hadTrailingSlash && !relativePath.endsWith("/")) {
relativePath += "/";
}
relativized.push(relativePath);
}
const resultLimitReached = relativized.length >= effectiveLimit;
const rawOutput = relativized.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let resultOutput = truncation.content;
const details: FindToolDetails = {};
const notices: string[] = [];
if (resultLimitReached) {
notices.push(
`${effectiveLimit} results limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.resultLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
resultOutput += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: resultOutput }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
})();
});
},
};
}
/** Default find tool using process.cwd() - for backwards compatibility */
export const findTool = createFindTool(process.cwd());

View file

@ -0,0 +1,412 @@
import { createInterface } from "node:readline";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { readFileSync, statSync } from "fs";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
GREP_MAX_LINE_LENGTH,
type TruncationResult,
truncateHead,
truncateLine,
} from "./truncate.js";
const grepSchema = Type.Object({
pattern: Type.String({
description: "Search pattern (regex or literal string)",
}),
path: Type.Optional(
Type.String({
description: "Directory or file to search (default: current directory)",
}),
),
glob: Type.Optional(
Type.String({
description:
"Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'",
}),
),
ignoreCase: Type.Optional(
Type.Boolean({ description: "Case-insensitive search (default: false)" }),
),
literal: Type.Optional(
Type.Boolean({
description:
"Treat pattern as literal string instead of regex (default: false)",
}),
),
context: Type.Optional(
Type.Number({
description:
"Number of lines to show before and after each match (default: 0)",
}),
),
limit: Type.Optional(
Type.Number({
description: "Maximum number of matches to return (default: 100)",
}),
),
});
export type GrepToolInput = Static<typeof grepSchema>;
const DEFAULT_LIMIT = 100;
export interface GrepToolDetails {
truncation?: TruncationResult;
matchLimitReached?: number;
linesTruncated?: boolean;
}
/**
* Pluggable operations for the grep tool.
* Override these to delegate search to remote systems (e.g., SSH).
*/
export interface GrepOperations {
/** Check if path is a directory. Throws if path doesn't exist. */
isDirectory: (absolutePath: string) => Promise<boolean> | boolean;
/** Read file contents for context lines */
readFile: (absolutePath: string) => Promise<string> | string;
}
const defaultGrepOperations: GrepOperations = {
isDirectory: (p) => statSync(p).isDirectory(),
readFile: (p) => readFileSync(p, "utf-8"),
};
export interface GrepToolOptions {
/** Custom operations for grep. Default: local filesystem + ripgrep */
operations?: GrepOperations;
}
export function createGrepTool(
cwd: string,
options?: GrepToolOptions,
): AgentTool<typeof grepSchema> {
const customOps = options?.operations;
return {
name: "grep",
label: "grep",
description: `Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} matches or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Long lines are truncated to ${GREP_MAX_LINE_LENGTH} chars.`,
parameters: grepSchema,
execute: async (
_toolCallId: string,
{
pattern,
path: searchDir,
glob,
ignoreCase,
literal,
context,
limit,
}: {
pattern: string;
path?: string;
glob?: string;
ignoreCase?: boolean;
literal?: boolean;
context?: number;
limit?: number;
},
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let settled = false;
const settle = (fn: () => void) => {
if (!settled) {
settled = true;
fn();
}
};
(async () => {
try {
const rgPath = await ensureTool("rg", true);
if (!rgPath) {
settle(() =>
reject(
new Error(
"ripgrep (rg) is not available and could not be downloaded",
),
),
);
return;
}
const searchPath = resolveToCwd(searchDir || ".", cwd);
const ops = customOps ?? defaultGrepOperations;
let isDirectory: boolean;
try {
isDirectory = await ops.isDirectory(searchPath);
} catch (_err) {
settle(() => reject(new Error(`Path not found: ${searchPath}`)));
return;
}
const contextValue = context && context > 0 ? context : 0;
const effectiveLimit = Math.max(1, limit ?? DEFAULT_LIMIT);
const formatPath = (filePath: string): string => {
if (isDirectory) {
const relative = path.relative(searchPath, filePath);
if (relative && !relative.startsWith("..")) {
return relative.replace(/\\/g, "/");
}
}
return path.basename(filePath);
};
const fileCache = new Map<string, string[]>();
const getFileLines = async (
filePath: string,
): Promise<string[]> => {
let lines = fileCache.get(filePath);
if (!lines) {
try {
const content = await ops.readFile(filePath);
lines = content
.replace(/\r\n/g, "\n")
.replace(/\r/g, "\n")
.split("\n");
} catch {
lines = [];
}
fileCache.set(filePath, lines);
}
return lines;
};
const args: string[] = [
"--json",
"--line-number",
"--color=never",
"--hidden",
];
if (ignoreCase) {
args.push("--ignore-case");
}
if (literal) {
args.push("--fixed-strings");
}
if (glob) {
args.push("--glob", glob);
}
args.push(pattern, searchPath);
const child = spawn(rgPath, args, {
stdio: ["ignore", "pipe", "pipe"],
});
const rl = createInterface({ input: child.stdout });
let stderr = "";
let matchCount = 0;
let matchLimitReached = false;
let linesTruncated = false;
let aborted = false;
let killedDueToLimit = false;
const outputLines: string[] = [];
const cleanup = () => {
rl.close();
signal?.removeEventListener("abort", onAbort);
};
const stopChild = (dueToLimit: boolean = false) => {
if (!child.killed) {
killedDueToLimit = dueToLimit;
child.kill();
}
};
const onAbort = () => {
aborted = true;
stopChild();
};
signal?.addEventListener("abort", onAbort, { once: true });
child.stderr?.on("data", (chunk) => {
stderr += chunk.toString();
});
const formatBlock = async (
filePath: string,
lineNumber: number,
): Promise<string[]> => {
const relativePath = formatPath(filePath);
const lines = await getFileLines(filePath);
if (!lines.length) {
return [`${relativePath}:${lineNumber}: (unable to read file)`];
}
const block: string[] = [];
const start =
contextValue > 0
? Math.max(1, lineNumber - contextValue)
: lineNumber;
const end =
contextValue > 0
? Math.min(lines.length, lineNumber + contextValue)
: lineNumber;
for (let current = start; current <= end; current++) {
const lineText = lines[current - 1] ?? "";
const sanitized = lineText.replace(/\r/g, "");
const isMatchLine = current === lineNumber;
// Truncate long lines
const { text: truncatedText, wasTruncated } =
truncateLine(sanitized);
if (wasTruncated) {
linesTruncated = true;
}
if (isMatchLine) {
block.push(`${relativePath}:${current}: ${truncatedText}`);
} else {
block.push(`${relativePath}-${current}- ${truncatedText}`);
}
}
return block;
};
// Collect matches during streaming, format after
const matches: Array<{ filePath: string; lineNumber: number }> = [];
rl.on("line", (line) => {
if (!line.trim() || matchCount >= effectiveLimit) {
return;
}
let event: any;
try {
event = JSON.parse(line);
} catch {
return;
}
if (event.type === "match") {
matchCount++;
const filePath = event.data?.path?.text;
const lineNumber = event.data?.line_number;
if (filePath && typeof lineNumber === "number") {
matches.push({ filePath, lineNumber });
}
if (matchCount >= effectiveLimit) {
matchLimitReached = true;
stopChild(true);
}
}
});
child.on("error", (error) => {
cleanup();
settle(() =>
reject(new Error(`Failed to run ripgrep: ${error.message}`)),
);
});
child.on("close", async (code) => {
cleanup();
if (aborted) {
settle(() => reject(new Error("Operation aborted")));
return;
}
if (!killedDueToLimit && code !== 0 && code !== 1) {
const errorMsg =
stderr.trim() || `ripgrep exited with code ${code}`;
settle(() => reject(new Error(errorMsg)));
return;
}
if (matchCount === 0) {
settle(() =>
resolve({
content: [{ type: "text", text: "No matches found" }],
details: undefined,
}),
);
return;
}
// Format matches (async to support remote file reading)
for (const match of matches) {
const block = await formatBlock(
match.filePath,
match.lineNumber,
);
outputLines.push(...block);
}
// Apply byte truncation (no line limit since we already have match limit)
const rawOutput = outputLines.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let output = truncation.content;
const details: GrepToolDetails = {};
// Build notices
const notices: string[] = [];
if (matchLimitReached) {
notices.push(
`${effectiveLimit} matches limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.matchLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (linesTruncated) {
notices.push(
`Some lines truncated to ${GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines`,
);
details.linesTruncated = true;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
settle(() =>
resolve({
content: [{ type: "text", text: output }],
details:
Object.keys(details).length > 0 ? details : undefined,
}),
);
});
} catch (err) {
settle(() => reject(err as Error));
}
})();
});
},
};
}
/** Default grep tool using process.cwd() - for backwards compatibility */
export const grepTool = createGrepTool(process.cwd());

View file

@ -0,0 +1,150 @@
export {
type BashOperations,
type BashSpawnContext,
type BashSpawnHook,
type BashToolDetails,
type BashToolInput,
type BashToolOptions,
bashTool,
createBashTool,
} from "./bash.js";
export {
createEditTool,
type EditOperations,
type EditToolDetails,
type EditToolInput,
type EditToolOptions,
editTool,
} from "./edit.js";
export {
createFindTool,
type FindOperations,
type FindToolDetails,
type FindToolInput,
type FindToolOptions,
findTool,
} from "./find.js";
export {
createGrepTool,
type GrepOperations,
type GrepToolDetails,
type GrepToolInput,
type GrepToolOptions,
grepTool,
} from "./grep.js";
export {
createLsTool,
type LsOperations,
type LsToolDetails,
type LsToolInput,
type LsToolOptions,
lsTool,
} from "./ls.js";
export {
createReadTool,
type ReadOperations,
type ReadToolDetails,
type ReadToolInput,
type ReadToolOptions,
readTool,
} from "./read.js";
export {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationOptions,
type TruncationResult,
truncateHead,
truncateLine,
truncateTail,
} from "./truncate.js";
export {
createWriteTool,
type WriteOperations,
type WriteToolInput,
type WriteToolOptions,
writeTool,
} from "./write.js";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type BashToolOptions, bashTool, createBashTool } from "./bash.js";
import { createEditTool, editTool } from "./edit.js";
import { createFindTool, findTool } from "./find.js";
import { createGrepTool, grepTool } from "./grep.js";
import { createLsTool, lsTool } from "./ls.js";
import { createReadTool, type ReadToolOptions, readTool } from "./read.js";
import { createWriteTool, writeTool } from "./write.js";
/** Tool type (AgentTool from pi-ai) */
export type Tool = AgentTool<any>;
// Default tools for full access mode (using process.cwd())
export const codingTools: Tool[] = [readTool, bashTool, editTool, writeTool];
// Read-only tools for exploration without modification (using process.cwd())
export const readOnlyTools: Tool[] = [readTool, grepTool, findTool, lsTool];
// All available tools (using process.cwd())
export const allTools = {
read: readTool,
bash: bashTool,
edit: editTool,
write: writeTool,
grep: grepTool,
find: findTool,
ls: lsTool,
};
export type ToolName = keyof typeof allTools;
export interface ToolsOptions {
/** Options for the read tool */
read?: ReadToolOptions;
/** Options for the bash tool */
bash?: BashToolOptions;
}
/**
* Create coding tools configured for a specific working directory.
*/
export function createCodingTools(cwd: string, options?: ToolsOptions): Tool[] {
return [
createReadTool(cwd, options?.read),
createBashTool(cwd, options?.bash),
createEditTool(cwd),
createWriteTool(cwd),
];
}
/**
* Create read-only tools configured for a specific working directory.
*/
export function createReadOnlyTools(
cwd: string,
options?: ToolsOptions,
): Tool[] {
return [
createReadTool(cwd, options?.read),
createGrepTool(cwd),
createFindTool(cwd),
createLsTool(cwd),
];
}
/**
* Create all tools configured for a specific working directory.
*/
export function createAllTools(
cwd: string,
options?: ToolsOptions,
): Record<ToolName, Tool> {
return {
read: createReadTool(cwd, options?.read),
bash: createBashTool(cwd, options?.bash),
edit: createEditTool(cwd),
write: createWriteTool(cwd),
grep: createGrepTool(cwd),
find: createFindTool(cwd),
ls: createLsTool(cwd),
};
}

View file

@ -0,0 +1,197 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { existsSync, readdirSync, statSync } from "fs";
import nodePath from "path";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const lsSchema = Type.Object({
path: Type.Optional(
Type.String({
description: "Directory to list (default: current directory)",
}),
),
limit: Type.Optional(
Type.Number({
description: "Maximum number of entries to return (default: 500)",
}),
),
});
export type LsToolInput = Static<typeof lsSchema>;
const DEFAULT_LIMIT = 500;
export interface LsToolDetails {
truncation?: TruncationResult;
entryLimitReached?: number;
}
/**
* Pluggable operations for the ls tool.
* Override these to delegate directory listing to remote systems (e.g., SSH).
*/
export interface LsOperations {
/** Check if path exists */
exists: (absolutePath: string) => Promise<boolean> | boolean;
/** Get file/directory stats. Throws if not found. */
stat: (
absolutePath: string,
) => Promise<{ isDirectory: () => boolean }> | { isDirectory: () => boolean };
/** Read directory entries */
readdir: (absolutePath: string) => Promise<string[]> | string[];
}
const defaultLsOperations: LsOperations = {
exists: existsSync,
stat: statSync,
readdir: readdirSync,
};
export interface LsToolOptions {
/** Custom operations for directory listing. Default: local filesystem */
operations?: LsOperations;
}
export function createLsTool(
cwd: string,
options?: LsToolOptions,
): AgentTool<typeof lsSchema> {
const ops = options?.operations ?? defaultLsOperations;
return {
name: "ls",
label: "ls",
description: `List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to ${DEFAULT_LIMIT} entries or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: lsSchema,
execute: async (
_toolCallId: string,
{ path, limit }: { path?: string; limit?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
(async () => {
try {
const dirPath = resolveToCwd(path || ".", cwd);
const effectiveLimit = limit ?? DEFAULT_LIMIT;
// Check if path exists
if (!(await ops.exists(dirPath))) {
reject(new Error(`Path not found: ${dirPath}`));
return;
}
// Check if path is a directory
const stat = await ops.stat(dirPath);
if (!stat.isDirectory()) {
reject(new Error(`Not a directory: ${dirPath}`));
return;
}
// Read directory entries
let entries: string[];
try {
entries = await ops.readdir(dirPath);
} catch (e: any) {
reject(new Error(`Cannot read directory: ${e.message}`));
return;
}
// Sort alphabetically (case-insensitive)
entries.sort((a, b) =>
a.toLowerCase().localeCompare(b.toLowerCase()),
);
// Format entries with directory indicators
const results: string[] = [];
let entryLimitReached = false;
for (const entry of entries) {
if (results.length >= effectiveLimit) {
entryLimitReached = true;
break;
}
const fullPath = nodePath.join(dirPath, entry);
let suffix = "";
try {
const entryStat = await ops.stat(fullPath);
if (entryStat.isDirectory()) {
suffix = "/";
}
} catch {
// Skip entries we can't stat
continue;
}
results.push(entry + suffix);
}
signal?.removeEventListener("abort", onAbort);
if (results.length === 0) {
resolve({
content: [{ type: "text", text: "(empty directory)" }],
details: undefined,
});
return;
}
// Apply byte truncation (no line limit since we already have entry limit)
const rawOutput = results.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let output = truncation.content;
const details: LsToolDetails = {};
// Build notices
const notices: string[] = [];
if (entryLimitReached) {
notices.push(
`${effectiveLimit} entries limit reached. Use limit=${effectiveLimit * 2} for more`,
);
details.entryLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: output }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
})();
});
},
};
}
/** Default ls tool using process.cwd() - for backwards compatibility */
export const lsTool = createLsTool(process.cwd());

View file

@ -0,0 +1,94 @@
import { accessSync, constants } from "node:fs";
import * as os from "node:os";
import { isAbsolute, resolve as resolvePath } from "node:path";
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
const NARROW_NO_BREAK_SPACE = "\u202F";
function normalizeUnicodeSpaces(str: string): string {
return str.replace(UNICODE_SPACES, " ");
}
function tryMacOSScreenshotPath(filePath: string): string {
return filePath.replace(/ (AM|PM)\./g, `${NARROW_NO_BREAK_SPACE}$1.`);
}
function tryNFDVariant(filePath: string): string {
// macOS stores filenames in NFD (decomposed) form, try converting user input to NFD
return filePath.normalize("NFD");
}
function tryCurlyQuoteVariant(filePath: string): string {
// macOS uses U+2019 (right single quotation mark) in screenshot names like "Capture d'écran"
// Users typically type U+0027 (straight apostrophe)
return filePath.replace(/'/g, "\u2019");
}
function fileExists(filePath: string): boolean {
try {
accessSync(filePath, constants.F_OK);
return true;
} catch {
return false;
}
}
function normalizeAtPrefix(filePath: string): string {
return filePath.startsWith("@") ? filePath.slice(1) : filePath;
}
export function expandPath(filePath: string): string {
const normalized = normalizeUnicodeSpaces(normalizeAtPrefix(filePath));
if (normalized === "~") {
return os.homedir();
}
if (normalized.startsWith("~/")) {
return os.homedir() + normalized.slice(1);
}
return normalized;
}
/**
* Resolve a path relative to the given cwd.
* Handles ~ expansion and absolute paths.
*/
export function resolveToCwd(filePath: string, cwd: string): string {
const expanded = expandPath(filePath);
if (isAbsolute(expanded)) {
return expanded;
}
return resolvePath(cwd, expanded);
}
export function resolveReadPath(filePath: string, cwd: string): string {
const resolved = resolveToCwd(filePath, cwd);
if (fileExists(resolved)) {
return resolved;
}
// Try macOS AM/PM variant (narrow no-break space before AM/PM)
const amPmVariant = tryMacOSScreenshotPath(resolved);
if (amPmVariant !== resolved && fileExists(amPmVariant)) {
return amPmVariant;
}
// Try NFD variant (macOS stores filenames in NFD form)
const nfdVariant = tryNFDVariant(resolved);
if (nfdVariant !== resolved && fileExists(nfdVariant)) {
return nfdVariant;
}
// Try curly quote variant (macOS uses U+2019 in screenshot names)
const curlyVariant = tryCurlyQuoteVariant(resolved);
if (curlyVariant !== resolved && fileExists(curlyVariant)) {
return curlyVariant;
}
// Try combined NFD + curly quote (for French macOS screenshots like "Capture d'écran")
const nfdCurlyVariant = tryCurlyQuoteVariant(nfdVariant);
if (nfdCurlyVariant !== resolved && fileExists(nfdCurlyVariant)) {
return nfdCurlyVariant;
}
return resolved;
}

View file

@ -0,0 +1,265 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import { constants } from "fs";
import { access as fsAccess, readFile as fsReadFile } from "fs/promises";
import { formatDimensionNote, resizeImage } from "../../utils/image-resize.js";
import { detectSupportedImageMimeTypeFromFile } from "../../utils/mime.js";
import { resolveReadPath } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const readSchema = Type.Object({
path: Type.String({
description: "Path to the file to read (relative or absolute)",
}),
offset: Type.Optional(
Type.Number({
description: "Line number to start reading from (1-indexed)",
}),
),
limit: Type.Optional(
Type.Number({ description: "Maximum number of lines to read" }),
),
});
export type ReadToolInput = Static<typeof readSchema>;
export interface ReadToolDetails {
truncation?: TruncationResult;
}
/**
* Pluggable operations for the read tool.
* Override these to delegate file reading to remote systems (e.g., SSH).
*/
export interface ReadOperations {
/** Read file contents as a Buffer */
readFile: (absolutePath: string) => Promise<Buffer>;
/** Check if file is readable (throw if not) */
access: (absolutePath: string) => Promise<void>;
/** Detect image MIME type, return null/undefined for non-images */
detectImageMimeType?: (
absolutePath: string,
) => Promise<string | null | undefined>;
}
const defaultReadOperations: ReadOperations = {
readFile: (path) => fsReadFile(path),
access: (path) => fsAccess(path, constants.R_OK),
detectImageMimeType: detectSupportedImageMimeTypeFromFile,
};
export interface ReadToolOptions {
/** Whether to auto-resize images to 2000x2000 max. Default: true */
autoResizeImages?: boolean;
/** Custom operations for file reading. Default: local filesystem */
operations?: ReadOperations;
}
export function createReadTool(
cwd: string,
options?: ReadToolOptions,
): AgentTool<typeof readSchema> {
const autoResizeImages = options?.autoResizeImages ?? true;
const ops = options?.operations ?? defaultReadOperations;
return {
name: "read",
label: "read",
description: `Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files. When you need the full file, continue with offset until complete.`,
parameters: readSchema,
execute: async (
_toolCallId: string,
{
path,
offset,
limit,
}: { path: string; offset?: number; limit?: number },
signal?: AbortSignal,
) => {
const absolutePath = resolveReadPath(path, cwd);
return new Promise<{
content: (TextContent | ImageContent)[];
details: ReadToolDetails | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the read operation
(async () => {
try {
// Check if file exists
await ops.access(absolutePath);
// Check if aborted before reading
if (aborted) {
return;
}
const mimeType = ops.detectImageMimeType
? await ops.detectImageMimeType(absolutePath)
: undefined;
// Read the file based on type
let content: (TextContent | ImageContent)[];
let details: ReadToolDetails | undefined;
if (mimeType) {
// Read as image (binary)
const buffer = await ops.readFile(absolutePath);
const base64 = buffer.toString("base64");
if (autoResizeImages) {
// Resize image if needed
const resized = await resizeImage({
type: "image",
data: base64,
mimeType,
});
const dimensionNote = formatDimensionNote(resized);
let textNote = `Read image file [${resized.mimeType}]`;
if (dimensionNote) {
textNote += `\n${dimensionNote}`;
}
content = [
{ type: "text", text: textNote },
{
type: "image",
data: resized.data,
mimeType: resized.mimeType,
},
];
} else {
const textNote = `Read image file [${mimeType}]`;
content = [
{ type: "text", text: textNote },
{ type: "image", data: base64, mimeType },
];
}
} else {
// Read as text
const buffer = await ops.readFile(absolutePath);
const textContent = buffer.toString("utf-8");
const allLines = textContent.split("\n");
const totalFileLines = allLines.length;
// Apply offset if specified (1-indexed to 0-indexed)
const startLine = offset ? Math.max(0, offset - 1) : 0;
const startLineDisplay = startLine + 1; // For display (1-indexed)
// Check if offset is out of bounds
if (startLine >= allLines.length) {
throw new Error(
`Offset ${offset} is beyond end of file (${allLines.length} lines total)`,
);
}
// If limit is specified by user, use it; otherwise we'll let truncateHead decide
let selectedContent: string;
let userLimitedLines: number | undefined;
if (limit !== undefined) {
const endLine = Math.min(startLine + limit, allLines.length);
selectedContent = allLines.slice(startLine, endLine).join("\n");
userLimitedLines = endLine - startLine;
} else {
selectedContent = allLines.slice(startLine).join("\n");
}
// Apply truncation (respects both line and byte limits)
const truncation = truncateHead(selectedContent);
let outputText: string;
if (truncation.firstLineExceedsLimit) {
// First line at offset exceeds 30KB - tell model to use bash
const firstLineSize = formatSize(
Buffer.byteLength(allLines[startLine], "utf-8"),
);
outputText = `[Line ${startLineDisplay} is ${firstLineSize}, exceeds ${formatSize(DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '${startLineDisplay}p' ${path} | head -c ${DEFAULT_MAX_BYTES}]`;
details = { truncation };
} else if (truncation.truncated) {
// Truncation occurred - build actionable notice
const endLineDisplay =
startLineDisplay + truncation.outputLines - 1;
const nextOffset = endLineDisplay + 1;
outputText = truncation.content;
if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines}. Use offset=${nextOffset} to continue.]`;
} else {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Use offset=${nextOffset} to continue.]`;
}
details = { truncation };
} else if (
userLimitedLines !== undefined &&
startLine + userLimitedLines < allLines.length
) {
// User specified limit, there's more content, but no truncation
const remaining =
allLines.length - (startLine + userLimitedLines);
const nextOffset = startLine + userLimitedLines + 1;
outputText = truncation.content;
outputText += `\n\n[${remaining} more lines in file. Use offset=${nextOffset} to continue.]`;
} else {
// No truncation, no user limit exceeded
outputText = truncation.content;
}
content = [{ type: "text", text: outputText }];
}
// Check if aborted after reading
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({ content, details });
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default read tool using process.cwd() - for backwards compatibility */
export const readTool = createReadTool(process.cwd());

View file

@ -0,0 +1,279 @@
/**
* Shared truncation utilities for tool outputs.
*
* Truncation is based on two independent limits - whichever is hit first wins:
* - Line limit (default: 2000 lines)
* - Byte limit (default: 50KB)
*
* Never returns partial lines (except bash tail truncation edge case).
*/
export const DEFAULT_MAX_LINES = 2000;
export const DEFAULT_MAX_BYTES = 50 * 1024; // 50KB
export const GREP_MAX_LINE_LENGTH = 500; // Max chars per grep match line
export interface TruncationResult {
/** The truncated content */
content: string;
/** Whether truncation occurred */
truncated: boolean;
/** Which limit was hit: "lines", "bytes", or null if not truncated */
truncatedBy: "lines" | "bytes" | null;
/** Total number of lines in the original content */
totalLines: number;
/** Total number of bytes in the original content */
totalBytes: number;
/** Number of complete lines in the truncated output */
outputLines: number;
/** Number of bytes in the truncated output */
outputBytes: number;
/** Whether the last line was partially truncated (only for tail truncation edge case) */
lastLinePartial: boolean;
/** Whether the first line exceeded the byte limit (for head truncation) */
firstLineExceedsLimit: boolean;
/** The max lines limit that was applied */
maxLines: number;
/** The max bytes limit that was applied */
maxBytes: number;
}
export interface TruncationOptions {
/** Maximum number of lines (default: 2000) */
maxLines?: number;
/** Maximum number of bytes (default: 50KB) */
maxBytes?: number;
}
/**
* Format bytes as human-readable size.
*/
export function formatSize(bytes: number): string {
if (bytes < 1024) {
return `${bytes}B`;
} else if (bytes < 1024 * 1024) {
return `${(bytes / 1024).toFixed(1)}KB`;
} else {
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
}
}
/**
* Truncate content from the head (keep first N lines/bytes).
* Suitable for file reads where you want to see the beginning.
*
* Never returns partial lines. If first line exceeds byte limit,
* returns empty content with firstLineExceedsLimit=true.
*/
export function truncateHead(
content: string,
options: TruncationOptions = {},
): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
// Check if first line alone exceeds byte limit
const firstLineBytes = Buffer.byteLength(lines[0], "utf-8");
if (firstLineBytes > maxBytes) {
return {
content: "",
truncated: true,
truncatedBy: "bytes",
totalLines,
totalBytes,
outputLines: 0,
outputBytes: 0,
lastLinePartial: false,
firstLineExceedsLimit: true,
maxLines,
maxBytes,
};
}
// Collect complete lines that fit
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
for (let i = 0; i < lines.length && i < maxLines; i++) {
const line = lines[i];
const lineBytes = Buffer.byteLength(line, "utf-8") + (i > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
break;
}
outputLinesArr.push(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
/**
* Truncate content from the tail (keep last N lines/bytes).
* Suitable for bash output where you want to see the end (errors, final results).
*
* May return partial first line if the last line of original content exceeds byte limit.
*/
export function truncateTail(
content: string,
options: TruncationOptions = {},
): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
// Work backwards from the end
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
let lastLinePartial = false;
for (
let i = lines.length - 1;
i >= 0 && outputLinesArr.length < maxLines;
i--
) {
const line = lines[i];
const lineBytes =
Buffer.byteLength(line, "utf-8") + (outputLinesArr.length > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
// Edge case: if we haven't added ANY lines yet and this line exceeds maxBytes,
// take the end of the line (partial)
if (outputLinesArr.length === 0) {
const truncatedLine = truncateStringToBytesFromEnd(line, maxBytes);
outputLinesArr.unshift(truncatedLine);
outputBytesCount = Buffer.byteLength(truncatedLine, "utf-8");
lastLinePartial = true;
}
break;
}
outputLinesArr.unshift(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
/**
* Truncate a string to fit within a byte limit (from the end).
* Handles multi-byte UTF-8 characters correctly.
*/
function truncateStringToBytesFromEnd(str: string, maxBytes: number): string {
const buf = Buffer.from(str, "utf-8");
if (buf.length <= maxBytes) {
return str;
}
// Start from the end, skip maxBytes back
let start = buf.length - maxBytes;
// Find a valid UTF-8 boundary (start of a character)
while (start < buf.length && (buf[start] & 0xc0) === 0x80) {
start++;
}
return buf.slice(start).toString("utf-8");
}
/**
* Truncate a single line to max characters, adding [truncated] suffix.
* Used for grep match lines.
*/
export function truncateLine(
line: string,
maxChars: number = GREP_MAX_LINE_LENGTH,
): { text: string; wasTruncated: boolean } {
if (line.length <= maxChars) {
return { text: line, wasTruncated: false };
}
return {
text: `${line.slice(0, maxChars)}... [truncated]`,
wasTruncated: true,
};
}

View file

@ -0,0 +1,129 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { mkdir as fsMkdir, writeFile as fsWriteFile } from "fs/promises";
import { dirname } from "path";
import { resolveToCwd } from "./path-utils.js";
const writeSchema = Type.Object({
path: Type.String({
description: "Path to the file to write (relative or absolute)",
}),
content: Type.String({ description: "Content to write to the file" }),
});
export type WriteToolInput = Static<typeof writeSchema>;
/**
* Pluggable operations for the write tool.
* Override these to delegate file writing to remote systems (e.g., SSH).
*/
export interface WriteOperations {
/** Write content to a file */
writeFile: (absolutePath: string, content: string) => Promise<void>;
/** Create directory (recursively) */
mkdir: (dir: string) => Promise<void>;
}
const defaultWriteOperations: WriteOperations = {
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
mkdir: (dir) => fsMkdir(dir, { recursive: true }).then(() => {}),
};
export interface WriteToolOptions {
/** Custom operations for file writing. Default: local filesystem */
operations?: WriteOperations;
}
export function createWriteTool(
cwd: string,
options?: WriteToolOptions,
): AgentTool<typeof writeSchema> {
const ops = options?.operations ?? defaultWriteOperations;
return {
name: "write",
label: "write",
description:
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories.",
parameters: writeSchema,
execute: async (
_toolCallId: string,
{ path, content }: { path: string; content: string },
signal?: AbortSignal,
) => {
const absolutePath = resolveToCwd(path, cwd);
const dir = dirname(absolutePath);
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the write operation
(async () => {
try {
// Create parent directories if needed
await ops.mkdir(dir);
// Check if aborted before writing
if (aborted) {
return;
}
// Write the file
await ops.writeFile(absolutePath, content);
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({
content: [
{
type: "text",
text: `Successfully wrote ${content.length} bytes to ${path}`,
},
],
details: undefined,
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default write tool using process.cwd() - for backwards compatibility */
export const writeTool = createWriteTool(process.cwd());

View file

@ -0,0 +1,205 @@
import { randomUUID } from "node:crypto";
import type { ServerResponse } from "node:http";
import type { AgentSessionEvent } from "./agent-session.js";
/**
* Write a single Vercel AI SDK v5+ SSE chunk to the response.
* Format: `data: <JSON>\n\n`
* For the terminal [DONE] sentinel: `data: [DONE]\n\n`
*/
function writeChunk(response: ServerResponse, chunk: object | string): void {
if (response.writableEnded) return;
const payload = typeof chunk === "string" ? chunk : JSON.stringify(chunk);
response.write(`data: ${payload}\n\n`);
}
/**
* Extract the user's text from the request body.
* Supports both useChat format ({ messages: UIMessage[] }) and simple gateway format ({ text: string }).
*/
export function extractUserText(body: Record<string, unknown>): string | null {
// Simple gateway format
if (typeof body.text === "string" && body.text.trim()) {
return body.text;
}
// Convenience format
if (typeof body.prompt === "string" && body.prompt.trim()) {
return body.prompt;
}
// Vercel AI SDK useChat format - extract last user message
if (Array.isArray(body.messages)) {
for (let i = body.messages.length - 1; i >= 0; i--) {
const msg = body.messages[i] as Record<string, unknown>;
if (msg.role !== "user") continue;
// v5+ format with parts array
if (Array.isArray(msg.parts)) {
for (const part of msg.parts as Array<Record<string, unknown>>) {
if (part.type === "text" && typeof part.text === "string") {
return part.text;
}
}
}
// v4 format with content string
if (typeof msg.content === "string" && msg.content.trim()) {
return msg.content;
}
}
}
return null;
}
/**
* Create an AgentSessionEvent listener that translates events to Vercel AI SDK v5+ SSE
* chunks and writes them to the HTTP response.
*
* Returns the listener function. The caller is responsible for subscribing/unsubscribing.
*/
export function createVercelStreamListener(
response: ServerResponse,
messageId?: string,
): (event: AgentSessionEvent) => void {
// Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
// handleChat now subscribes this listener immediately before the queued prompt starts,
// so these guards only need to bound the stream to that prompt's event span.
let active = false;
const msgId = messageId ?? randomUUID();
return (event: AgentSessionEvent) => {
if (response.writableEnded) return;
// Activate on our agent_start, deactivate on agent_end
if (event.type === "agent_start") {
if (!active) {
active = true;
writeChunk(response, { type: "start", messageId: msgId });
}
return;
}
if (event.type === "agent_end") {
active = false;
return;
}
// Drop events that don't belong to our message
if (!active) return;
switch (event.type) {
case "turn_start":
writeChunk(response, { type: "start-step" });
return;
case "message_update": {
const inner = event.assistantMessageEvent;
switch (inner.type) {
case "text_start":
writeChunk(response, {
type: "text-start",
id: `text_${inner.contentIndex}`,
});
return;
case "text_delta":
writeChunk(response, {
type: "text-delta",
id: `text_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "text_end":
writeChunk(response, {
type: "text-end",
id: `text_${inner.contentIndex}`,
});
return;
case "toolcall_start": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-start",
toolCallId: content.id,
toolName: content.name,
});
}
return;
}
case "toolcall_delta": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-delta",
toolCallId: content.id,
inputTextDelta: inner.delta,
});
}
return;
}
case "toolcall_end":
writeChunk(response, {
type: "tool-input-available",
toolCallId: inner.toolCall.id,
toolName: inner.toolCall.name,
input: inner.toolCall.arguments,
});
return;
case "thinking_start":
writeChunk(response, {
type: "reasoning-start",
id: `reasoning_${inner.contentIndex}`,
});
return;
case "thinking_delta":
writeChunk(response, {
type: "reasoning-delta",
id: `reasoning_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "thinking_end":
writeChunk(response, {
type: "reasoning-end",
id: `reasoning_${inner.contentIndex}`,
});
return;
}
return;
}
case "turn_end":
writeChunk(response, { type: "finish-step" });
return;
case "tool_execution_end":
writeChunk(response, {
type: "tool-output-available",
toolCallId: event.toolCallId,
output: event.result,
});
return;
}
};
}
/**
* Write the terminal finish sequence and end the response.
*/
export function finishVercelStream(
response: ServerResponse,
finishReason: string = "stop",
): void {
if (response.writableEnded) return;
writeChunk(response, { type: "finish", finishReason });
writeChunk(response, "[DONE]");
response.end();
}
/**
* Write an error chunk and end the response.
*/
export function errorVercelStream(
response: ServerResponse,
errorText: string,
): void {
if (response.writableEnded) return;
writeChunk(response, { type: "error", errorText });
writeChunk(response, "[DONE]");
response.end();
}