move pi-mono into companion-cloud as apps/companion-os

- Copy all pi-mono source into apps/companion-os/
- Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases
- Update deploy-staging.yml to build pi from source (bun compile) before Docker build
- Add apps/companion-os/** to path triggers
- No more cross-repo dispatch needed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Harivansh Rathi 2026-03-07 09:22:50 -08:00
commit 0250f72976
579 changed files with 206942 additions and 0 deletions

View file

@ -0,0 +1,18 @@
#!/usr/bin/env node
/**
* CLI entry point for the refactored coding agent.
* Uses main.ts with AgentSession and new mode modules.
*
* Test with: npx tsx src/cli-new.ts [args...]
*/
process.title = "pi";
import { setBedrockProviderModule } from "@mariozechner/pi-ai";
import { bedrockProviderModule } from "@mariozechner/pi-ai/bedrock-provider";
import { EnvHttpProxyAgent, setGlobalDispatcher } from "undici";
import { main } from "./main.js";
setGlobalDispatcher(new EnvHttpProxyAgent());
setBedrockProviderModule(bedrockProviderModule);
main(process.argv.slice(2));

View file

@ -0,0 +1,334 @@
/**
* CLI argument parsing and help display
*/
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import chalk from "chalk";
import { APP_NAME, CONFIG_DIR_NAME, ENV_AGENT_DIR } from "../config.js";
import { allTools, type ToolName } from "../core/tools/index.js";
export type Mode = "text" | "json" | "rpc";
export interface Args {
provider?: string;
model?: string;
apiKey?: string;
systemPrompt?: string;
appendSystemPrompt?: string;
thinking?: ThinkingLevel;
continue?: boolean;
resume?: boolean;
help?: boolean;
version?: boolean;
mode?: Mode;
noSession?: boolean;
session?: string;
sessionDir?: string;
models?: string[];
tools?: ToolName[];
noTools?: boolean;
extensions?: string[];
noExtensions?: boolean;
print?: boolean;
export?: string;
noSkills?: boolean;
skills?: string[];
promptTemplates?: string[];
noPromptTemplates?: boolean;
themes?: string[];
noThemes?: boolean;
listModels?: string | true;
offline?: boolean;
verbose?: boolean;
messages: string[];
fileArgs: string[];
/** Unknown flags (potentially extension flags) - map of flag name to value */
unknownFlags: Map<string, boolean | string>;
}
const VALID_THINKING_LEVELS = [
"off",
"minimal",
"low",
"medium",
"high",
"xhigh",
] as const;
export function isValidThinkingLevel(level: string): level is ThinkingLevel {
return VALID_THINKING_LEVELS.includes(level as ThinkingLevel);
}
export function parseArgs(
args: string[],
extensionFlags?: Map<string, { type: "boolean" | "string" }>,
): Args {
const result: Args = {
messages: [],
fileArgs: [],
unknownFlags: new Map(),
};
for (let i = 0; i < args.length; i++) {
const arg = args[i];
if (arg === "--help" || arg === "-h") {
result.help = true;
} else if (arg === "--version" || arg === "-v") {
result.version = true;
} else if (arg === "--mode" && i + 1 < args.length) {
const mode = args[++i];
if (mode === "text" || mode === "json" || mode === "rpc") {
result.mode = mode;
}
} else if (arg === "--continue" || arg === "-c") {
result.continue = true;
} else if (arg === "--resume" || arg === "-r") {
result.resume = true;
} else if (arg === "--provider" && i + 1 < args.length) {
result.provider = args[++i];
} else if (arg === "--model" && i + 1 < args.length) {
result.model = args[++i];
} else if (arg === "--api-key" && i + 1 < args.length) {
result.apiKey = args[++i];
} else if (arg === "--system-prompt" && i + 1 < args.length) {
result.systemPrompt = args[++i];
} else if (arg === "--append-system-prompt" && i + 1 < args.length) {
result.appendSystemPrompt = args[++i];
} else if (arg === "--no-session") {
result.noSession = true;
} else if (arg === "--session" && i + 1 < args.length) {
result.session = args[++i];
} else if (arg === "--session-dir" && i + 1 < args.length) {
result.sessionDir = args[++i];
} else if (arg === "--models" && i + 1 < args.length) {
result.models = args[++i].split(",").map((s) => s.trim());
} else if (arg === "--no-tools") {
result.noTools = true;
} else if (arg === "--tools" && i + 1 < args.length) {
const toolNames = args[++i].split(",").map((s) => s.trim());
const validTools: ToolName[] = [];
for (const name of toolNames) {
if (name in allTools) {
validTools.push(name as ToolName);
} else {
console.error(
chalk.yellow(
`Warning: Unknown tool "${name}". Valid tools: ${Object.keys(allTools).join(", ")}`,
),
);
}
}
result.tools = validTools;
} else if (arg === "--thinking" && i + 1 < args.length) {
const level = args[++i];
if (isValidThinkingLevel(level)) {
result.thinking = level;
} else {
console.error(
chalk.yellow(
`Warning: Invalid thinking level "${level}". Valid values: ${VALID_THINKING_LEVELS.join(", ")}`,
),
);
}
} else if (arg === "--print" || arg === "-p") {
result.print = true;
} else if (arg === "--export" && i + 1 < args.length) {
result.export = args[++i];
} else if ((arg === "--extension" || arg === "-e") && i + 1 < args.length) {
result.extensions = result.extensions ?? [];
result.extensions.push(args[++i]);
} else if (arg === "--no-extensions" || arg === "-ne") {
result.noExtensions = true;
} else if (arg === "--skill" && i + 1 < args.length) {
result.skills = result.skills ?? [];
result.skills.push(args[++i]);
} else if (arg === "--prompt-template" && i + 1 < args.length) {
result.promptTemplates = result.promptTemplates ?? [];
result.promptTemplates.push(args[++i]);
} else if (arg === "--theme" && i + 1 < args.length) {
result.themes = result.themes ?? [];
result.themes.push(args[++i]);
} else if (arg === "--no-skills" || arg === "-ns") {
result.noSkills = true;
} else if (arg === "--no-prompt-templates" || arg === "-np") {
result.noPromptTemplates = true;
} else if (arg === "--no-themes") {
result.noThemes = true;
} else if (arg === "--list-models") {
// Check if next arg is a search pattern (not a flag or file arg)
if (
i + 1 < args.length &&
!args[i + 1].startsWith("-") &&
!args[i + 1].startsWith("@")
) {
result.listModels = args[++i];
} else {
result.listModels = true;
}
} else if (arg === "--verbose") {
result.verbose = true;
} else if (arg === "--offline") {
result.offline = true;
} else if (arg.startsWith("@")) {
result.fileArgs.push(arg.slice(1)); // Remove @ prefix
} else if (arg.startsWith("--") && extensionFlags) {
// Check if it's an extension-registered flag
const flagName = arg.slice(2);
const extFlag = extensionFlags.get(flagName);
if (extFlag) {
if (extFlag.type === "boolean") {
result.unknownFlags.set(flagName, true);
} else if (extFlag.type === "string" && i + 1 < args.length) {
result.unknownFlags.set(flagName, args[++i]);
}
}
// Unknown flags without extensionFlags are silently ignored (first pass)
} else if (!arg.startsWith("-")) {
result.messages.push(arg);
}
}
return result;
}
export function printHelp(): void {
console.log(`${chalk.bold(APP_NAME)} - AI coding assistant with read, bash, edit, write tools
${chalk.bold("Usage:")}
${APP_NAME} [options] [@files...] [messages...]
${chalk.bold("Commands:")}
${APP_NAME} install <source> [-l] Install extension source and add to settings
${APP_NAME} remove <source> [-l] Remove extension source from settings
${APP_NAME} update [source] Update installed extensions (skips pinned sources)
${APP_NAME} list List installed extensions from settings
${APP_NAME} gateway Run the always-on gateway process
${APP_NAME} daemon Alias for gateway
${APP_NAME} config Open TUI to enable/disable package resources
${APP_NAME} <command> --help Show help for install/remove/update/list
${chalk.bold("Options:")}
--provider <name> Provider name (default: google)
--model <pattern> Model pattern or ID (supports "provider/id" and optional ":<thinking>")
--api-key <key> API key (defaults to env vars)
--system-prompt <text> System prompt (default: coding assistant prompt)
--append-system-prompt <text> Append text or file contents to the system prompt
--mode <mode> Output mode: text (default), json, or rpc
--print, -p Non-interactive mode: process prompt and exit
--continue, -c Continue previous session
--resume, -r Select a session to resume
--session <path> Use specific session file
--session-dir <dir> Directory for session storage and lookup
--no-session Don't save session (ephemeral)
--models <patterns> Comma-separated model patterns for Ctrl+P cycling
Supports globs (anthropic/*, *sonnet*) and fuzzy matching
--no-tools Disable all built-in tools
--tools <tools> Comma-separated list of tools to enable (default: read,bash,edit,write)
Available: read, bash, edit, write, grep, find, ls
--thinking <level> Set thinking level: off, minimal, low, medium, high, xhigh
--extension, -e <path> Load an extension file (can be used multiple times)
--no-extensions, -ne Disable extension discovery (explicit -e paths still work)
--skill <path> Load a skill file or directory (can be used multiple times)
--no-skills, -ns Disable skills discovery and loading
--prompt-template <path> Load a prompt template file or directory (can be used multiple times)
--no-prompt-templates, -np Disable prompt template discovery and loading
--theme <path> Load a theme file or directory (can be used multiple times)
--no-themes Disable theme discovery and loading
--export <file> Export session file to HTML and exit
--list-models [search] List available models (with optional fuzzy search)
--verbose Force verbose startup (overrides quietStartup setting)
--offline Disable startup network operations (same as PI_OFFLINE=1)
--help, -h Show this help
--version, -v Show version number
Extensions can register additional flags (e.g., --plan from plan-mode extension).
${chalk.bold("Examples:")}
# Interactive mode
${APP_NAME}
# Interactive mode with initial prompt
${APP_NAME} "List all .ts files in src/"
# Include files in initial message
${APP_NAME} @prompt.md @image.png "What color is the sky?"
# Non-interactive mode (process and exit)
${APP_NAME} -p "List all .ts files in src/"
# Multiple messages (interactive)
${APP_NAME} "Read package.json" "What dependencies do we have?"
# Continue previous session
${APP_NAME} --continue "What did we discuss?"
# Use different model
${APP_NAME} --provider openai --model gpt-4o-mini "Help me refactor this code"
# Use model with provider prefix (no --provider needed)
${APP_NAME} --model openai/gpt-4o "Help me refactor this code"
# Use model with thinking level shorthand
${APP_NAME} --model sonnet:high "Solve this complex problem"
# Limit model cycling to specific models
${APP_NAME} --models claude-sonnet,claude-haiku,gpt-4o
# Limit to a specific provider with glob pattern
${APP_NAME} --models "github-copilot/*"
# Cycle models with fixed thinking levels
${APP_NAME} --models sonnet:high,haiku:low
# Start with a specific thinking level
${APP_NAME} --thinking high "Solve this complex problem"
# Read-only mode (no file modifications possible)
${APP_NAME} --tools read,grep,find,ls -p "Review the code in src/"
# Export a session file to HTML
${APP_NAME} --export ~/${CONFIG_DIR_NAME}/agent/sessions/--path--/session.jsonl
${APP_NAME} --export session.jsonl output.html
${chalk.bold("Environment Variables:")}
ANTHROPIC_API_KEY - Anthropic Claude API key
ANTHROPIC_OAUTH_TOKEN - Anthropic OAuth token (alternative to API key)
OPENAI_API_KEY - OpenAI GPT API key
AZURE_OPENAI_API_KEY - Azure OpenAI API key
AZURE_OPENAI_BASE_URL - Azure OpenAI base URL (https://{resource}.openai.azure.com/openai/v1)
AZURE_OPENAI_RESOURCE_NAME - Azure OpenAI resource name (alternative to base URL)
AZURE_OPENAI_API_VERSION - Azure OpenAI API version (default: v1)
AZURE_OPENAI_DEPLOYMENT_NAME_MAP - Azure OpenAI model=deployment map (comma-separated)
GEMINI_API_KEY - Google Gemini API key
GROQ_API_KEY - Groq API key
CEREBRAS_API_KEY - Cerebras API key
XAI_API_KEY - xAI Grok API key
OPENROUTER_API_KEY - OpenRouter API key
AI_GATEWAY_API_KEY - Vercel AI Gateway API key
ZAI_API_KEY - ZAI API key
MISTRAL_API_KEY - Mistral API key
MINIMAX_API_KEY - MiniMax API key
OPENCODE_API_KEY - OpenCode Zen/OpenCode Go API key
KIMI_API_KEY - Kimi For Coding API key
AWS_PROFILE - AWS profile for Amazon Bedrock
AWS_ACCESS_KEY_ID - AWS access key for Amazon Bedrock
AWS_SECRET_ACCESS_KEY - AWS secret key for Amazon Bedrock
AWS_BEARER_TOKEN_BEDROCK - Bedrock API key (bearer token)
AWS_REGION - AWS region for Amazon Bedrock (e.g., us-east-1)
${ENV_AGENT_DIR.padEnd(32)} - Session storage directory (default: ~/${CONFIG_DIR_NAME}/agent)
PI_PACKAGE_DIR - Override package directory (for Nix/Guix store paths)
PI_OFFLINE - Disable startup network operations when set to 1/true/yes
PI_SHARE_VIEWER_URL - Base URL for /share command (default: https://pi.dev/session/)
PI_AI_ANTIGRAVITY_VERSION - Override Antigravity User-Agent version (e.g., 1.23.0)
${chalk.bold("Available Tools (default: read, bash, edit, write):")}
read - Read file contents
bash - Execute bash commands
edit - Edit files with find/replace
write - Write files (creates/overwrites)
grep - Search file contents (read-only, off by default)
find - Find files by glob pattern (read-only, off by default)
ls - List directory contents (read-only, off by default)
`);
}

View file

@ -0,0 +1,57 @@
/**
* TUI config selector for `pi config` command
*/
import { ProcessTerminal, TUI } from "@mariozechner/pi-tui";
import type { ResolvedPaths } from "../core/package-manager.js";
import type { SettingsManager } from "../core/settings-manager.js";
import { ConfigSelectorComponent } from "../modes/interactive/components/config-selector.js";
import {
initTheme,
stopThemeWatcher,
} from "../modes/interactive/theme/theme.js";
export interface ConfigSelectorOptions {
resolvedPaths: ResolvedPaths;
settingsManager: SettingsManager;
cwd: string;
agentDir: string;
}
/** Show TUI config selector and return when closed */
export async function selectConfig(
options: ConfigSelectorOptions,
): Promise<void> {
// Initialize theme before showing TUI
initTheme(options.settingsManager.getTheme(), true);
return new Promise((resolve) => {
const ui = new TUI(new ProcessTerminal());
let resolved = false;
const selector = new ConfigSelectorComponent(
options.resolvedPaths,
options.settingsManager,
options.cwd,
options.agentDir,
() => {
if (!resolved) {
resolved = true;
ui.stop();
stopThemeWatcher();
resolve();
}
},
() => {
ui.stop();
stopThemeWatcher();
process.exit(0);
},
() => ui.requestRender(),
);
ui.addChild(selector);
ui.setFocus(selector.getResourceList());
ui.start();
});
}

View file

@ -0,0 +1,105 @@
/**
* Process @file CLI arguments into text content and image attachments
*/
import { access, readFile, stat } from "node:fs/promises";
import type { ImageContent } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { resolve } from "path";
import { resolveReadPath } from "../core/tools/path-utils.js";
import { formatDimensionNote, resizeImage } from "../utils/image-resize.js";
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
export interface ProcessedFiles {
text: string;
images: ImageContent[];
}
export interface ProcessFileOptions {
/** Whether to auto-resize images to 2000x2000 max. Default: true */
autoResizeImages?: boolean;
}
/** Process @file arguments into text content and image attachments */
export async function processFileArguments(
fileArgs: string[],
options?: ProcessFileOptions,
): Promise<ProcessedFiles> {
const autoResizeImages = options?.autoResizeImages ?? true;
let text = "";
const images: ImageContent[] = [];
for (const fileArg of fileArgs) {
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
const absolutePath = resolve(resolveReadPath(fileArg, process.cwd()));
// Check if file exists
try {
await access(absolutePath);
} catch {
console.error(chalk.red(`Error: File not found: ${absolutePath}`));
process.exit(1);
}
// Check if file is empty
const stats = await stat(absolutePath);
if (stats.size === 0) {
// Skip empty files
continue;
}
const mimeType = await detectSupportedImageMimeTypeFromFile(absolutePath);
if (mimeType) {
// Handle image file
const content = await readFile(absolutePath);
const base64Content = content.toString("base64");
let attachment: ImageContent;
let dimensionNote: string | undefined;
if (autoResizeImages) {
const resized = await resizeImage({
type: "image",
data: base64Content,
mimeType,
});
dimensionNote = formatDimensionNote(resized);
attachment = {
type: "image",
mimeType: resized.mimeType,
data: resized.data,
};
} else {
attachment = {
type: "image",
mimeType,
data: base64Content,
};
}
images.push(attachment);
// Add text reference to image with optional dimension note
if (dimensionNote) {
text += `<file name="${absolutePath}">${dimensionNote}</file>\n`;
} else {
text += `<file name="${absolutePath}"></file>\n`;
}
} else {
// Handle text file
try {
const content = await readFile(absolutePath, "utf-8");
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
} catch (error: unknown) {
const message = error instanceof Error ? error.message : String(error);
console.error(
chalk.red(`Error: Could not read file ${absolutePath}: ${message}`),
);
process.exit(1);
}
}
}
return { text, images };
}

View file

@ -0,0 +1,126 @@
/**
* List available models with optional fuzzy search
*/
import type { Api, Model } from "@mariozechner/pi-ai";
import { fuzzyFilter } from "@mariozechner/pi-tui";
import type { ModelRegistry } from "../core/model-registry.js";
/**
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
*/
function formatTokenCount(count: number): string {
if (count >= 1_000_000) {
const millions = count / 1_000_000;
return millions % 1 === 0 ? `${millions}M` : `${millions.toFixed(1)}M`;
}
if (count >= 1_000) {
const thousands = count / 1_000;
return thousands % 1 === 0 ? `${thousands}K` : `${thousands.toFixed(1)}K`;
}
return count.toString();
}
/**
* List available models, optionally filtered by search pattern
*/
export async function listModels(
modelRegistry: ModelRegistry,
searchPattern?: string,
): Promise<void> {
const models = modelRegistry.getAvailable();
if (models.length === 0) {
console.log("No models available. Set API keys in environment variables.");
return;
}
// Apply fuzzy filter if search pattern provided
let filteredModels: Model<Api>[] = models;
if (searchPattern) {
filteredModels = fuzzyFilter(
models,
searchPattern,
(m) => `${m.provider} ${m.id}`,
);
}
if (filteredModels.length === 0) {
console.log(`No models matching "${searchPattern}"`);
return;
}
// Sort by provider, then by model id
filteredModels.sort((a, b) => {
const providerCmp = a.provider.localeCompare(b.provider);
if (providerCmp !== 0) return providerCmp;
return a.id.localeCompare(b.id);
});
// Calculate column widths
const rows = filteredModels.map((m) => ({
provider: m.provider,
model: m.id,
context: formatTokenCount(m.contextWindow),
maxOut: formatTokenCount(m.maxTokens),
thinking: m.reasoning ? "yes" : "no",
images: m.input.includes("image") ? "yes" : "no",
}));
const headers = {
provider: "provider",
model: "model",
context: "context",
maxOut: "max-out",
thinking: "thinking",
images: "images",
};
const widths = {
provider: Math.max(
headers.provider.length,
...rows.map((r) => r.provider.length),
),
model: Math.max(headers.model.length, ...rows.map((r) => r.model.length)),
context: Math.max(
headers.context.length,
...rows.map((r) => r.context.length),
),
maxOut: Math.max(
headers.maxOut.length,
...rows.map((r) => r.maxOut.length),
),
thinking: Math.max(
headers.thinking.length,
...rows.map((r) => r.thinking.length),
),
images: Math.max(
headers.images.length,
...rows.map((r) => r.images.length),
),
};
// Print header
const headerLine = [
headers.provider.padEnd(widths.provider),
headers.model.padEnd(widths.model),
headers.context.padEnd(widths.context),
headers.maxOut.padEnd(widths.maxOut),
headers.thinking.padEnd(widths.thinking),
headers.images.padEnd(widths.images),
].join(" ");
console.log(headerLine);
// Print rows
for (const row of rows) {
const line = [
row.provider.padEnd(widths.provider),
row.model.padEnd(widths.model),
row.context.padEnd(widths.context),
row.maxOut.padEnd(widths.maxOut),
row.thinking.padEnd(widths.thinking),
row.images.padEnd(widths.images),
].join(" ");
console.log(line);
}
}

View file

@ -0,0 +1,56 @@
/**
* TUI session selector for --resume flag
*/
import { ProcessTerminal, TUI } from "@mariozechner/pi-tui";
import { KeybindingsManager } from "../core/keybindings.js";
import type {
SessionInfo,
SessionListProgress,
} from "../core/session-manager.js";
import { SessionSelectorComponent } from "../modes/interactive/components/session-selector.js";
type SessionsLoader = (
onProgress?: SessionListProgress,
) => Promise<SessionInfo[]>;
/** Show TUI session selector and return selected session path or null if cancelled */
export async function selectSession(
currentSessionsLoader: SessionsLoader,
allSessionsLoader: SessionsLoader,
): Promise<string | null> {
return new Promise((resolve) => {
const ui = new TUI(new ProcessTerminal());
const keybindings = KeybindingsManager.create();
let resolved = false;
const selector = new SessionSelectorComponent(
currentSessionsLoader,
allSessionsLoader,
(path: string) => {
if (!resolved) {
resolved = true;
ui.stop();
resolve(path);
}
},
() => {
if (!resolved) {
resolved = true;
ui.stop();
resolve(null);
}
},
() => {
ui.stop();
process.exit(0);
},
() => ui.requestRender(),
{ showRenameHint: false, keybindings },
);
ui.addChild(selector);
ui.setFocus(selector.getSessionList());
ui.start();
});
}

View file

@ -0,0 +1,256 @@
import { existsSync, readFileSync } from "fs";
import { homedir } from "os";
import { dirname, join, resolve } from "path";
import { fileURLToPath } from "url";
// =============================================================================
// Package Detection
// =============================================================================
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
/**
* Detect if we're running as a Bun compiled binary.
* Bun binaries have import.meta.url containing "$bunfs", "~BUN", or "%7EBUN" (Bun's virtual filesystem path)
*/
export const isBunBinary =
import.meta.url.includes("$bunfs") ||
import.meta.url.includes("~BUN") ||
import.meta.url.includes("%7EBUN");
/** Detect if Bun is the runtime (compiled binary or bun run) */
export const isBunRuntime = !!process.versions.bun;
// =============================================================================
// Install Method Detection
// =============================================================================
export type InstallMethod =
| "bun-binary"
| "npm"
| "pnpm"
| "yarn"
| "bun"
| "unknown";
export function detectInstallMethod(): InstallMethod {
if (isBunBinary) {
return "bun-binary";
}
const resolvedPath = `${__dirname}\0${process.execPath || ""}`.toLowerCase();
if (
resolvedPath.includes("/pnpm/") ||
resolvedPath.includes("/.pnpm/") ||
resolvedPath.includes("\\pnpm\\")
) {
return "pnpm";
}
if (
resolvedPath.includes("/yarn/") ||
resolvedPath.includes("/.yarn/") ||
resolvedPath.includes("\\yarn\\")
) {
return "yarn";
}
if (isBunRuntime) {
return "bun";
}
if (
resolvedPath.includes("/npm/") ||
resolvedPath.includes("/node_modules/") ||
resolvedPath.includes("\\npm\\")
) {
return "npm";
}
return "unknown";
}
export function getUpdateInstruction(packageName: string): string {
const method = detectInstallMethod();
switch (method) {
case "bun-binary":
return `Download from: https://github.com/badlogic/pi-mono/releases/latest`;
case "pnpm":
return `Run: pnpm install -g ${packageName}`;
case "yarn":
return `Run: yarn global add ${packageName}`;
case "bun":
return `Run: bun install -g ${packageName}`;
case "npm":
return `Run: npm install -g ${packageName}`;
default:
return `Run: npm install -g ${packageName}`;
}
}
// =============================================================================
// Package Asset Paths (shipped with executable)
// =============================================================================
/**
* Get the base directory for resolving package assets (themes, package.json, README.md, CHANGELOG.md).
* - For Bun binary: returns the directory containing the executable
* - For Node.js (dist/): returns __dirname (the dist/ directory)
* - For tsx (src/): returns parent directory (the package root)
*/
export function getPackageDir(): string {
// Allow override via environment variable (useful for Nix/Guix where store paths tokenize poorly)
const envDir = process.env.PI_PACKAGE_DIR;
if (envDir) {
if (envDir === "~") return homedir();
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
return envDir;
}
if (isBunBinary) {
// Bun binary: process.execPath points to the compiled executable
return dirname(process.execPath);
}
// Node.js: walk up from __dirname until we find package.json
let dir = __dirname;
while (dir !== dirname(dir)) {
if (existsSync(join(dir, "package.json"))) {
return dir;
}
dir = dirname(dir);
}
// Fallback (shouldn't happen)
return __dirname;
}
/**
* Get path to built-in themes directory (shipped with package)
* - For Bun binary: theme/ next to executable
* - For Node.js (dist/): dist/modes/interactive/theme/
* - For tsx (src/): src/modes/interactive/theme/
*/
export function getThemesDir(): string {
if (isBunBinary) {
return join(dirname(process.execPath), "theme");
}
// Theme is in modes/interactive/theme/ relative to src/ or dist/
const packageDir = getPackageDir();
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
return join(packageDir, srcOrDist, "modes", "interactive", "theme");
}
/**
* Get path to HTML export template directory (shipped with package)
* - For Bun binary: export-html/ next to executable
* - For Node.js (dist/): dist/core/export-html/
* - For tsx (src/): src/core/export-html/
*/
export function getExportTemplateDir(): string {
if (isBunBinary) {
return join(dirname(process.execPath), "export-html");
}
const packageDir = getPackageDir();
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
return join(packageDir, srcOrDist, "core", "export-html");
}
/** Get path to package.json */
export function getPackageJsonPath(): string {
return join(getPackageDir(), "package.json");
}
/** Get path to README.md */
export function getReadmePath(): string {
return resolve(join(getPackageDir(), "README.md"));
}
/** Get path to docs directory */
export function getDocsPath(): string {
return resolve(join(getPackageDir(), "docs"));
}
/** Get path to CHANGELOG.md */
export function getChangelogPath(): string {
return resolve(join(getPackageDir(), "CHANGELOG.md"));
}
// =============================================================================
// App Config (from package.json piConfig)
// =============================================================================
const pkg = JSON.parse(readFileSync(getPackageJsonPath(), "utf-8"));
export const APP_NAME: string = pkg.piConfig?.name || "pi";
export const CONFIG_DIR_NAME: string = pkg.piConfig?.configDir || ".pi";
export const VERSION: string = pkg.version;
// e.g., PI_CODING_AGENT_DIR or TAU_CODING_AGENT_DIR
export const ENV_AGENT_DIR = `${APP_NAME.toUpperCase()}_CODING_AGENT_DIR`;
const DEFAULT_SHARE_VIEWER_URL = "https://pi.dev/session/";
/** Get the share viewer URL for a gist ID */
export function getShareViewerUrl(gistId: string): string {
const baseUrl = process.env.PI_SHARE_VIEWER_URL || DEFAULT_SHARE_VIEWER_URL;
return `${baseUrl}#${gistId}`;
}
// =============================================================================
// User Config Paths (~/.pi/agent/*)
// =============================================================================
/** Get the agent config directory (e.g., ~/.pi/agent/) */
export function getAgentDir(): string {
const envDir = process.env[ENV_AGENT_DIR];
if (envDir) {
// Expand tilde to home directory
if (envDir === "~") return homedir();
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
return envDir;
}
return join(homedir(), CONFIG_DIR_NAME, "agent");
}
/** Get path to user's custom themes directory */
export function getCustomThemesDir(): string {
return join(getAgentDir(), "themes");
}
/** Get path to models.json */
export function getModelsPath(): string {
return join(getAgentDir(), "models.json");
}
/** Get path to auth.json */
export function getAuthPath(): string {
return join(getAgentDir(), "auth.json");
}
/** Get path to settings.json */
export function getSettingsPath(): string {
return join(getAgentDir(), "settings.json");
}
/** Get path to tools directory */
export function getToolsDir(): string {
return join(getAgentDir(), "tools");
}
/** Get path to managed binaries directory (fd, rg) */
export function getBinDir(): string {
return join(getAgentDir(), "bin");
}
/** Get path to prompt templates directory */
export function getPromptsDir(): string {
return join(getAgentDir(), "prompts");
}
/** Get path to sessions directory */
export function getSessionsDir(): string {
return join(getAgentDir(), "sessions");
}
/** Get path to debug log file */
export function getDebugLogPath(): string {
return join(getAgentDir(), `${APP_NAME}-debug.log`);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,503 @@
/**
* Credential storage for API keys and OAuth tokens.
* Handles loading, saving, and refreshing credentials from auth.json.
*
* Uses file locking to prevent race conditions when multiple pi instances
* try to refresh tokens simultaneously.
*/
import {
getEnvApiKey,
type OAuthCredentials,
type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@mariozechner/pi-ai";
import {
getOAuthApiKey,
getOAuthProvider,
getOAuthProviders,
} from "@mariozechner/pi-ai/oauth";
import {
chmodSync,
existsSync,
mkdirSync,
readFileSync,
writeFileSync,
} from "fs";
import { dirname, join } from "path";
import lockfile from "proper-lockfile";
import { getAgentDir } from "../config.js";
import { resolveConfigValue } from "./resolve-config-value.js";
export type ApiKeyCredential = {
type: "api_key";
key: string;
};
export type OAuthCredential = {
type: "oauth";
} & OAuthCredentials;
export type AuthCredential = ApiKeyCredential | OAuthCredential;
export type AuthStorageData = Record<string, AuthCredential>;
type LockResult<T> = {
result: T;
next?: string;
};
export interface AuthStorageBackend {
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T>;
}
export class FileAuthStorageBackend implements AuthStorageBackend {
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
private ensureParentDir(): void {
const dir = dirname(this.authPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
}
private ensureFileExists(): void {
if (!existsSync(this.authPath)) {
writeFileSync(this.authPath, "{}", "utf-8");
chmodSync(this.authPath, 0o600);
}
}
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => void) | undefined;
try {
release = lockfile.lockSync(this.authPath, { realpath: false });
const current = existsSync(this.authPath)
? readFileSync(this.authPath, "utf-8")
: undefined;
const { result, next } = fn(current);
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
return result;
} finally {
if (release) {
release();
}
}
}
async withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T> {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => Promise<void>) | undefined;
let lockCompromised = false;
let lockCompromisedError: Error | undefined;
const throwIfCompromised = () => {
if (lockCompromised) {
throw (
lockCompromisedError ?? new Error("Auth storage lock was compromised")
);
}
};
try {
release = await lockfile.lock(this.authPath, {
retries: {
retries: 10,
factor: 2,
minTimeout: 100,
maxTimeout: 10000,
randomize: true,
},
stale: 30000,
onCompromised: (err) => {
lockCompromised = true;
lockCompromisedError = err;
},
});
throwIfCompromised();
const current = existsSync(this.authPath)
? readFileSync(this.authPath, "utf-8")
: undefined;
const { result, next } = await fn(current);
throwIfCompromised();
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
throwIfCompromised();
return result;
} finally {
if (release) {
try {
await release();
} catch {
// Ignore unlock errors when lock is compromised.
}
}
}
}
}
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
private value: string | undefined;
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
const { result, next } = fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
async withLockAsync<T>(
fn: (current: string | undefined) => Promise<LockResult<T>>,
): Promise<T> {
const { result, next } = await fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
}
/**
* Credential storage backed by a JSON file.
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
private loadError: Error | null = null;
private errors: Error[] = [];
private constructor(private storage: AuthStorageBackend) {
this.reload();
}
static create(authPath?: string): AuthStorage {
return new AuthStorage(
new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")),
);
}
static fromStorage(storage: AuthStorageBackend): AuthStorage {
return new AuthStorage(storage);
}
static inMemory(data: AuthStorageData = {}): AuthStorage {
const storage = new InMemoryAuthStorageBackend();
storage.withLock(() => ({
result: undefined,
next: JSON.stringify(data, null, 2),
}));
return AuthStorage.fromStorage(storage);
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Remove a runtime API key override.
*/
removeRuntimeApiKey(provider: string): void {
this.runtimeOverrides.delete(provider);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(
resolver: (provider: string) => string | undefined,
): void {
this.fallbackResolver = resolver;
}
private recordError(error: unknown): void {
const normalizedError =
error instanceof Error ? error : new Error(String(error));
this.errors.push(normalizedError);
}
private parseStorageData(content: string | undefined): AuthStorageData {
if (!content) {
return {};
}
return JSON.parse(content) as AuthStorageData;
}
/**
* Reload credentials from storage.
*/
reload(): void {
let content: string | undefined;
try {
this.storage.withLock((current) => {
content = current;
return { result: undefined };
});
this.data = this.parseStorageData(content);
this.loadError = null;
} catch (error) {
this.loadError = error as Error;
this.recordError(error);
}
}
private persistProviderChange(
provider: string,
credential: AuthCredential | undefined,
): void {
if (this.loadError) {
return;
}
try {
this.storage.withLock((current) => {
const currentData = this.parseStorageData(current);
const merged: AuthStorageData = { ...currentData };
if (credential) {
merged[provider] = credential;
} else {
delete merged[provider];
}
return { result: undefined, next: JSON.stringify(merged, null, 2) };
});
} catch (error) {
this.recordError(error);
}
}
/**
* Get credential for a provider.
*/
get(provider: string): AuthCredential | undefined {
return this.data[provider] ?? undefined;
}
/**
* Set credential for a provider.
*/
set(provider: string, credential: AuthCredential): void {
this.data[provider] = credential;
this.persistProviderChange(provider, credential);
}
/**
* Remove credential for a provider.
*/
remove(provider: string): void {
delete this.data[provider];
this.persistProviderChange(provider, undefined);
}
/**
* List all providers with credentials.
*/
list(): string[] {
return Object.keys(this.data);
}
/**
* Check if credentials exist for a provider in auth.json.
*/
has(provider: string): boolean {
return provider in this.data;
}
/**
* Check if any form of auth is configured for a provider.
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
*/
hasAuth(provider: string): boolean {
if (this.runtimeOverrides.has(provider)) return true;
if (this.data[provider]) return true;
if (getEnvApiKey(provider)) return true;
if (this.fallbackResolver?.(provider)) return true;
return false;
}
/**
* Get all credentials (for passing to getOAuthApiKey).
*/
getAll(): AuthStorageData {
return { ...this.data };
}
drainErrors(): Error[] {
const drained = [...this.errors];
this.errors = [];
return drained;
}
/**
* Login to an OAuth provider.
*/
async login(
providerId: OAuthProviderId,
callbacks: OAuthLoginCallbacks,
): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Refresh OAuth token with backend locking to prevent race conditions.
* Multiple pi instances may try to refresh simultaneously when tokens expire.
*/
private async refreshOAuthTokenWithLock(
providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
const result = await this.storage.withLockAsync(async (current) => {
const currentData = this.parseStorageData(current);
this.data = currentData;
this.loadError = null;
const cred = currentData[providerId];
if (cred?.type !== "oauth") {
return { result: null };
}
if (Date.now() < cred.expires) {
return {
result: { apiKey: provider.getApiKey(cred), newCredentials: cred },
};
}
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(currentData)) {
if (value.type === "oauth") {
oauthCreds[key] = value;
}
}
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
if (!refreshed) {
return { result: null };
}
const merged: AuthStorageData = {
...currentData,
[providerId]: { type: "oauth", ...refreshed.newCredentials },
};
this.data = merged;
this.loadError = null;
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
});
return result;
}
/**
* Get API key for a provider.
* Priority:
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed with locking)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return resolveConfigValue(cred.key);
}
if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) {
// Use locked refresh to prevent race conditions
try {
const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) {
return result.apiKey;
}
} catch (error) {
this.recordError(error);
// Refresh failed - re-read file to check if another instance succeeded
this.reload();
const updatedCred = this.data[providerId];
if (
updatedCred?.type === "oauth" &&
Date.now() < updatedCred.expires
) {
// Another instance refreshed successfully, use those credentials
return provider.getApiKey(updatedCred);
}
// Refresh truly failed - return undefined so model discovery skips this provider
// User can /login to re-authenticate (credentials preserved for retry)
return undefined;
}
} else {
// Token not expired, use current access token
return provider.getApiKey(cred);
}
}
// Fall back to environment variable
const envKey = getEnvApiKey(providerId);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
}
}

View file

@ -0,0 +1,296 @@
/**
* Bash command execution with streaming support and cancellation.
*
* This module provides a unified bash execution implementation used by:
* - AgentSession.executeBash() for interactive and RPC modes
* - Direct calls from modes that need bash execution
*/
import { randomBytes } from "node:crypto";
import { createWriteStream, type WriteStream } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { type ChildProcess, spawn } from "child_process";
import stripAnsi from "strip-ansi";
import {
getShellConfig,
getShellEnv,
killProcessTree,
sanitizeBinaryOutput,
} from "../utils/shell.js";
import type { BashOperations } from "./tools/bash.js";
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
// ============================================================================
// Types
// ============================================================================
export interface BashExecutorOptions {
/** Callback for streaming output chunks (already sanitized) */
onChunk?: (chunk: string) => void;
/** AbortSignal for cancellation */
signal?: AbortSignal;
}
export interface BashResult {
/** Combined stdout + stderr output (sanitized, possibly truncated) */
output: string;
/** Process exit code (undefined if killed/cancelled) */
exitCode: number | undefined;
/** Whether the command was cancelled via signal */
cancelled: boolean;
/** Whether the output was truncated */
truncated: boolean;
/** Path to temp file containing full output (if output exceeded truncation threshold) */
fullOutputPath?: string;
}
// ============================================================================
// Implementation
// ============================================================================
/**
* Execute a bash command with optional streaming and cancellation support.
*
* Features:
* - Streams sanitized output via onChunk callback
* - Writes large output to temp file for later retrieval
* - Supports cancellation via AbortSignal
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
* - Truncates output if it exceeds the default max bytes
*
* @param command - The bash command to execute
* @param options - Optional streaming callback and abort signal
* @returns Promise resolving to execution result
*/
export function executeBash(
command: string,
options?: BashExecutorOptions,
): Promise<BashResult> {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
const child: ChildProcess = spawn(shell, [...args, command], {
detached: true,
env: getShellEnv(),
stdio: ["ignore", "pipe", "pipe"],
});
// Track sanitized output for truncation
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
// Temp file for large output
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
// Handle abort signal
const abortHandler = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (options?.signal) {
if (options.signal.aborted) {
// Already aborted, don't even start
child.kill();
resolve({
output: "",
exitCode: undefined,
cancelled: true,
truncated: false,
});
return;
}
options.signal.addEventListener("abort", abortHandler, { once: true });
}
const decoder = new TextDecoder();
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(
stripAnsi(decoder.decode(data, { stream: true })),
).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
// Write already-buffered chunks to temp file
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer of sanitized text
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback if provided
if (options?.onChunk) {
options.onChunk(text);
}
};
child.stdout?.on("data", handleData);
child.stderr?.on("data", handleData);
child.on("close", (code) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
// Combine buffered chunks for truncation (already sanitized)
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
// code === null means killed (cancelled)
const cancelled = code === null;
resolve({
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: cancelled ? undefined : code,
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
});
});
child.on("error", (err) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
reject(err);
});
});
}
/**
* Execute a bash command using custom BashOperations.
* Used for remote execution (SSH, containers, etc.).
*/
export async function executeBashWithOperations(
command: string,
cwd: string,
operations: BashOperations,
options?: BashExecutorOptions,
): Promise<BashResult> {
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
const decoder = new TextDecoder();
const onData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(
stripAnsi(decoder.decode(data, { stream: true })),
).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback
if (options?.onChunk) {
options.onChunk(text);
}
};
try {
const result = await operations.exec(command, cwd, {
onData,
signal: options?.signal,
});
if (tempFileStream) {
tempFileStream.end();
}
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
const cancelled = options?.signal?.aborted ?? false;
return {
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
} catch (err) {
if (tempFileStream) {
tempFileStream.end();
}
// Check if it was an abort
if (options?.signal?.aborted) {
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
return {
output: truncationResult.truncated
? truncationResult.content
: fullOutput,
exitCode: undefined,
cancelled: true,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
}
throw err;
}
}

View file

@ -0,0 +1,382 @@
/**
* Branch summarization for tree navigation.
*
* When navigating to a different point in the session tree, this generates
* a summary of the branch being left so context isn't lost.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import { completeSimple } from "@mariozechner/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type {
ReadonlySessionManager,
SessionEntry,
} from "../session-manager.js";
import { estimateTokens } from "./compaction.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// Types
// ============================================================================
export interface BranchSummaryResult {
summary?: string;
readFiles?: string[];
modifiedFiles?: string[];
aborted?: boolean;
error?: string;
}
/** Details stored in BranchSummaryEntry.details for file tracking */
export interface BranchSummaryDetails {
readFiles: string[];
modifiedFiles: string[];
}
export type { FileOperations } from "./utils.js";
export interface BranchPreparation {
/** Messages extracted for summarization, in chronological order */
messages: AgentMessage[];
/** File operations extracted from tool calls */
fileOps: FileOperations;
/** Total estimated tokens in messages */
totalTokens: number;
}
export interface CollectEntriesResult {
/** Entries to summarize, in chronological order */
entries: SessionEntry[];
/** Common ancestor between old and new position, if any */
commonAncestorId: string | null;
}
export interface GenerateBranchSummaryOptions {
/** Model to use for summarization */
model: Model<any>;
/** API key for the model */
apiKey: string;
/** Abort signal for cancellation */
signal: AbortSignal;
/** Optional custom instructions for summarization */
customInstructions?: string;
/** If true, customInstructions replaces the default prompt instead of being appended */
replaceInstructions?: boolean;
/** Tokens reserved for prompt + LLM response (default 16384) */
reserveTokens?: number;
}
// ============================================================================
// Entry Collection
// ============================================================================
/**
* Collect entries that should be summarized when navigating from one position to another.
*
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
* along the way. Does NOT stop at compaction boundaries - those are included and their
* summaries become context.
*
* @param session - Session manager (read-only access)
* @param oldLeafId - Current position (where we're navigating from)
* @param targetId - Target position (where we're navigating to)
* @returns Entries to summarize and the common ancestor
*/
export function collectEntriesForBranchSummary(
session: ReadonlySessionManager,
oldLeafId: string | null,
targetId: string,
): CollectEntriesResult {
// If no old position, nothing to summarize
if (!oldLeafId) {
return { entries: [], commonAncestorId: null };
}
// Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
const targetPath = session.getBranch(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null;
for (let i = targetPath.length - 1; i >= 0; i--) {
if (oldPath.has(targetPath[i].id)) {
commonAncestorId = targetPath[i].id;
break;
}
}
// Collect entries from old leaf back to common ancestor
const entries: SessionEntry[] = [];
let current: string | null = oldLeafId;
while (current && current !== commonAncestorId) {
const entry = session.getEntry(current);
if (!entry) break;
entries.push(entry);
current = entry.parentId;
}
// Reverse to get chronological order
entries.reverse();
return { entries, commonAncestorId };
}
// ============================================================================
// Entry to Message Conversion
// ============================================================================
/**
* Extract AgentMessage from a session entry.
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
switch (entry.type) {
case "message":
// Skip tool results - context is in assistant's tool call
if (entry.message.role === "toolResult") return undefined;
return entry.message;
case "custom_message":
return createCustomMessage(
entry.customType,
entry.content,
entry.display,
entry.details,
entry.timestamp,
);
case "branch_summary":
return createBranchSummaryMessage(
entry.summary,
entry.fromId,
entry.timestamp,
);
case "compaction":
return createCompactionSummaryMessage(
entry.summary,
entry.tokensBefore,
entry.timestamp,
);
// These don't contribute to conversation content
case "thinking_level_change":
case "model_change":
case "custom":
case "label":
return undefined;
}
}
/**
* Prepare entries for summarization with token budget.
*
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
* This ensures we keep the most recent context when the branch is too long.
*
* Also collects file operations from:
* - Tool calls in assistant messages
* - Existing branch_summary entries' details (for cumulative tracking)
*
* @param entries - Entries in chronological order
* @param tokenBudget - Maximum tokens to include (0 = no limit)
*/
export function prepareBranchEntries(
entries: SessionEntry[],
tokenBudget: number = 0,
): BranchPreparation {
const messages: AgentMessage[] = [];
const fileOps = createFileOps();
let totalTokens = 0;
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
// This ensures we capture cumulative file tracking from nested branch summaries
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
for (const entry of entries) {
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
const details = entry.details as BranchSummaryDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
// Modified files go into both edited and written for proper deduplication
for (const f of details.modifiedFiles) {
fileOps.edited.add(f);
}
}
}
}
// Second pass: walk from newest to oldest, adding messages until token budget
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
const message = getMessageFromEntry(entry);
if (!message) continue;
// Extract file ops from assistant messages (tool calls)
extractFileOpsFromMessage(message, fileOps);
const tokens = estimateTokens(message);
// Check budget before adding
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
// If this is a summary entry, try to fit it anyway as it's important context
if (entry.type === "compaction" || entry.type === "branch_summary") {
if (totalTokens < tokenBudget * 0.9) {
messages.unshift(message);
totalTokens += tokens;
}
}
// Stop - we've hit the budget
break;
}
messages.unshift(message);
totalTokens += tokens;
}
return { messages, fileOps, totalTokens };
}
// ============================================================================
// Summary Generation
// ============================================================================
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
Summary of that exploration:
`;
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
Use this EXACT format:
## Goal
[What was the user trying to accomplish in this branch?]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Work that was started but not finished]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [What should happen next to continue this work]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of abandoned branch entries.
*
* @param entries - Session entries to summarize (chronological order)
* @param options - Generation options
*/
export async function generateBranchSummary(
entries: SessionEntry[],
options: GenerateBranchSummaryOptions,
): Promise<BranchSummaryResult> {
const {
model,
apiKey,
signal,
customInstructions,
replaceInstructions,
reserveTokens = 16384,
} = options;
// Token budget = context window minus reserved space for prompt + response
const contextWindow = model.contextWindow || 128000;
const tokenBudget = contextWindow - reserveTokens;
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
if (messages.length === 0) {
return { summary: "No content to summarize" };
}
// Transform to LLM-compatible messages, then serialize to text
// Serialization prevents the model from treating it as a conversation to continue
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
// Build prompt
let instructions: string;
if (replaceInstructions && customInstructions) {
instructions = customInstructions;
} else if (customInstructions) {
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
} else {
instructions = BRANCH_SUMMARY_PROMPT;
}
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
// Call LLM for summarization
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
{ apiKey, signal, maxTokens: 2048 },
);
// Check if aborted or errored
if (response.stopReason === "aborted") {
return { aborted: true };
}
if (response.stopReason === "error") {
return { error: response.errorMessage || "Summarization failed" };
}
let summary = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
// Prepend preamble to provide context about the branch summary
summary = BRANCH_SUMMARY_PREAMBLE + summary;
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
return {
summary: summary || "No summary generated",
readFiles,
modifiedFiles,
};
}

View file

@ -0,0 +1,899 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { completeSimple } from "@mariozechner/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// File Operation Tracking
// ============================================================================
/** Details stored in CompactionEntry.details for file tracking */
export interface CompactionDetails {
readFiles: string[];
modifiedFiles: string[];
}
/**
* Extract file operations from messages and previous compaction entries.
*/
function extractFileOperations(
messages: AgentMessage[],
entries: SessionEntry[],
prevCompactionIndex: number,
): FileOperations {
const fileOps = createFileOps();
// Collect from previous compaction's details (if pi-generated)
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
if (!prevCompaction.fromHook && prevCompaction.details) {
// fromHook field kept for session file compatibility
const details = prevCompaction.details as CompactionDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
for (const f of details.modifiedFiles) fileOps.edited.add(f);
}
}
}
// Extract from tool calls in messages
for (const msg of messages) {
extractFileOpsFromMessage(msg, fileOps);
}
return fileOps;
}
// ============================================================================
// Message Extraction
// ============================================================================
/**
* Extract AgentMessage from an entry if it produces one.
* Returns undefined for entries that don't contribute to LLM context.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
if (entry.type === "message") {
return entry.message;
}
if (entry.type === "custom_message") {
return createCustomMessage(
entry.customType,
entry.content,
entry.display,
entry.details,
entry.timestamp,
);
}
if (entry.type === "branch_summary") {
return createBranchSummaryMessage(
entry.summary,
entry.fromId,
entry.timestamp,
);
}
if (entry.type === "compaction") {
return createCompactionSummaryMessage(
entry.summary,
entry.tokensBefore,
entry.timestamp,
);
}
return undefined;
}
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
export interface CompactionResult<T = unknown> {
summary: string;
firstKeptEntryId: string;
tokensBefore: number;
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
details?: T;
}
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return (
usage.totalTokens ||
usage.input + usage.output + usage.cacheRead + usage.cacheWrite
);
}
/**
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (
assistantMsg.stopReason !== "aborted" &&
assistantMsg.stopReason !== "error" &&
assistantMsg.usage
) {
return assistantMsg.usage;
}
}
return undefined;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(
entries: SessionEntry[],
): Usage | undefined {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return undefined;
}
export interface ContextUsageEstimate {
tokens: number;
usageTokens: number;
trailingTokens: number;
lastUsageIndex: number | null;
}
function getLastAssistantUsageInfo(
messages: AgentMessage[],
): { usage: Usage; index: number } | undefined {
for (let i = messages.length - 1; i >= 0; i--) {
const usage = getAssistantUsage(messages[i]);
if (usage) return { usage, index: i };
}
return undefined;
}
/**
* Estimate context tokens from messages, using the last assistant usage when available.
* If there are messages after the last usage, estimate their tokens with estimateTokens.
*/
export function estimateContextTokens(
messages: AgentMessage[],
): ContextUsageEstimate {
const usageInfo = getLastAssistantUsageInfo(messages);
if (!usageInfo) {
let estimated = 0;
for (const message of messages) {
estimated += estimateTokens(message);
}
return {
tokens: estimated,
usageTokens: 0,
trailingTokens: estimated,
lastUsageIndex: null,
};
}
const usageTokens = calculateContextTokens(usageInfo.usage);
let trailingTokens = 0;
for (let i = usageInfo.index + 1; i < messages.length; i++) {
trailingTokens += estimateTokens(messages[i]);
}
return {
tokens: usageTokens + trailingTokens,
usageTokens,
trailingTokens,
lastUsageIndex: usageInfo.index,
};
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(
contextTokens: number,
contextWindow: number,
settings: CompactionSettings,
): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AgentMessage): number {
let chars = 0;
switch (message.role) {
case "user": {
const content = (
message as { content: string | Array<{ type: string; text?: string }> }
).content;
if (typeof content === "string") {
chars = content.length;
} else if (Array.isArray(content)) {
for (const block of content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
}
return Math.ceil(chars / 4);
}
case "assistant": {
const assistant = message as AssistantMessage;
for (const block of assistant.content) {
if (block.type === "text") {
chars += block.text.length;
} else if (block.type === "thinking") {
chars += block.thinking.length;
} else if (block.type === "toolCall") {
chars += block.name.length + JSON.stringify(block.arguments).length;
}
}
return Math.ceil(chars / 4);
}
case "custom":
case "toolResult": {
if (typeof message.content === "string") {
chars = message.content.length;
} else {
for (const block of message.content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
if (block.type === "image") {
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
}
}
}
return Math.ceil(chars / 4);
}
case "bashExecution": {
chars = message.command.length + message.output.length;
return Math.ceil(chars / 4);
}
case "branchSummary":
case "compactionSummary": {
chars = message.summary.length;
return Math.ceil(chars / 4);
}
}
return 0;
}
/**
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
* Never cut at tool results (they must follow their tool call).
* When we cut at an assistant message with tool calls, its tool results follow it
* and will be kept.
* BashExecutionMessage is treated like a user message (user-initiated context).
*/
function findValidCutPoints(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
): number[] {
const cutPoints: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
switch (entry.type) {
case "message": {
const role = entry.message.role;
switch (role) {
case "bashExecution":
case "custom":
case "branchSummary":
case "compactionSummary":
case "user":
case "assistant":
cutPoints.push(i);
break;
case "toolResult":
break;
}
break;
}
case "thinking_level_change":
case "model_change":
case "compaction":
case "branch_summary":
case "custom":
case "custom_message":
case "label":
}
// branch_summary and custom_message are user-role messages, valid cut points
if (entry.type === "branch_summary" || entry.type === "custom_message") {
cutPoints.push(i);
}
}
return cutPoints;
}
/**
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
* Returns -1 if no turn start found before the index.
* BashExecutionMessage is treated like a user message for turn boundaries.
*/
export function findTurnStartIndex(
entries: SessionEntry[],
entryIndex: number,
startIndex: number,
): number {
for (let i = entryIndex; i >= startIndex; i--) {
const entry = entries[i];
// branch_summary and custom_message are user-role messages, can start a turn
if (entry.type === "branch_summary" || entry.type === "custom_message") {
return i;
}
if (entry.type === "message") {
const role = entry.message.role;
if (role === "user" || role === "bashExecution") {
return i;
}
}
}
return -1;
}
export interface CutPointResult {
/** Index of first entry to keep */
firstKeptEntryIndex: number;
/** Index of user message that starts the turn being split, or -1 if not splitting */
turnStartIndex: number;
/** Whether this cut splits a turn (cut point is not a user message) */
isSplitTurn: boolean;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
*
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
*
* Can cut at user OR assistant messages (never tool results). When cutting at an
* assistant message with tool calls, its tool results come after and will be kept.
*
* Returns CutPointResult with:
* - firstKeptEntryIndex: the entry index to start keeping from
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
* - isSplitTurn: whether we're cutting in the middle of a turn
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): CutPointResult {
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
if (cutPoints.length === 0) {
return {
firstKeptEntryIndex: startIndex,
turnStartIndex: -1,
isSplitTurn: false,
};
}
// Walk backwards from newest, accumulating estimated message sizes
let accumulatedTokens = 0;
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type !== "message") continue;
// Estimate this message's size
const messageTokens = estimateTokens(entry.message);
accumulatedTokens += messageTokens;
// Check if we've exceeded the budget
if (accumulatedTokens >= keepRecentTokens) {
// Find the closest valid cut point at or after this entry
for (let c = 0; c < cutPoints.length; c++) {
if (cutPoints[c] >= i) {
cutIndex = cutPoints[c];
break;
}
}
break;
}
}
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
while (cutIndex > startIndex) {
const prevEntry = entries[cutIndex - 1];
// Stop at session header or compaction boundaries
if (prevEntry.type === "compaction") {
break;
}
if (prevEntry.type === "message") {
// Stop if we hit any message
break;
}
// Include this non-message entry (bash, settings change, etc.)
cutIndex--;
}
// Determine if this is a split turn
const cutEntry = entries[cutIndex];
const isUserMessage =
cutEntry.type === "message" && cutEntry.message.role === "user";
const turnStartIndex = isUserMessage
? -1
: findTurnStartIndex(entries, cutIndex, startIndex);
return {
firstKeptEntryIndex: cutIndex,
turnStartIndex,
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
};
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
Use this EXACT format:
## Goal
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned by user]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Current work]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [Ordered list of what should happen next]
## Critical Context
- [Any data, examples, or references needed to continue]
- [Or "(none)" if not applicable]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
Update the existing structured summary with new information. RULES:
- PRESERVE all existing information from the previous summary
- ADD new progress, decisions, and context from the new messages
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
- UPDATE "Next Steps" based on what was accomplished
- PRESERVE exact file paths, function names, and error messages
- If something is no longer relevant, you may remove it
Use this EXACT format:
## Goal
[Preserve existing goals, add new ones if the task expanded]
## Constraints & Preferences
- [Preserve existing, add new ones discovered]
## Progress
### Done
- [x] [Include previously done items AND newly completed items]
### In Progress
- [ ] [Current work - update based on progress]
### Blocked
- [Current blockers - remove if resolved]
## Key Decisions
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
## Next Steps
1. [Update based on current state]
## Critical Context
- [Preserve important context, add new if needed]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of the conversation using the LLM.
* If previousSummary is provided, uses the update prompt to merge.
*/
export async function generateSummary(
currentMessages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
previousSummary?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
// Use update prompt if we have a previous summary, otherwise initial prompt
let basePrompt = previousSummary
? UPDATE_SUMMARIZATION_PROMPT
: SUMMARIZATION_PROMPT;
if (customInstructions) {
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
}
// Serialize conversation to text so model doesn't try to continue it
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
const llmMessages = convertToLlm(currentMessages);
const conversationText = serializeConversation(llmMessages);
// Build the prompt with conversation wrapped in tags
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
if (previousSummary) {
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
}
promptText += basePrompt;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const completionOptions = model.reasoning
? { maxTokens, signal, apiKey, reasoning: "high" as const }
: { maxTokens, signal, apiKey };
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
completionOptions,
);
if (response.stopReason === "error") {
throw new Error(
`Summarization failed: ${response.errorMessage || "Unknown error"}`,
);
}
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Compaction Preparation (for extensions)
// ============================================================================
export interface CompactionPreparation {
/** UUID of first entry to keep */
firstKeptEntryId: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AgentMessage[];
/** Messages that will be turned into turn prefix summary (if splitting) */
turnPrefixMessages: AgentMessage[];
/** Whether this is a split turn (cut point in middle of turn) */
isSplitTurn: boolean;
tokensBefore: number;
/** Summary from previous compaction, for iterative update */
previousSummary?: string;
/** File operations extracted from messagesToSummarize */
fileOps: FileOperations;
/** Compaction settions from settings.jsonl */
settings: CompactionSettings;
}
export function prepareCompaction(
pathEntries: SessionEntry[],
settings: CompactionSettings,
): CompactionPreparation | undefined {
if (
pathEntries.length > 0 &&
pathEntries[pathEntries.length - 1].type === "compaction"
) {
return undefined;
}
let prevCompactionIndex = -1;
for (let i = pathEntries.length - 1; i >= 0; i--) {
if (pathEntries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = pathEntries.length;
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
const usageMessages: AgentMessage[] = [];
for (let i = usageStart; i < boundaryEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) usageMessages.push(msg);
}
const tokensBefore = estimateContextTokens(usageMessages).tokens;
const cutPoint = findCutPoint(
pathEntries,
boundaryStart,
boundaryEnd,
settings.keepRecentTokens,
);
// Get UUID of first kept entry
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
if (!firstKeptEntry?.id) {
return undefined; // Session needs migration
}
const firstKeptEntryId = firstKeptEntry.id;
const historyEnd = cutPoint.isSplitTurn
? cutPoint.turnStartIndex
: cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) messagesToSummarize.push(msg);
}
// Messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AgentMessage[] = [];
if (cutPoint.isSplitTurn) {
for (
let i = cutPoint.turnStartIndex;
i < cutPoint.firstKeptEntryIndex;
i++
) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) turnPrefixMessages.push(msg);
}
}
// Get previous summary for iterative update
let previousSummary: string | undefined;
if (prevCompactionIndex >= 0) {
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
previousSummary = prevCompaction.summary;
}
// Extract file operations from messages and previous compaction
const fileOps = extractFileOperations(
messagesToSummarize,
pathEntries,
prevCompactionIndex,
);
// Also extract file ops from turn prefix if splitting
if (cutPoint.isSplitTurn) {
for (const msg of turnPrefixMessages) {
extractFileOpsFromMessage(msg, fileOps);
}
}
return {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn: cutPoint.isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
};
}
// ============================================================================
// Main compaction function
// ============================================================================
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
Summarize the prefix to provide context for the retained suffix:
## Original Request
[What did the user ask for in this turn?]
## Early Progress
- [Key decisions and work done in the prefix]
## Context for Suffix
- [Information needed to understand the retained recent work]
Be concise. Focus on what's needed to understand the kept suffix.`;
/**
* Generate summaries for compaction using prepared data.
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
*
* @param preparation - Pre-calculated preparation from prepareCompaction()
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
preparation: CompactionPreparation,
model: Model<any>,
apiKey: string,
customInstructions?: string,
signal?: AbortSignal,
): Promise<CompactionResult> {
const {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
} = preparation;
// Generate summaries (can be parallel if both needed) and merge into one
let summary: string;
if (isSplitTurn && turnPrefixMessages.length > 0) {
// Generate both summaries in parallel
const [historyResult, turnPrefixResult] = await Promise.all([
messagesToSummarize.length > 0
? generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(
turnPrefixMessages,
model,
settings.reserveTokens,
apiKey,
signal,
),
]);
// Merge into single summary
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
} else {
// Just generate history summary
summary = await generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
);
}
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
if (!firstKeptEntryId) {
throw new Error(
"First kept entry has no UUID - session may need migration",
);
}
return {
summary,
firstKeptEntryId,
tokensBefore,
details: { readFiles, modifiedFiles } as CompactionDetails,
};
}
/**
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
): Promise<string> {
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const response = await completeSimple(
model,
{
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
messages: summarizationMessages,
},
{ maxTokens, signal, apiKey },
);
if (response.stopReason === "error") {
throw new Error(
`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`,
);
}
return response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
}

View file

@ -0,0 +1,7 @@
/**
* Compaction and summarization utilities.
*/
export * from "./branch-summarization.js";
export * from "./compaction.js";
export * from "./utils.js";

View file

@ -0,0 +1,167 @@
/**
* Shared utilities for compaction and branch summarization.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
// ============================================================================
// File Operation Tracking
// ============================================================================
export interface FileOperations {
read: Set<string>;
written: Set<string>;
edited: Set<string>;
}
export function createFileOps(): FileOperations {
return {
read: new Set(),
written: new Set(),
edited: new Set(),
};
}
/**
* Extract file operations from tool calls in an assistant message.
*/
export function extractFileOpsFromMessage(
message: AgentMessage,
fileOps: FileOperations,
): void {
if (message.role !== "assistant") return;
if (!("content" in message) || !Array.isArray(message.content)) return;
for (const block of message.content) {
if (typeof block !== "object" || block === null) continue;
if (!("type" in block) || block.type !== "toolCall") continue;
if (!("arguments" in block) || !("name" in block)) continue;
const args = block.arguments as Record<string, unknown> | undefined;
if (!args) continue;
const path = typeof args.path === "string" ? args.path : undefined;
if (!path) continue;
switch (block.name) {
case "read":
fileOps.read.add(path);
break;
case "write":
fileOps.written.add(path);
break;
case "edit":
fileOps.edited.add(path);
break;
}
}
}
/**
* Compute final file lists from file operations.
* Returns readFiles (files only read, not modified) and modifiedFiles.
*/
export function computeFileLists(fileOps: FileOperations): {
readFiles: string[];
modifiedFiles: string[];
} {
const modified = new Set([...fileOps.edited, ...fileOps.written]);
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
const modifiedFiles = [...modified].sort();
return { readFiles: readOnly, modifiedFiles };
}
/**
* Format file operations as XML tags for summary.
*/
export function formatFileOperations(
readFiles: string[],
modifiedFiles: string[],
): string {
const sections: string[] = [];
if (readFiles.length > 0) {
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
}
if (modifiedFiles.length > 0) {
sections.push(
`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`,
);
}
if (sections.length === 0) return "";
return `\n\n${sections.join("\n\n")}`;
}
// ============================================================================
// Message Serialization
// ============================================================================
/**
* Serialize LLM messages to text for summarization.
* This prevents the model from treating it as a conversation to continue.
* Call convertToLlm() first to handle custom message types.
*/
export function serializeConversation(messages: Message[]): string {
const parts: string[] = [];
for (const msg of messages) {
if (msg.role === "user") {
const content =
typeof msg.content === "string"
? msg.content
: msg.content
.filter(
(c): c is { type: "text"; text: string } => c.type === "text",
)
.map((c) => c.text)
.join("");
if (content) parts.push(`[User]: ${content}`);
} else if (msg.role === "assistant") {
const textParts: string[] = [];
const thinkingParts: string[] = [];
const toolCalls: string[] = [];
for (const block of msg.content) {
if (block.type === "text") {
textParts.push(block.text);
} else if (block.type === "thinking") {
thinkingParts.push(block.thinking);
} else if (block.type === "toolCall") {
const args = block.arguments as Record<string, unknown>;
const argsStr = Object.entries(args)
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
.join(", ");
toolCalls.push(`${block.name}(${argsStr})`);
}
}
if (thinkingParts.length > 0) {
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
}
if (textParts.length > 0) {
parts.push(`[Assistant]: ${textParts.join("\n")}`);
}
if (toolCalls.length > 0) {
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
}
} else if (msg.role === "toolResult") {
const content = msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) {
parts.push(`[Tool result]: ${content}`);
}
}
}
return parts.join("\n\n");
}
// ============================================================================
// Summarization System Prompt
// ============================================================================
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;

View file

@ -0,0 +1,3 @@
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";

View file

@ -0,0 +1,15 @@
export interface ResourceCollision {
resourceType: "extension" | "skill" | "prompt" | "theme";
name: string; // skill name, command/tool/flag name, prompt name, theme name
winnerPath: string;
loserPath: string;
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
loserSource?: string;
}
export interface ResourceDiagnostic {
type: "warning" | "error" | "collision";
message: string;
path?: string;
collision?: ResourceCollision;
}

View file

@ -0,0 +1,33 @@
import { EventEmitter } from "node:events";
export interface EventBus {
emit(channel: string, data: unknown): void;
on(channel: string, handler: (data: unknown) => void): () => void;
}
export interface EventBusController extends EventBus {
clear(): void;
}
export function createEventBus(): EventBusController {
const emitter = new EventEmitter();
return {
emit: (channel, data) => {
emitter.emit(channel, data);
},
on: (channel, handler) => {
const safeHandler = async (data: unknown) => {
try {
await handler(data);
} catch (err) {
console.error(`Event handler error (${channel}):`, err);
}
};
emitter.on(channel, safeHandler);
return () => emitter.off(channel, safeHandler);
},
clear: () => {
emitter.removeAllListeners();
},
};
}

View file

@ -0,0 +1,104 @@
/**
* Shared command execution utilities for extensions and custom tools.
*/
import { spawn } from "node:child_process";
/**
* Options for executing shell commands.
*/
export interface ExecOptions {
/** AbortSignal to cancel the command */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
/** Working directory */
cwd?: string;
}
/**
* Result of executing a shell command.
*/
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
killed: boolean;
}
/**
* Execute a shell command and return stdout/stderr/code.
* Supports timeout and abort signal.
*/
export async function execCommand(
command: string,
args: string[],
cwd: string,
options?: ExecOptions,
): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, {
cwd,
shell: false,
stdio: ["ignore", "pipe", "pipe"],
});
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout?.on("data", (data) => {
stdout += data.toString();
});
proc.stderr?.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: code ?? 0, killed });
});
proc.on("error", (_err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: 1, killed });
});
});
}

View file

@ -0,0 +1,271 @@
/**
* ANSI escape code to HTML converter.
*
* Converts terminal ANSI color/style codes to HTML with inline styles.
* Supports:
* - Standard foreground colors (30-37) and bright variants (90-97)
* - Standard background colors (40-47) and bright variants (100-107)
* - 256-color palette (38;5;N and 48;5;N)
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
* - Text styles: bold (1), dim (2), italic (3), underline (4)
* - Reset (0)
*/
// Standard ANSI color palette (0-15)
const ANSI_COLORS = [
"#000000", // 0: black
"#800000", // 1: red
"#008000", // 2: green
"#808000", // 3: yellow
"#000080", // 4: blue
"#800080", // 5: magenta
"#008080", // 6: cyan
"#c0c0c0", // 7: white
"#808080", // 8: bright black
"#ff0000", // 9: bright red
"#00ff00", // 10: bright green
"#ffff00", // 11: bright yellow
"#0000ff", // 12: bright blue
"#ff00ff", // 13: bright magenta
"#00ffff", // 14: bright cyan
"#ffffff", // 15: bright white
];
/**
* Convert 256-color index to hex.
*/
function color256ToHex(index: number): string {
// Standard colors (0-15)
if (index < 16) {
return ANSI_COLORS[index];
}
// Color cube (16-231): 6x6x6 = 216 colors
if (index < 232) {
const cubeIndex = index - 16;
const r = Math.floor(cubeIndex / 36);
const g = Math.floor((cubeIndex % 36) / 6);
const b = cubeIndex % 6;
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
}
// Grayscale (232-255): 24 shades
const gray = 8 + (index - 232) * 10;
const grayHex = gray.toString(16).padStart(2, "0");
return `#${grayHex}${grayHex}${grayHex}`;
}
/**
* Escape HTML special characters.
*/
function escapeHtml(text: string): string {
return text
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
interface TextStyle {
fg: string | null;
bg: string | null;
bold: boolean;
dim: boolean;
italic: boolean;
underline: boolean;
}
function createEmptyStyle(): TextStyle {
return {
fg: null,
bg: null,
bold: false,
dim: false,
italic: false,
underline: false,
};
}
function styleToInlineCSS(style: TextStyle): string {
const parts: string[] = [];
if (style.fg) parts.push(`color:${style.fg}`);
if (style.bg) parts.push(`background-color:${style.bg}`);
if (style.bold) parts.push("font-weight:bold");
if (style.dim) parts.push("opacity:0.6");
if (style.italic) parts.push("font-style:italic");
if (style.underline) parts.push("text-decoration:underline");
return parts.join(";");
}
function hasStyle(style: TextStyle): boolean {
return (
style.fg !== null ||
style.bg !== null ||
style.bold ||
style.dim ||
style.italic ||
style.underline
);
}
/**
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
*/
function applySgrCode(params: number[], style: TextStyle): void {
let i = 0;
while (i < params.length) {
const code = params[i];
if (code === 0) {
// Reset all
style.fg = null;
style.bg = null;
style.bold = false;
style.dim = false;
style.italic = false;
style.underline = false;
} else if (code === 1) {
style.bold = true;
} else if (code === 2) {
style.dim = true;
} else if (code === 3) {
style.italic = true;
} else if (code === 4) {
style.underline = true;
} else if (code === 22) {
// Reset bold/dim
style.bold = false;
style.dim = false;
} else if (code === 23) {
style.italic = false;
} else if (code === 24) {
style.underline = false;
} else if (code >= 30 && code <= 37) {
// Standard foreground colors
style.fg = ANSI_COLORS[code - 30];
} else if (code === 38) {
// Extended foreground color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 38;5;N
style.fg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 38;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.fg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 39) {
// Default foreground
style.fg = null;
} else if (code >= 40 && code <= 47) {
// Standard background colors
style.bg = ANSI_COLORS[code - 40];
} else if (code === 48) {
// Extended background color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 48;5;N
style.bg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 48;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.bg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 49) {
// Default background
style.bg = null;
} else if (code >= 90 && code <= 97) {
// Bright foreground colors
style.fg = ANSI_COLORS[code - 90 + 8];
} else if (code >= 100 && code <= 107) {
// Bright background colors
style.bg = ANSI_COLORS[code - 100 + 8];
}
// Ignore unrecognized codes
i++;
}
}
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
/**
* Convert ANSI-escaped text to HTML with inline styles.
*/
export function ansiToHtml(text: string): string {
const style = createEmptyStyle();
let result = "";
let lastIndex = 0;
let inSpan = false;
// Reset regex state
ANSI_REGEX.lastIndex = 0;
let match = ANSI_REGEX.exec(text);
while (match !== null) {
// Add text before this escape sequence
const beforeText = text.slice(lastIndex, match.index);
if (beforeText) {
result += escapeHtml(beforeText);
}
// Parse SGR parameters
const paramStr = match[1];
const params = paramStr
? paramStr.split(";").map((p) => parseInt(p, 10) || 0)
: [0];
// Close existing span if we have one
if (inSpan) {
result += "</span>";
inSpan = false;
}
// Apply the codes
applySgrCode(params, style);
// Open new span if we have any styling
if (hasStyle(style)) {
result += `<span style="${styleToInlineCSS(style)}">`;
inSpan = true;
}
lastIndex = match.index + match[0].length;
match = ANSI_REGEX.exec(text);
}
// Add remaining text
const remainingText = text.slice(lastIndex);
if (remainingText) {
result += escapeHtml(remainingText);
}
// Close any open span
if (inSpan) {
result += "</span>";
}
return result;
}
/**
* Convert array of ANSI-escaped lines to HTML.
* Each line is wrapped in a div element.
*/
export function ansiLinesToHtml(lines: string[]): string {
return lines
.map(
(line) => `<div class="ansi-line">${ansiToHtml(line) || "&nbsp;"}</div>`,
)
.join("\n");
}

View file

@ -0,0 +1,353 @@
import type { AgentState } from "@mariozechner/pi-agent-core";
import { existsSync, readFileSync, writeFileSync } from "fs";
import { basename, join } from "path";
import { APP_NAME, getExportTemplateDir } from "../../config.js";
import {
getResolvedThemeColors,
getThemeExportColors,
} from "../../modes/interactive/theme/theme.js";
import type { ToolInfo } from "../extensions/types.js";
import type { SessionEntry } from "../session-manager.js";
import { SessionManager } from "../session-manager.js";
/**
* Interface for rendering custom tools to HTML.
* Used by agent-session to pre-render extension tool output.
*/
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined;
}
/** Pre-rendered HTML for a custom tool call and result */
interface RenderedToolHtml {
callHtml?: string;
resultHtml?: string;
}
export interface ExportOptions {
outputPath?: string;
themeName?: string;
/** Optional tool renderer for custom tools */
toolRenderer?: ToolHtmlRenderer;
}
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
function parseColor(
color: string,
): { r: number; g: number; b: number } | undefined {
const hexMatch = color.match(
/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/,
);
if (hexMatch) {
return {
r: Number.parseInt(hexMatch[1], 16),
g: Number.parseInt(hexMatch[2], 16),
b: Number.parseInt(hexMatch[3], 16),
};
}
const rgbMatch = color.match(
/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/,
);
if (rgbMatch) {
return {
r: Number.parseInt(rgbMatch[1], 10),
g: Number.parseInt(rgbMatch[2], 10),
b: Number.parseInt(rgbMatch[3], 10),
};
}
return undefined;
}
/** Calculate relative luminance of a color (0-1, higher = lighter). */
function getLuminance(r: number, g: number, b: number): number {
const toLinear = (c: number) => {
const s = c / 255;
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
};
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
}
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
function adjustBrightness(color: string, factor: number): string {
const parsed = parseColor(color);
if (!parsed) return color;
const adjust = (c: number) =>
Math.min(255, Math.max(0, Math.round(c * factor)));
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
}
/** Derive export background colors from a base color (e.g., userMessageBg). */
function deriveExportColors(baseColor: string): {
pageBg: string;
cardBg: string;
infoBg: string;
} {
const parsed = parseColor(baseColor);
if (!parsed) {
return {
pageBg: "rgb(24, 24, 30)",
cardBg: "rgb(30, 30, 36)",
infoBg: "rgb(60, 55, 40)",
};
}
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
const isLight = luminance > 0.5;
if (isLight) {
return {
pageBg: adjustBrightness(baseColor, 0.96),
cardBg: baseColor,
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
};
}
return {
pageBg: adjustBrightness(baseColor, 0.7),
cardBg: adjustBrightness(baseColor, 0.85),
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
};
}
/**
* Generate CSS custom property declarations from theme colors.
*/
function generateThemeVars(themeName?: string): string {
const colors = getResolvedThemeColors(themeName);
const lines: string[] = [];
for (const [key, value] of Object.entries(colors)) {
lines.push(`--${key}: ${value};`);
}
// Use explicit theme export colors if available, otherwise derive from userMessageBg
const themeExport = getThemeExportColors(themeName);
const userMessageBg = colors.userMessageBg || "#343541";
const derivedColors = deriveExportColors(userMessageBg);
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
return lines.join("\n ");
}
interface SessionData {
header: ReturnType<SessionManager["getHeader"]>;
entries: ReturnType<SessionManager["getEntries"]>;
leafId: string | null;
systemPrompt?: string;
tools?: ToolInfo[];
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
renderedTools?: Record<string, RenderedToolHtml>;
}
/**
* Core HTML generation logic shared by both export functions.
*/
function generateHtml(sessionData: SessionData, themeName?: string): string {
const templateDir = getExportTemplateDir();
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
const markedJs = readFileSync(
join(templateDir, "vendor", "marked.min.js"),
"utf-8",
);
const hljsJs = readFileSync(
join(templateDir, "vendor", "highlight.min.js"),
"utf-8",
);
const themeVars = generateThemeVars(themeName);
const colors = getResolvedThemeColors(themeName);
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
const bodyBg = exportColors.pageBg;
const containerBg = exportColors.cardBg;
const infoBg = exportColors.infoBg;
// Base64 encode session data to avoid escaping issues
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString(
"base64",
);
// Build the CSS with theme variables injected
const css = templateCss
.replace("{{THEME_VARS}}", themeVars)
.replace("{{BODY_BG}}", bodyBg)
.replace("{{CONTAINER_BG}}", containerBg)
.replace("{{INFO_BG}}", infoBg);
return template
.replace("{{CSS}}", css)
.replace("{{JS}}", templateJs)
.replace("{{SESSION_DATA}}", sessionDataBase64)
.replace("{{MARKED_JS}}", markedJs)
.replace("{{HIGHLIGHT_JS}}", hljsJs);
}
/** Built-in tool names that have custom rendering in template.js */
const BUILTIN_TOOLS = new Set([
"bash",
"read",
"write",
"edit",
"ls",
"find",
"grep",
]);
/**
* Pre-render custom tools to HTML using their TUI renderers.
*/
function preRenderCustomTools(
entries: SessionEntry[],
toolRenderer: ToolHtmlRenderer,
): Record<string, RenderedToolHtml> {
const renderedTools: Record<string, RenderedToolHtml> = {};
for (const entry of entries) {
if (entry.type !== "message") continue;
const msg = entry.message;
// Find tool calls in assistant messages
if (msg.role === "assistant" && Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
if (callHtml) {
renderedTools[block.id] = { callHtml };
}
}
}
}
// Find tool results
if (msg.role === "toolResult" && msg.toolCallId) {
const toolName = msg.toolName || "";
// Only render if we have a pre-rendered call OR it's not a built-in tool
const existing = renderedTools[msg.toolCallId];
if (existing || !BUILTIN_TOOLS.has(toolName)) {
const resultHtml = toolRenderer.renderResult(
toolName,
msg.content,
msg.details,
msg.isError || false,
);
if (resultHtml) {
renderedTools[msg.toolCallId] = {
...existing,
resultHtml,
};
}
}
}
}
return renderedTools;
}
/**
* Export session to HTML using SessionManager and AgentState.
* Used by TUI's /export command.
*/
export async function exportSessionToHtml(
sm: SessionManager,
state?: AgentState,
options?: ExportOptions | string,
): Promise<string> {
const opts: ExportOptions =
typeof options === "string" ? { outputPath: options } : options || {};
const sessionFile = sm.getSessionFile();
if (!sessionFile) {
throw new Error("Cannot export in-memory session to HTML");
}
if (!existsSync(sessionFile)) {
throw new Error("Nothing to export yet - start a conversation first");
}
const entries = sm.getEntries();
// Pre-render custom tools if a tool renderer is provided
let renderedTools: Record<string, RenderedToolHtml> | undefined;
if (opts.toolRenderer) {
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
// Only include if we actually rendered something
if (Object.keys(renderedTools).length === 0) {
renderedTools = undefined;
}
}
const sessionData: SessionData = {
header: sm.getHeader(),
entries,
leafId: sm.getLeafId(),
systemPrompt: state?.systemPrompt,
tools: state?.tools?.map((t) => ({
name: t.name,
description: t.description,
parameters: t.parameters,
})),
renderedTools,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const sessionBasename = basename(sessionFile, ".jsonl");
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}
/**
* Export session file to HTML (standalone, without AgentState).
* Used by CLI for exporting arbitrary session files.
*/
export async function exportFromFile(
inputPath: string,
options?: ExportOptions | string,
): Promise<string> {
const opts: ExportOptions =
typeof options === "string" ? { outputPath: options } : options || {};
if (!existsSync(inputPath)) {
throw new Error(`File not found: ${inputPath}`);
}
const sm = SessionManager.open(inputPath);
const sessionData: SessionData = {
header: sm.getHeader(),
entries: sm.getEntries(),
leafId: sm.getLeafId(),
systemPrompt: undefined,
tools: undefined,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const inputBasename = basename(inputPath, ".jsonl");
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}

View file

@ -0,0 +1,971 @@
:root {
{{THEME_VARS}}
--body-bg: {{BODY_BG}};
--container-bg: {{CONTAINER_BG}};
--info-bg: {{INFO_BG}};
}
* { margin: 0; padding: 0; box-sizing: border-box; }
:root {
--line-height: 18px; /* 12px font * 1.5 */
}
body {
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
font-size: 12px;
line-height: var(--line-height);
color: var(--text);
background: var(--body-bg);
}
#app {
display: flex;
min-height: 100vh;
}
/* Sidebar */
#sidebar {
width: 400px;
background: var(--container-bg);
flex-shrink: 0;
display: flex;
flex-direction: column;
position: sticky;
top: 0;
height: 100vh;
border-right: 1px solid var(--dim);
}
.sidebar-header {
padding: 8px 12px;
flex-shrink: 0;
}
.sidebar-controls {
padding: 8px 8px 4px 8px;
}
.sidebar-search {
width: 100%;
box-sizing: border-box;
padding: 4px 8px;
font-size: 11px;
font-family: inherit;
background: var(--body-bg);
color: var(--text);
border: 1px solid var(--dim);
border-radius: 3px;
}
.sidebar-filters {
display: flex;
padding: 4px 8px 8px 8px;
gap: 4px;
align-items: center;
flex-wrap: wrap;
}
.sidebar-search:focus {
outline: none;
border-color: var(--accent);
}
.sidebar-search::placeholder {
color: var(--muted);
}
.filter-btn {
padding: 3px 8px;
font-size: 10px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
.filter-btn:hover {
color: var(--text);
border-color: var(--text);
}
.filter-btn.active {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.sidebar-close {
display: none;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
margin-left: auto;
}
.sidebar-close:hover {
color: var(--text);
border-color: var(--text);
}
.tree-container {
flex: 1;
overflow: auto;
padding: 4px 0;
}
.tree-node {
padding: 0 8px;
cursor: pointer;
display: flex;
align-items: baseline;
font-size: 11px;
line-height: 13px;
white-space: nowrap;
}
.tree-node:hover {
background: var(--selectedBg);
}
.tree-node.active {
background: var(--selectedBg);
}
.tree-node.active .tree-content {
font-weight: bold;
}
.tree-node.in-path {
background: color-mix(in srgb, var(--accent) 10%, transparent);
}
.tree-node:not(.in-path) {
opacity: 0.5;
}
.tree-node:not(.in-path):hover {
opacity: 1;
}
.tree-prefix {
color: var(--muted);
flex-shrink: 0;
font-family: monospace;
white-space: pre;
}
.tree-marker {
color: var(--accent);
flex-shrink: 0;
}
.tree-content {
color: var(--text);
}
.tree-role-user {
color: var(--accent);
}
.tree-role-assistant {
color: var(--success);
}
.tree-role-tool {
color: var(--muted);
}
.tree-muted {
color: var(--muted);
}
.tree-error {
color: var(--error);
}
.tree-compaction {
color: var(--borderAccent);
}
.tree-branch-summary {
color: var(--warning);
}
.tree-custom-message {
color: var(--customMessageLabel);
}
.tree-status {
padding: 4px 12px;
font-size: 10px;
color: var(--muted);
flex-shrink: 0;
}
/* Main content */
#content {
flex: 1;
overflow-y: auto;
padding: var(--line-height) calc(var(--line-height) * 2);
display: flex;
flex-direction: column;
align-items: center;
}
#content > * {
width: 100%;
max-width: 800px;
}
/* Help bar */
.help-bar {
font-size: 11px;
color: var(--warning);
margin-bottom: var(--line-height);
display: flex;
align-items: center;
gap: 12px;
}
.download-json-btn {
font-size: 10px;
padding: 2px 8px;
background: var(--container-bg);
border: 1px solid var(--border);
border-radius: 3px;
color: var(--text);
cursor: pointer;
font-family: inherit;
}
.download-json-btn:hover {
background: var(--hover);
border-color: var(--borderAccent);
}
/* Header */
.header {
background: var(--container-bg);
border-radius: 4px;
padding: var(--line-height);
margin-bottom: var(--line-height);
}
.header h1 {
font-size: 12px;
font-weight: bold;
color: var(--borderAccent);
margin-bottom: var(--line-height);
}
.header-info {
display: flex;
flex-direction: column;
gap: 0;
font-size: 11px;
}
.info-item {
color: var(--dim);
display: flex;
align-items: baseline;
}
.info-label {
font-weight: 600;
margin-right: 8px;
min-width: 100px;
}
.info-value {
color: var(--text);
flex: 1;
}
/* Messages */
#messages {
display: flex;
flex-direction: column;
gap: var(--line-height);
}
.message-timestamp {
font-size: 10px;
color: var(--dim);
opacity: 0.8;
}
.user-message {
background: var(--userMessageBg);
color: var(--userMessageText);
padding: var(--line-height);
border-radius: 4px;
position: relative;
}
.assistant-message {
padding: 0;
position: relative;
}
/* Copy link button - appears on hover */
.copy-link-btn {
position: absolute;
top: 8px;
right: 8px;
width: 28px;
height: 28px;
padding: 6px;
background: var(--container-bg);
border: 1px solid var(--dim);
border-radius: 4px;
color: var(--muted);
cursor: pointer;
opacity: 0;
transition: opacity 0.15s, background 0.15s, color 0.15s;
display: flex;
align-items: center;
justify-content: center;
z-index: 10;
}
.user-message:hover .copy-link-btn,
.assistant-message:hover .copy-link-btn {
opacity: 1;
}
.copy-link-btn:hover {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.copy-link-btn.copied {
background: var(--success, #22c55e);
color: white;
border-color: var(--success, #22c55e);
}
/* Highlight effect for deep-linked messages */
.user-message.highlight,
.assistant-message.highlight {
animation: highlight-pulse 2s ease-out;
}
@keyframes highlight-pulse {
0% {
box-shadow: 0 0 0 3px var(--accent);
}
100% {
box-shadow: 0 0 0 0 transparent;
}
}
.assistant-message > .message-timestamp {
padding-left: var(--line-height);
}
.assistant-text {
padding: var(--line-height);
padding-bottom: 0;
}
.message-timestamp + .assistant-text,
.message-timestamp + .thinking-block {
padding-top: 0;
}
.thinking-block + .assistant-text {
padding-top: 0;
}
.thinking-text {
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
white-space: pre-wrap;
}
.message-timestamp + .thinking-block .thinking-text,
.message-timestamp + .thinking-block .thinking-collapsed {
padding-top: 0;
}
.thinking-collapsed {
display: none;
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
}
/* Tool execution */
.tool-execution {
padding: var(--line-height);
border-radius: 4px;
}
.tool-execution + .tool-execution {
margin-top: var(--line-height);
}
.assistant-text + .tool-execution {
margin-top: var(--line-height);
}
.tool-execution.pending { background: var(--toolPendingBg); }
.tool-execution.success { background: var(--toolSuccessBg); }
.tool-execution.error { background: var(--toolErrorBg); }
.tool-header, .tool-name {
font-weight: bold;
}
.tool-path {
color: var(--accent);
word-break: break-all;
}
.line-numbers {
color: var(--warning);
}
.line-count {
color: var(--dim);
}
.tool-command {
font-weight: bold;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
}
.tool-output {
margin-top: var(--line-height);
color: var(--toolOutput);
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
font-family: inherit;
overflow-x: auto;
}
.tool-output > div,
.output-preview,
.output-full {
margin: 0;
padding: 0;
line-height: var(--line-height);
}
.tool-output pre {
margin: 0;
padding: 0;
font-family: inherit;
color: inherit;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
}
.tool-output code {
padding: 0;
background: none;
color: var(--text);
}
.tool-output.expandable {
cursor: pointer;
}
.tool-output.expandable:hover {
opacity: 0.9;
}
.tool-output.expandable .output-full {
display: none;
}
.tool-output.expandable.expanded .output-preview {
display: none;
}
.tool-output.expandable.expanded .output-full {
display: block;
}
.ansi-line {
white-space: pre-wrap;
}
.tool-images {
}
.tool-image {
max-width: 100%;
max-height: 500px;
border-radius: 4px;
margin: var(--line-height) 0;
}
.expand-hint {
color: var(--toolOutput);
}
/* Diff */
.tool-diff {
font-size: 11px;
overflow-x: auto;
white-space: pre;
}
.diff-added { color: var(--toolDiffAdded); }
.diff-removed { color: var(--toolDiffRemoved); }
.diff-context { color: var(--toolDiffContext); }
/* Model change */
.model-change {
padding: 0 var(--line-height);
color: var(--dim);
font-size: 11px;
}
.model-name {
color: var(--borderAccent);
font-weight: bold;
}
/* Compaction / Branch Summary - matches customMessage colors from TUI */
.compaction {
background: var(--customMessageBg);
border-radius: 4px;
padding: var(--line-height);
cursor: pointer;
}
.compaction-label {
color: var(--customMessageLabel);
font-weight: bold;
}
.compaction-collapsed {
color: var(--customMessageText);
}
.compaction-content {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
margin-top: var(--line-height);
}
.compaction.expanded .compaction-collapsed {
display: none;
}
.compaction.expanded .compaction-content {
display: block;
}
/* System prompt */
.system-prompt {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.system-prompt.expandable {
cursor: pointer;
}
.system-prompt-header {
font-weight: bold;
color: var(--customMessageLabel);
}
.system-prompt-preview {
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt-expand-hint {
color: var(--muted);
font-style: italic;
margin-top: 4px;
}
.system-prompt-full {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt.expanded .system-prompt-preview,
.system-prompt.expanded .system-prompt-expand-hint {
display: none;
}
.system-prompt.expanded .system-prompt-full {
display: block;
}
.system-prompt.provider-prompt {
border-left: 3px solid var(--warning);
}
.system-prompt-note {
font-size: 10px;
font-style: italic;
color: var(--muted);
margin-top: 4px;
}
/* Tools list */
.tools-list {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.tools-header {
font-weight: bold;
color: var(--customMessageLabel);
margin-bottom: var(--line-height);
}
.tool-item {
font-size: 11px;
}
.tool-item-name {
font-weight: bold;
color: var(--text);
}
.tool-item-desc {
color: var(--dim);
}
.tool-params-hint {
color: var(--muted);
font-style: italic;
}
.tool-item:has(.tool-params-hint) {
cursor: pointer;
}
.tool-params-hint::after {
content: '[click to show parameters]';
}
.tool-item.params-expanded .tool-params-hint::after {
content: '[hide parameters]';
}
.tool-params-content {
display: none;
margin-top: 4px;
margin-left: 12px;
padding-left: 8px;
border-left: 1px solid var(--dim);
}
.tool-item.params-expanded .tool-params-content {
display: block;
}
.tool-param {
margin-bottom: 4px;
font-size: 11px;
}
.tool-param-name {
font-weight: bold;
color: var(--text);
}
.tool-param-type {
color: var(--dim);
font-style: italic;
}
.tool-param-required {
color: var(--warning, #e8a838);
font-size: 10px;
}
.tool-param-optional {
color: var(--dim);
font-size: 10px;
}
.tool-param-desc {
color: var(--dim);
margin-left: 8px;
}
/* Hook/custom messages */
.hook-message {
background: var(--customMessageBg);
color: var(--customMessageText);
padding: var(--line-height);
border-radius: 4px;
}
.hook-type {
color: var(--customMessageLabel);
font-weight: bold;
}
/* Branch summary */
.branch-summary {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
}
.branch-summary-header {
font-weight: bold;
color: var(--borderAccent);
}
/* Error */
.error-text {
color: var(--error);
padding: 0 var(--line-height);
}
.tool-error {
color: var(--error);
}
/* Images */
.message-images {
margin-bottom: 12px;
}
.message-image {
max-width: 100%;
max-height: 400px;
border-radius: 4px;
margin: var(--line-height) 0;
}
/* Markdown content */
.markdown-content h1,
.markdown-content h2,
.markdown-content h3,
.markdown-content h4,
.markdown-content h5,
.markdown-content h6 {
color: var(--mdHeading);
margin: var(--line-height) 0 0 0;
font-weight: bold;
}
.markdown-content h1 { font-size: 1em; }
.markdown-content h2 { font-size: 1em; }
.markdown-content h3 { font-size: 1em; }
.markdown-content h4 { font-size: 1em; }
.markdown-content h5 { font-size: 1em; }
.markdown-content h6 { font-size: 1em; }
.markdown-content p { margin: 0; }
.markdown-content p + p { margin-top: var(--line-height); }
.markdown-content a {
color: var(--mdLink);
text-decoration: underline;
}
.markdown-content code {
background: rgba(128, 128, 128, 0.2);
color: var(--mdCode);
padding: 0 4px;
border-radius: 3px;
font-family: inherit;
}
.markdown-content pre {
background: transparent;
margin: var(--line-height) 0;
overflow-x: auto;
}
.markdown-content pre code {
display: block;
background: none;
color: var(--text);
}
.markdown-content blockquote {
border-left: 3px solid var(--mdQuoteBorder);
padding-left: var(--line-height);
margin: var(--line-height) 0;
color: var(--mdQuote);
font-style: italic;
}
.markdown-content ul,
.markdown-content ol {
margin: var(--line-height) 0;
padding-left: calc(var(--line-height) * 2);
}
.markdown-content li { margin: 0; }
.markdown-content li::marker { color: var(--mdListBullet); }
.markdown-content hr {
border: none;
border-top: 1px solid var(--mdHr);
margin: var(--line-height) 0;
}
.markdown-content table {
border-collapse: collapse;
margin: 0.5em 0;
width: 100%;
}
.markdown-content th,
.markdown-content td {
border: 1px solid var(--mdCodeBlockBorder);
padding: 6px 10px;
text-align: left;
}
.markdown-content th {
background: rgba(128, 128, 128, 0.1);
font-weight: bold;
}
.markdown-content img {
max-width: 100%;
border-radius: 4px;
}
/* Syntax highlighting */
.hljs { background: transparent; color: var(--text); }
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
/* Types: hljs v11 uses .hljs-title.class_ for class names */
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
.hljs-operator { color: var(--syntaxOperator); }
.hljs-punctuation { color: var(--syntaxPunctuation); }
.hljs-subst { color: var(--text); }
/* Footer */
.footer {
margin-top: 48px;
padding: 20px;
text-align: center;
color: var(--dim);
font-size: 10px;
}
/* Mobile */
#hamburger {
display: none;
position: fixed;
top: 10px;
left: 10px;
z-index: 100;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
#hamburger:hover {
color: var(--text);
border-color: var(--text);
}
#sidebar-overlay {
display: none;
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
z-index: 98;
}
@media (max-width: 900px) {
#sidebar {
position: fixed;
left: -400px;
width: 400px;
top: 0;
bottom: 0;
height: 100vh;
z-index: 99;
transition: left 0.3s;
}
#sidebar.open {
left: 0;
}
#sidebar-overlay.open {
display: block;
}
#hamburger {
display: block;
}
.sidebar-close {
display: block;
}
#content {
padding: var(--line-height) 16px;
}
#content > * {
max-width: 100%;
}
}
@media (max-width: 500px) {
#sidebar {
width: 100vw;
left: -100vw;
}
}
@media print {
#sidebar, #sidebar-toggle { display: none !important; }
body { background: white; color: black; }
#content { max-width: none; }
}

View file

@ -0,0 +1,54 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Session Export</title>
<style>
{{CSS}}
</style>
</head>
<body>
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
<div id="sidebar-overlay"></div>
<div id="app">
<aside id="sidebar">
<div class="sidebar-header">
<div class="sidebar-controls">
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
</div>
<div class="sidebar-filters">
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
<button class="sidebar-close" id="sidebar-close" title="Close"></button>
</div>
</div>
<div class="tree-container" id="tree-container"></div>
<div class="tree-status" id="tree-status"></div>
</aside>
<main id="content">
<div id="header-container"></div>
<div id="messages"></div>
</main>
<div id="image-modal" class="image-modal">
<img id="modal-image" src="" alt="">
</div>
</div>
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
<!-- Vendored libraries -->
<script>{{MARKED_JS}}</script>
<!-- highlight.js -->
<script>{{HIGHLIGHT_JS}}</script>
<!-- Main application code -->
<script>
{{JS}}
</script>
</body>
</html>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,112 @@
/**
* Tool HTML renderer for custom tools in HTML export.
*
* Renders custom tool calls and results to HTML by invoking their TUI renderers
* and converting the ANSI output to HTML.
*/
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ToolDefinition } from "../extensions/types.js";
import { ansiLinesToHtml } from "./ansi-to-html.js";
export interface ToolHtmlRendererDeps {
/** Function to look up tool definition by name */
getToolDefinition: (name: string) => ToolDefinition | undefined;
/** Theme for styling */
theme: Theme;
/** Terminal width for rendering (default: 100) */
width?: number;
}
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined;
}
/**
* Create a tool HTML renderer.
*
* The renderer looks up tool definitions and invokes their renderCall/renderResult
* methods, converting the resulting TUI Component output (ANSI) to HTML.
*/
export function createToolHtmlRenderer(
deps: ToolHtmlRendererDeps,
): ToolHtmlRenderer {
const { getToolDefinition, theme, width = 100 } = deps;
return {
renderCall(toolName: string, args: unknown): string | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderCall) {
return undefined;
}
const component = toolDef.renderCall(args, theme);
if (!component) {
return undefined;
}
const lines = component.render(width);
return ansiLinesToHtml(lines);
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
renderResult(
toolName: string,
result: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>,
details: unknown,
isError: boolean,
): string | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderResult) {
return undefined;
}
// Build AgentToolResult from content array
// Cast content since session storage uses generic object types
const agentToolResult = {
content: result as (TextContent | ImageContent)[],
details,
isError,
};
// Always render expanded, client-side will apply truncation
const component = toolDef.renderResult(
agentToolResult,
{ expanded: true, isPartial: false },
theme,
);
if (!component) {
return undefined;
}
const lines = component.render(width);
return ansiLinesToHtml(lines);
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
};
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,170 @@
/**
* Extension system for lifecycle events and custom tools.
*/
export type {
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
} from "../slash-commands.js";
export {
createExtensionRuntime,
discoverAndLoadExtensions,
loadExtensionFromFactory,
loadExtensions,
} from "./loader.js";
export type {
ExtensionErrorListener,
ForkHandler,
NavigateTreeHandler,
NewSessionHandler,
ShutdownHandler,
SwitchSessionHandler,
} from "./runner.js";
export { ExtensionRunner } from "./runner.js";
export type {
AgentEndEvent,
AgentStartEvent,
// Re-exports
AgentToolResult,
AgentToolUpdateCallback,
// App keybindings (for custom editors)
AppAction,
AppendEntryHandler,
// Events - Tool (ToolCallEvent types)
BashToolCallEvent,
BashToolResultEvent,
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
// Context
CompactOptions,
// Events - Agent
ContextEvent,
// Event Results
ContextEventResult,
ContextUsage,
CustomToolCallEvent,
CustomToolResultEvent,
EditToolCallEvent,
EditToolResultEvent,
ExecOptions,
ExecResult,
Extension,
ExtensionActions,
// API
ExtensionAPI,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
// Errors
ExtensionError,
ExtensionEvent,
ExtensionFactory,
ExtensionFlag,
ExtensionHandler,
// Runtime
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
ExtensionUIDialogOptions,
ExtensionWidgetOptions,
FindToolCallEvent,
FindToolResultEvent,
GetActiveToolsHandler,
GetAllToolsHandler,
GetCommandsHandler,
GetThinkingLevelHandler,
GrepToolCallEvent,
GrepToolResultEvent,
// Events - Input
InputEvent,
InputEventResult,
InputSource,
KeybindingsManager,
LoadExtensionsResult,
LsToolCallEvent,
LsToolResultEvent,
// Events - Message
MessageEndEvent,
// Message Rendering
MessageRenderer,
MessageRenderOptions,
MessageStartEvent,
MessageUpdateEvent,
ModelSelectEvent,
ModelSelectSource,
// Provider Registration
ProviderConfig,
ProviderModelConfig,
ReadToolCallEvent,
ReadToolResultEvent,
// Commands
RegisteredCommand,
RegisteredTool,
// Events - Resources
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SendMessageHandler,
SendUserMessageHandler,
SessionBeforeCompactEvent,
SessionBeforeCompactResult,
SessionBeforeForkEvent,
SessionBeforeForkResult,
SessionBeforeSwitchEvent,
SessionBeforeSwitchResult,
SessionBeforeTreeEvent,
SessionBeforeTreeResult,
SessionCompactEvent,
SessionEvent,
SessionForkEvent,
SessionShutdownEvent,
// Events - Session
SessionStartEvent,
SessionSwitchEvent,
SessionTreeEvent,
SetActiveToolsHandler,
SetLabelHandler,
SetModelHandler,
SetThinkingLevelHandler,
TerminalInputHandler,
// Events - Tool
ToolCallEvent,
ToolCallEventResult,
// Tools
ToolDefinition,
// Events - Tool Execution
ToolExecutionEndEvent,
ToolExecutionStartEvent,
ToolExecutionUpdateEvent,
ToolInfo,
ToolRenderResultOptions,
ToolResultEvent,
ToolResultEventResult,
TreePreparation,
TurnEndEvent,
TurnStartEvent,
// Events - User Bash
UserBashEvent,
UserBashEventResult,
WidgetPlacement,
WriteToolCallEvent,
WriteToolResultEvent,
} from "./types.js";
// Type guards
export {
isBashToolResult,
isEditToolResult,
isFindToolResult,
isGrepToolResult,
isLsToolResult,
isReadToolResult,
isToolCallEventType,
isWriteToolResult,
} from "./types.js";
export {
wrapRegisteredTool,
wrapRegisteredTools,
wrapToolsWithExtensions,
wrapToolWithExtensions,
} from "./wrapper.js";

View file

@ -0,0 +1,607 @@
/**
* Extension loader - loads TypeScript extension modules using jiti.
*
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
*/
import * as fs from "node:fs";
import { createRequire } from "node:module";
import * as os from "node:os";
import * as path from "node:path";
import { fileURLToPath } from "node:url";
import { createJiti } from "@mariozechner/jiti";
import * as _bundledPiAgentCore from "@mariozechner/pi-agent-core";
import * as _bundledPiAi from "@mariozechner/pi-ai";
import * as _bundledPiAiOauth from "@mariozechner/pi-ai/oauth";
import type { KeyId } from "@mariozechner/pi-tui";
import * as _bundledPiTui from "@mariozechner/pi-tui";
// Static imports of packages that extensions may use.
// These MUST be static so Bun bundles them into the compiled binary.
// The virtualModules option then makes them available to extensions.
import * as _bundledTypebox from "@sinclair/typebox";
import { getAgentDir, isBunBinary } from "../../config.js";
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
// avoiding a circular dependency. Extensions can import from @mariozechner/pi-coding-agent.
import * as _bundledPiCodingAgent from "../../index.js";
import { createEventBus, type EventBus } from "../event-bus.js";
import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js";
import type {
Extension,
ExtensionAPI,
ExtensionFactory,
ExtensionRuntime,
LoadExtensionsResult,
MessageRenderer,
ProviderConfig,
RegisteredCommand,
ToolDefinition,
} from "./types.js";
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
const VIRTUAL_MODULES: Record<string, unknown> = {
"@sinclair/typebox": _bundledTypebox,
"@mariozechner/pi-agent-core": _bundledPiAgentCore,
"@mariozechner/pi-tui": _bundledPiTui,
"@mariozechner/pi-ai": _bundledPiAi,
"@mariozechner/pi-ai/oauth": _bundledPiAiOauth,
"@mariozechner/pi-coding-agent": _bundledPiCodingAgent,
};
const require = createRequire(import.meta.url);
/**
* Get aliases for jiti (used in Node.js/development mode).
* In Bun binary mode, virtualModules is used instead.
*/
let _aliases: Record<string, string> | null = null;
function getAliases(): Record<string, string> {
if (_aliases) return _aliases;
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const packageIndex = path.resolve(__dirname, "../..", "index.js");
const typeboxEntry = require.resolve("@sinclair/typebox");
const typeboxRoot = typeboxEntry.replace(
/[\\/]build[\\/]cjs[\\/]index\.js$/,
"",
);
const packagesRoot = path.resolve(__dirname, "../../../../");
const resolveWorkspaceOrImport = (
workspaceRelativePath: string,
specifier: string,
): string => {
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
if (fs.existsSync(workspacePath)) {
return workspacePath;
}
return fileURLToPath(import.meta.resolve(specifier));
};
_aliases = {
"@mariozechner/pi-coding-agent": packageIndex,
"@mariozechner/pi-agent-core": resolveWorkspaceOrImport(
"agent/dist/index.js",
"@mariozechner/pi-agent-core",
),
"@mariozechner/pi-tui": resolveWorkspaceOrImport(
"tui/dist/index.js",
"@mariozechner/pi-tui",
),
"@mariozechner/pi-ai": resolveWorkspaceOrImport(
"ai/dist/index.js",
"@mariozechner/pi-ai",
),
"@mariozechner/pi-ai/oauth": resolveWorkspaceOrImport(
"ai/dist/oauth.js",
"@mariozechner/pi-ai/oauth",
),
"@sinclair/typebox": typeboxRoot,
};
return _aliases;
}
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
function normalizeUnicodeSpaces(str: string): string {
return str.replace(UNICODE_SPACES, " ");
}
function expandPath(p: string): string {
const normalized = normalizeUnicodeSpaces(p);
if (normalized.startsWith("~/")) {
return path.join(os.homedir(), normalized.slice(2));
}
if (normalized.startsWith("~")) {
return path.join(os.homedir(), normalized.slice(1));
}
return normalized;
}
function resolvePath(extPath: string, cwd: string): string {
const expanded = expandPath(extPath);
if (path.isAbsolute(expanded)) {
return expanded;
}
return path.resolve(cwd, expanded);
}
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
/**
* Create a runtime with throwing stubs for action methods.
* Runner.bindCore() replaces these with real implementations.
*/
export function createExtensionRuntime(): ExtensionRuntime {
const notInitialized = () => {
throw new Error(
"Extension runtime not initialized. Action methods cannot be called during extension loading.",
);
};
const runtime: ExtensionRuntime = {
sendMessage: notInitialized,
sendUserMessage: notInitialized,
appendEntry: notInitialized,
setSessionName: notInitialized,
getSessionName: notInitialized,
setLabel: notInitialized,
getActiveTools: notInitialized,
getAllTools: notInitialized,
setActiveTools: notInitialized,
// registerTool() is valid during extension load; refresh is only needed post-bind.
refreshTools: () => {},
getCommands: notInitialized,
setModel: () =>
Promise.reject(new Error("Extension runtime not initialized")),
getThinkingLevel: notInitialized,
setThinkingLevel: notInitialized,
flagValues: new Map(),
pendingProviderRegistrations: [],
// Pre-bind: queue registrations so bindCore() can flush them once the
// model registry is available. bindCore() replaces both with direct calls.
registerProvider: (name, config) => {
runtime.pendingProviderRegistrations.push({ name, config });
},
unregisterProvider: (name) => {
runtime.pendingProviderRegistrations =
runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
},
};
return runtime;
}
/**
* Create the ExtensionAPI for an extension.
* Registration methods write to the extension object.
* Action methods delegate to the shared runtime.
*/
function createExtensionAPI(
extension: Extension,
runtime: ExtensionRuntime,
cwd: string,
eventBus: EventBus,
): ExtensionAPI {
const api = {
// Registration methods - write to extension
on(event: string, handler: HandlerFn): void {
const list = extension.handlers.get(event) ?? [];
list.push(handler);
extension.handlers.set(event, list);
},
registerTool(tool: ToolDefinition): void {
extension.tools.set(tool.name, {
definition: tool,
extensionPath: extension.path,
});
runtime.refreshTools();
},
registerCommand(
name: string,
options: Omit<RegisteredCommand, "name">,
): void {
extension.commands.set(name, { name, ...options });
},
registerShortcut(
shortcut: KeyId,
options: {
description?: string;
handler: (
ctx: import("./types.js").ExtensionContext,
) => Promise<void> | void;
},
): void {
extension.shortcuts.set(shortcut, {
shortcut,
extensionPath: extension.path,
...options,
});
},
registerFlag(
name: string,
options: {
description?: string;
type: "boolean" | "string";
default?: boolean | string;
},
): void {
extension.flags.set(name, {
name,
extensionPath: extension.path,
...options,
});
if (options.default !== undefined && !runtime.flagValues.has(name)) {
runtime.flagValues.set(name, options.default);
}
},
registerMessageRenderer<T>(
customType: string,
renderer: MessageRenderer<T>,
): void {
extension.messageRenderers.set(customType, renderer as MessageRenderer);
},
// Flag access - checks extension registered it, reads from runtime
getFlag(name: string): boolean | string | undefined {
if (!extension.flags.has(name)) return undefined;
return runtime.flagValues.get(name);
},
// Action methods - delegate to shared runtime
sendMessage(message, options): void {
runtime.sendMessage(message, options);
},
sendUserMessage(content, options): void {
runtime.sendUserMessage(content, options);
},
appendEntry(customType: string, data?: unknown): void {
runtime.appendEntry(customType, data);
},
setSessionName(name: string): void {
runtime.setSessionName(name);
},
getSessionName(): string | undefined {
return runtime.getSessionName();
},
setLabel(entryId: string, label: string | undefined): void {
runtime.setLabel(entryId, label);
},
exec(command: string, args: string[], options?: ExecOptions) {
return execCommand(command, args, options?.cwd ?? cwd, options);
},
getActiveTools(): string[] {
return runtime.getActiveTools();
},
getAllTools() {
return runtime.getAllTools();
},
setActiveTools(toolNames: string[]): void {
runtime.setActiveTools(toolNames);
},
getCommands() {
return runtime.getCommands();
},
setModel(model) {
return runtime.setModel(model);
},
getThinkingLevel() {
return runtime.getThinkingLevel();
},
setThinkingLevel(level) {
runtime.setThinkingLevel(level);
},
registerProvider(name: string, config: ProviderConfig) {
runtime.registerProvider(name, config);
},
unregisterProvider(name: string) {
runtime.unregisterProvider(name);
},
events: eventBus,
} as ExtensionAPI;
return api;
}
async function loadExtensionModule(extensionPath: string) {
const jiti = createJiti(import.meta.url, {
moduleCache: false,
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
// In Node.js/dev: use aliases to resolve to node_modules paths
...(isBunBinary
? { virtualModules: VIRTUAL_MODULES, tryNative: false }
: { alias: getAliases() }),
});
const module = await jiti.import(extensionPath, { default: true });
const factory = module as ExtensionFactory;
return typeof factory !== "function" ? undefined : factory;
}
/**
* Create an Extension object with empty collections.
*/
function createExtension(
extensionPath: string,
resolvedPath: string,
): Extension {
return {
path: extensionPath,
resolvedPath,
handlers: new Map(),
tools: new Map(),
messageRenderers: new Map(),
commands: new Map(),
flags: new Map(),
shortcuts: new Map(),
};
}
async function loadExtension(
extensionPath: string,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
): Promise<{ extension: Extension | null; error: string | null }> {
const resolvedPath = resolvePath(extensionPath, cwd);
try {
const factory = await loadExtensionModule(resolvedPath);
if (!factory) {
return {
extension: null,
error: `Extension does not export a valid factory function: ${extensionPath}`,
};
}
const extension = createExtension(extensionPath, resolvedPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return { extension, error: null };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
return { extension: null, error: `Failed to load extension: ${message}` };
}
}
/**
* Create an Extension from an inline factory function.
*/
export async function loadExtensionFromFactory(
factory: ExtensionFactory,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
extensionPath = "<inline>",
): Promise<Extension> {
const extension = createExtension(extensionPath, extensionPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return extension;
}
/**
* Load extensions from paths.
*/
export async function loadExtensions(
paths: string[],
cwd: string,
eventBus?: EventBus,
): Promise<LoadExtensionsResult> {
const extensions: Extension[] = [];
const errors: Array<{ path: string; error: string }> = [];
const resolvedEventBus = eventBus ?? createEventBus();
const runtime = createExtensionRuntime();
for (const extPath of paths) {
const { extension, error } = await loadExtension(
extPath,
cwd,
resolvedEventBus,
runtime,
);
if (error) {
errors.push({ path: extPath, error });
continue;
}
if (extension) {
extensions.push(extension);
}
}
return {
extensions,
errors,
runtime,
};
}
interface PiManifest {
extensions?: string[];
themes?: string[];
skills?: string[];
prompts?: string[];
}
function readPiManifest(packageJsonPath: string): PiManifest | null {
try {
const content = fs.readFileSync(packageJsonPath, "utf-8");
const pkg = JSON.parse(content);
if (pkg.pi && typeof pkg.pi === "object") {
return pkg.pi as PiManifest;
}
return null;
} catch {
return null;
}
}
function isExtensionFile(name: string): boolean {
return name.endsWith(".ts") || name.endsWith(".js");
}
/**
* Resolve extension entry points from a directory.
*
* Checks for:
* 1. package.json with "pi.extensions" field -> returns declared paths
* 2. index.ts or index.js -> returns the index file
*
* Returns resolved paths or null if no entry points found.
*/
function resolveExtensionEntries(dir: string): string[] | null {
// Check for package.json with "pi" field first
const packageJsonPath = path.join(dir, "package.json");
if (fs.existsSync(packageJsonPath)) {
const manifest = readPiManifest(packageJsonPath);
if (manifest?.extensions?.length) {
const entries: string[] = [];
for (const extPath of manifest.extensions) {
const resolvedExtPath = path.resolve(dir, extPath);
if (fs.existsSync(resolvedExtPath)) {
entries.push(resolvedExtPath);
}
}
if (entries.length > 0) {
return entries;
}
}
}
// Check for index.ts or index.js
const indexTs = path.join(dir, "index.ts");
const indexJs = path.join(dir, "index.js");
if (fs.existsSync(indexTs)) {
return [indexTs];
}
if (fs.existsSync(indexJs)) {
return [indexJs];
}
return null;
}
/**
* Discover extensions in a directory.
*
* Discovery rules:
* 1. Direct files: `extensions/*.ts` or `*.js` load
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` load
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field load what it declares
*
* No recursion beyond one level. Complex packages must use package.json manifest.
*/
function discoverExtensionsInDir(dir: string): string[] {
if (!fs.existsSync(dir)) {
return [];
}
const discovered: string[] = [];
try {
const entries = fs.readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const entryPath = path.join(dir, entry.name);
// 1. Direct files: *.ts or *.js
if (
(entry.isFile() || entry.isSymbolicLink()) &&
isExtensionFile(entry.name)
) {
discovered.push(entryPath);
continue;
}
// 2 & 3. Subdirectories
if (entry.isDirectory() || entry.isSymbolicLink()) {
const entries = resolveExtensionEntries(entryPath);
if (entries) {
discovered.push(...entries);
}
}
}
} catch {
return [];
}
return discovered;
}
/**
* Discover and load extensions from standard locations.
*/
export async function discoverAndLoadExtensions(
configuredPaths: string[],
cwd: string,
agentDir: string = getAgentDir(),
eventBus?: EventBus,
): Promise<LoadExtensionsResult> {
const allPaths: string[] = [];
const seen = new Set<string>();
const addPaths = (paths: string[]) => {
for (const p of paths) {
const resolved = path.resolve(p);
if (!seen.has(resolved)) {
seen.add(resolved);
allPaths.push(p);
}
}
};
// 1. Project-local extensions: cwd/.pi/extensions/
const localExtDir = path.join(cwd, ".pi", "extensions");
addPaths(discoverExtensionsInDir(localExtDir));
// 2. Global extensions: agentDir/extensions/
const globalExtDir = path.join(agentDir, "extensions");
addPaths(discoverExtensionsInDir(globalExtDir));
// 3. Explicitly configured paths
for (const p of configuredPaths) {
const resolved = resolvePath(p, cwd);
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
// Check for package.json with pi manifest or index.ts
const entries = resolveExtensionEntries(resolved);
if (entries) {
addPaths(entries);
continue;
}
// No explicit entries - discover individual files in directory
addPaths(discoverExtensionsInDir(resolved));
continue;
}
addPaths([resolved]);
}
return loadExtensions(allPaths, cwd, eventBus);
}

View file

@ -0,0 +1,950 @@
/**
* Extension runner - executes extensions and manages their lifecycle.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Model } from "@mariozechner/pi-ai";
import type { KeyId } from "@mariozechner/pi-tui";
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
import type { ResourceDiagnostic } from "../diagnostics.js";
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js";
import type {
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
CompactOptions,
ContextEvent,
ContextEventResult,
ContextUsage,
Extension,
ExtensionActions,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
ExtensionError,
ExtensionEvent,
ExtensionFlag,
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
InputEvent,
InputEventResult,
InputSource,
MessageRenderer,
RegisteredCommand,
RegisteredTool,
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SessionBeforeCompactResult,
SessionBeforeForkResult,
SessionBeforeSwitchResult,
SessionBeforeTreeResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEvent,
ToolResultEventResult,
UserBashEvent,
UserBashEventResult,
} from "./types.js";
// Keybindings for these actions cannot be overridden by extensions
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"externalEditor",
"followUp",
"submit",
"selectConfirm",
"selectCancel",
"copy",
"deleteToLineEnd",
];
type BuiltInKeyBindings = Partial<
Record<KeyId, { action: KeyAction; restrictOverride: boolean }>
>;
const buildBuiltinKeybindings = (
effectiveKeybindings: Required<KeybindingsConfig>,
): BuiltInKeyBindings => {
const builtinKeybindings = {} as BuiltInKeyBindings;
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
const keyAction = action as KeyAction;
const keyList = Array.isArray(keys) ? keys : [keys];
const restrictOverride =
RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
for (const key of keyList) {
const normalizedKey = key.toLowerCase() as KeyId;
builtinKeybindings[normalizedKey] = {
action: keyAction,
restrictOverride: restrictOverride,
};
}
}
return builtinKeybindings;
};
/** Combined result from all before_agent_start handlers */
interface BeforeAgentStartCombinedResult {
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
systemPrompt?: string;
}
/**
* Events handled by the generic emit() method.
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
*/
type RunnerEmitEvent = Exclude<
ExtensionEvent,
| ToolCallEvent
| ToolResultEvent
| UserBashEvent
| ContextEvent
| BeforeAgentStartEvent
| ResourcesDiscoverEvent
| InputEvent
>;
type SessionBeforeEvent = Extract<
RunnerEmitEvent,
{
type:
| "session_before_switch"
| "session_before_fork"
| "session_before_compact"
| "session_before_tree";
}
>;
type SessionBeforeEventResult =
| SessionBeforeSwitchResult
| SessionBeforeForkResult
| SessionBeforeCompactResult
| SessionBeforeTreeResult;
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends {
type: "session_before_switch";
}
? SessionBeforeSwitchResult | undefined
: TEvent extends { type: "session_before_fork" }
? SessionBeforeForkResult | undefined
: TEvent extends { type: "session_before_compact" }
? SessionBeforeCompactResult | undefined
: TEvent extends { type: "session_before_tree" }
? SessionBeforeTreeResult | undefined
: undefined;
export type ExtensionErrorListener = (error: ExtensionError) => void;
export type NewSessionHandler = (options?: {
parentSession?: string;
setup?: (sessionManager: SessionManager) => Promise<void>;
}) => Promise<{ cancelled: boolean }>;
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
export type NavigateTreeHandler = (
targetId: string,
options?: {
summarize?: boolean;
customInstructions?: string;
replaceInstructions?: boolean;
label?: string;
},
) => Promise<{ cancelled: boolean }>;
export type SwitchSessionHandler = (
sessionPath: string,
) => Promise<{ cancelled: boolean }>;
export type ReloadHandler = () => Promise<void>;
export type ShutdownHandler = () => void;
/**
* Helper function to emit session_shutdown event to extensions.
* Returns true if the event was emitted, false if there were no handlers.
*/
export async function emitSessionShutdownEvent(
extensionRunner: ExtensionRunner | undefined,
): Promise<boolean> {
if (extensionRunner?.hasHandlers("session_shutdown")) {
await extensionRunner.emit({
type: "session_shutdown",
});
return true;
}
return false;
}
const noOpUIContext: ExtensionUIContext = {
select: async () => undefined,
confirm: async () => false,
input: async () => undefined,
notify: () => {},
onTerminalInput: () => () => {},
setStatus: () => {},
setWorkingMessage: () => {},
setWidget: () => {},
setFooter: () => {},
setHeader: () => {},
setTitle: () => {},
custom: async () => undefined as never,
pasteToEditor: () => {},
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
getAllThemes: () => [],
getTheme: () => undefined,
setTheme: (_theme: string | Theme) => ({
success: false,
error: "UI not available",
}),
getToolsExpanded: () => false,
setToolsExpanded: () => {},
};
export class ExtensionRunner {
private extensions: Extension[];
private runtime: ExtensionRuntime;
private uiContext: ExtensionUIContext;
private cwd: string;
private sessionManager: SessionManager;
private modelRegistry: ModelRegistry;
private errorListeners: Set<ExtensionErrorListener> = new Set();
private getModel: () => Model<any> | undefined = () => undefined;
private isIdleFn: () => boolean = () => true;
private waitForIdleFn: () => Promise<void> = async () => {};
private abortFn: () => void = () => {};
private hasPendingMessagesFn: () => boolean = () => false;
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
private compactFn: (options?: CompactOptions) => void = () => {};
private getSystemPromptFn: () => string = () => "";
private newSessionHandler: NewSessionHandler = async () => ({
cancelled: false,
});
private forkHandler: ForkHandler = async () => ({ cancelled: false });
private navigateTreeHandler: NavigateTreeHandler = async () => ({
cancelled: false,
});
private switchSessionHandler: SwitchSessionHandler = async () => ({
cancelled: false,
});
private reloadHandler: ReloadHandler = async () => {};
private shutdownHandler: ShutdownHandler = () => {};
private shortcutDiagnostics: ResourceDiagnostic[] = [];
private commandDiagnostics: ResourceDiagnostic[] = [];
constructor(
extensions: Extension[],
runtime: ExtensionRuntime,
cwd: string,
sessionManager: SessionManager,
modelRegistry: ModelRegistry,
) {
this.extensions = extensions;
this.runtime = runtime;
this.uiContext = noOpUIContext;
this.cwd = cwd;
this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry;
}
bindCore(
actions: ExtensionActions,
contextActions: ExtensionContextActions,
): void {
// Copy actions into the shared runtime (all extension APIs reference this)
this.runtime.sendMessage = actions.sendMessage;
this.runtime.sendUserMessage = actions.sendUserMessage;
this.runtime.appendEntry = actions.appendEntry;
this.runtime.setSessionName = actions.setSessionName;
this.runtime.getSessionName = actions.getSessionName;
this.runtime.setLabel = actions.setLabel;
this.runtime.getActiveTools = actions.getActiveTools;
this.runtime.getAllTools = actions.getAllTools;
this.runtime.setActiveTools = actions.setActiveTools;
this.runtime.refreshTools = actions.refreshTools;
this.runtime.getCommands = actions.getCommands;
this.runtime.setModel = actions.setModel;
this.runtime.getThinkingLevel = actions.getThinkingLevel;
this.runtime.setThinkingLevel = actions.setThinkingLevel;
// Context actions (required)
this.getModel = contextActions.getModel;
this.isIdleFn = contextActions.isIdle;
this.abortFn = contextActions.abort;
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact;
this.getSystemPromptFn = contextActions.getSystemPrompt;
// Flush provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
}
this.runtime.pendingProviderRegistrations = [];
// From this point on, provider registration/unregistration takes effect immediately
// without requiring a /reload.
this.runtime.registerProvider = (name, config) =>
this.modelRegistry.registerProvider(name, config);
this.runtime.unregisterProvider = (name) =>
this.modelRegistry.unregisterProvider(name);
}
bindCommandContext(actions?: ExtensionCommandContextActions): void {
if (actions) {
this.waitForIdleFn = actions.waitForIdle;
this.newSessionHandler = actions.newSession;
this.forkHandler = actions.fork;
this.navigateTreeHandler = actions.navigateTree;
this.switchSessionHandler = actions.switchSession;
this.reloadHandler = actions.reload;
return;
}
this.waitForIdleFn = async () => {};
this.newSessionHandler = async () => ({ cancelled: false });
this.forkHandler = async () => ({ cancelled: false });
this.navigateTreeHandler = async () => ({ cancelled: false });
this.switchSessionHandler = async () => ({ cancelled: false });
this.reloadHandler = async () => {};
}
setUIContext(uiContext?: ExtensionUIContext): void {
this.uiContext = uiContext ?? noOpUIContext;
}
getUIContext(): ExtensionUIContext {
return this.uiContext;
}
hasUI(): boolean {
return this.uiContext !== noOpUIContext;
}
getExtensionPaths(): string[] {
return this.extensions.map((e) => e.path);
}
/** Get all registered tools from all extensions (first registration per name wins). */
getAllRegisteredTools(): RegisteredTool[] {
const toolsByName = new Map<string, RegisteredTool>();
for (const ext of this.extensions) {
for (const tool of ext.tools.values()) {
if (!toolsByName.has(tool.definition.name)) {
toolsByName.set(tool.definition.name, tool);
}
}
}
return Array.from(toolsByName.values());
}
/** Get a tool definition by name. Returns undefined if not found. */
getToolDefinition(
toolName: string,
): RegisteredTool["definition"] | undefined {
for (const ext of this.extensions) {
const tool = ext.tools.get(toolName);
if (tool) {
return tool.definition;
}
}
return undefined;
}
getFlags(): Map<string, ExtensionFlag> {
const allFlags = new Map<string, ExtensionFlag>();
for (const ext of this.extensions) {
for (const [name, flag] of ext.flags) {
if (!allFlags.has(name)) {
allFlags.set(name, flag);
}
}
}
return allFlags;
}
setFlagValue(name: string, value: boolean | string): void {
this.runtime.flagValues.set(name, value);
}
getFlagValues(): Map<string, boolean | string> {
return new Map(this.runtime.flagValues);
}
getShortcuts(
effectiveKeybindings: Required<KeybindingsConfig>,
): Map<KeyId, ExtensionShortcut> {
this.shortcutDiagnostics = [];
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
const addDiagnostic = (message: string, extensionPath: string) => {
this.shortcutDiagnostics.push({
type: "warning",
message,
path: extensionPath,
});
if (!this.hasUI()) {
console.warn(message);
}
};
for (const ext of this.extensions) {
for (const [key, shortcut] of ext.shortcuts) {
const normalizedKey = key.toLowerCase() as KeyId;
const builtInKeybinding = builtinKeybindings[normalizedKey];
if (builtInKeybinding?.restrictOverride === true) {
addDiagnostic(
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
shortcut.extensionPath,
);
continue;
}
if (builtInKeybinding?.restrictOverride === false) {
addDiagnostic(
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
if (existingExtensionShortcut) {
addDiagnostic(
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
extensionShortcuts.set(normalizedKey, shortcut);
}
}
return extensionShortcuts;
}
getShortcutDiagnostics(): ResourceDiagnostic[] {
return this.shortcutDiagnostics;
}
onError(listener: ExtensionErrorListener): () => void {
this.errorListeners.add(listener);
return () => this.errorListeners.delete(listener);
}
emitError(error: ExtensionError): void {
for (const listener of this.errorListeners) {
listener(error);
}
}
hasHandlers(eventType: string): boolean {
for (const ext of this.extensions) {
const handlers = ext.handlers.get(eventType);
if (handlers && handlers.length > 0) {
return true;
}
}
return false;
}
getMessageRenderer(customType: string): MessageRenderer | undefined {
for (const ext of this.extensions) {
const renderer = ext.messageRenderers.get(customType);
if (renderer) {
return renderer;
}
}
return undefined;
}
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
this.commandDiagnostics = [];
const commands: RegisteredCommand[] = [];
const commandOwners = new Map<string, string>();
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
if (reserved?.has(command.name)) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
this.commandDiagnostics.push({
type: "warning",
message,
path: ext.path,
});
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
const existingOwner = commandOwners.get(command.name);
if (existingOwner) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
this.commandDiagnostics.push({
type: "warning",
message,
path: ext.path,
});
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
commandOwners.set(command.name, ext.path);
commands.push(command);
}
}
return commands;
}
getCommandDiagnostics(): ResourceDiagnostic[] {
return this.commandDiagnostics;
}
getRegisteredCommandsWithPaths(): Array<{
command: RegisteredCommand;
extensionPath: string;
}> {
const result: Array<{ command: RegisteredCommand; extensionPath: string }> =
[];
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
result.push({ command, extensionPath: ext.path });
}
}
return result;
}
getCommand(name: string): RegisteredCommand | undefined {
for (const ext of this.extensions) {
const command = ext.commands.get(name);
if (command) {
return command;
}
}
return undefined;
}
/**
* Request a graceful shutdown. Called by extension tools and event handlers.
* The actual shutdown behavior is provided by the mode via bindExtensions().
*/
shutdown(): void {
this.shutdownHandler();
}
/**
* Create an ExtensionContext for use in event handlers and tool execution.
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
*/
createContext(): ExtensionContext {
const getModel = this.getModel;
return {
ui: this.uiContext,
hasUI: this.hasUI(),
cwd: this.cwd,
sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry,
get model() {
return getModel();
},
isIdle: () => this.isIdleFn(),
abort: () => this.abortFn(),
hasPendingMessages: () => this.hasPendingMessagesFn(),
shutdown: () => this.shutdownHandler(),
getContextUsage: () => this.getContextUsageFn(),
compact: (options) => this.compactFn(options),
getSystemPrompt: () => this.getSystemPromptFn(),
};
}
createCommandContext(): ExtensionCommandContext {
return {
...this.createContext(),
waitForIdle: () => this.waitForIdleFn(),
newSession: (options) => this.newSessionHandler(options),
fork: (entryId) => this.forkHandler(entryId),
navigateTree: (targetId, options) =>
this.navigateTreeHandler(targetId, options),
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
reload: () => this.reloadHandler(),
};
}
private isSessionBeforeEvent(
event: RunnerEmitEvent,
): event is SessionBeforeEvent {
return (
event.type === "session_before_switch" ||
event.type === "session_before_fork" ||
event.type === "session_before_compact" ||
event.type === "session_before_tree"
);
}
async emit<TEvent extends RunnerEmitEvent>(
event: TEvent,
): Promise<RunnerEmitResult<TEvent>> {
const ctx = this.createContext();
let result: SessionBeforeEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get(event.type);
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (this.isSessionBeforeEvent(event) && handlerResult) {
result = handlerResult as SessionBeforeEventResult;
if (result.cancel) {
return result as RunnerEmitResult<TEvent>;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: event.type,
error: message,
stack,
});
}
}
}
return result as RunnerEmitResult<TEvent>;
}
async emitToolResult(
event: ToolResultEvent,
): Promise<ToolResultEventResult | undefined> {
const ctx = this.createContext();
const currentEvent: ToolResultEvent = { ...event };
let modified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_result");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = (await handler(currentEvent, ctx)) as
| ToolResultEventResult
| undefined;
if (!handlerResult) continue;
if (handlerResult.content !== undefined) {
currentEvent.content = handlerResult.content;
modified = true;
}
if (handlerResult.details !== undefined) {
currentEvent.details = handlerResult.details;
modified = true;
}
if (handlerResult.isError !== undefined) {
currentEvent.isError = handlerResult.isError;
modified = true;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "tool_result",
error: message,
stack,
});
}
}
}
if (!modified) {
return undefined;
}
return {
content: currentEvent.content,
details: currentEvent.details,
isError: currentEvent.isError,
};
}
async emitToolCall(
event: ToolCallEvent,
): Promise<ToolCallEventResult | undefined> {
const ctx = this.createContext();
let result: ToolCallEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_call");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
result = handlerResult as ToolCallEventResult;
if (result.block) {
return result;
}
}
}
}
return result;
}
async emitUserBash(
event: UserBashEvent,
): Promise<UserBashEventResult | undefined> {
const ctx = this.createContext();
for (const ext of this.extensions) {
const handlers = ext.handlers.get("user_bash");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
return handlerResult as UserBashEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "user_bash",
error: message,
stack,
});
}
}
}
return undefined;
}
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
const ctx = this.createContext();
let currentMessages = structuredClone(messages);
for (const ext of this.extensions) {
const handlers = ext.handlers.get("context");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ContextEvent = {
type: "context",
messages: currentMessages,
};
const handlerResult = await handler(event, ctx);
if (handlerResult && (handlerResult as ContextEventResult).messages) {
currentMessages = (handlerResult as ContextEventResult).messages!;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "context",
error: message,
stack,
});
}
}
}
return currentMessages;
}
async emitBeforeAgentStart(
prompt: string,
images: ImageContent[] | undefined,
systemPrompt: string,
): Promise<BeforeAgentStartCombinedResult | undefined> {
const ctx = this.createContext();
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
let currentSystemPrompt = systemPrompt;
let systemPromptModified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("before_agent_start");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeAgentStartEvent = {
type: "before_agent_start",
prompt,
images,
systemPrompt: currentSystemPrompt,
};
const handlerResult = await handler(event, ctx);
if (handlerResult) {
const result = handlerResult as BeforeAgentStartEventResult;
if (result.message) {
messages.push(result.message);
}
if (result.systemPrompt !== undefined) {
currentSystemPrompt = result.systemPrompt;
systemPromptModified = true;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "before_agent_start",
error: message,
stack,
});
}
}
}
if (messages.length > 0 || systemPromptModified) {
return {
messages: messages.length > 0 ? messages : undefined,
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
};
}
return undefined;
}
async emitResourcesDiscover(
cwd: string,
reason: ResourcesDiscoverEvent["reason"],
): Promise<{
skillPaths: Array<{ path: string; extensionPath: string }>;
promptPaths: Array<{ path: string; extensionPath: string }>;
themePaths: Array<{ path: string; extensionPath: string }>;
}> {
const ctx = this.createContext();
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
const themePaths: Array<{ path: string; extensionPath: string }> = [];
for (const ext of this.extensions) {
const handlers = ext.handlers.get("resources_discover");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ResourcesDiscoverEvent = {
type: "resources_discover",
cwd,
reason,
};
const handlerResult = await handler(event, ctx);
const result = handlerResult as ResourcesDiscoverResult | undefined;
if (result?.skillPaths?.length) {
skillPaths.push(
...result.skillPaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
if (result?.promptPaths?.length) {
promptPaths.push(
...result.promptPaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
if (result?.themePaths?.length) {
themePaths.push(
...result.themePaths.map((path) => ({
path,
extensionPath: ext.path,
})),
);
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "resources_discover",
error: message,
stack,
});
}
}
}
return { skillPaths, promptPaths, themePaths };
}
/** Emit input event. Transforms chain, "handled" short-circuits. */
async emitInput(
text: string,
images: ImageContent[] | undefined,
source: InputSource,
): Promise<InputEventResult> {
const ctx = this.createContext();
let currentText = text;
let currentImages = images;
for (const ext of this.extensions) {
for (const handler of ext.handlers.get("input") ?? []) {
try {
const event: InputEvent = {
type: "input",
text: currentText,
images: currentImages,
source,
};
const result = (await handler(event, ctx)) as
| InputEventResult
| undefined;
if (result?.action === "handled") return result;
if (result?.action === "transform") {
currentText = result.text;
currentImages = result.images ?? currentImages;
}
} catch (err) {
this.emitError({
extensionPath: ext.path,
event: "input",
error: err instanceof Error ? err.message : String(err),
stack: err instanceof Error ? err.stack : undefined,
});
}
}
}
return currentText !== text || currentImages !== images
? { action: "transform", text: currentText, images: currentImages }
: { action: "continue" };
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,147 @@
/**
* Tool wrappers for extensions.
*/
import type {
AgentTool,
AgentToolUpdateCallback,
} from "@mariozechner/pi-agent-core";
import type { ExtensionRunner } from "./runner.js";
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
/**
* Wrap a RegisteredTool into an AgentTool.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTool(
registeredTool: RegisteredTool,
runner: ExtensionRunner,
): AgentTool {
const { definition } = registeredTool;
return {
name: definition.name,
label: definition.label,
description: definition.description,
parameters: definition.parameters,
execute: (toolCallId, params, signal, onUpdate) =>
definition.execute(
toolCallId,
params,
signal,
onUpdate,
runner.createContext(),
),
};
}
/**
* Wrap all registered tools into AgentTools.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTools(
registeredTools: RegisteredTool[],
runner: ExtensionRunner,
): AgentTool[] {
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
}
/**
* Wrap a tool with extension callbacks for interception.
* - Emits tool_call event before execution (can block)
* - Emits tool_result event after execution (can modify result)
*/
export function wrapToolWithExtensions<T>(
tool: AgentTool<any, T>,
runner: ExtensionRunner,
): AgentTool<any, T> {
return {
...tool,
execute: async (
toolCallId: string,
params: Record<string, unknown>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<T>,
) => {
// Emit tool_call event - extensions can block execution
if (runner.hasHandlers("tool_call")) {
try {
const callResult = (await runner.emitToolCall({
type: "tool_call",
toolName: tool.name,
toolCallId,
input: params,
})) as ToolCallEventResult | undefined;
if (callResult?.block) {
const reason =
callResult.reason || "Tool execution was blocked by an extension";
throw new Error(reason);
}
} catch (err) {
if (err instanceof Error) {
throw err;
}
throw new Error(
`Extension failed, blocking execution: ${String(err)}`,
);
}
}
// Execute the actual tool
try {
const result = await tool.execute(toolCallId, params, signal, onUpdate);
// Emit tool_result event - extensions can modify the result
if (runner.hasHandlers("tool_result")) {
const resultResult = await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: result.content,
details: result.details,
isError: false,
});
if (resultResult) {
return {
content: resultResult.content ?? result.content,
details: (resultResult.details ?? result.details) as T,
};
}
}
return result;
} catch (err) {
// Emit tool_result event for errors
if (runner.hasHandlers("tool_result")) {
await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: [
{
type: "text",
text: err instanceof Error ? err.message : String(err),
},
],
details: undefined,
isError: true,
});
}
throw err;
}
},
};
}
/**
* Wrap all tools with extension callbacks.
*/
export function wrapToolsWithExtensions<T>(
tools: AgentTool<any, T>[],
runner: ExtensionRunner,
): AgentTool<any, T>[] {
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
}

View file

@ -0,0 +1,149 @@
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
import { dirname, join, resolve } from "path";
/**
* Find the git HEAD path by walking up from cwd.
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
*/
function findGitHeadPath(): string | null {
let dir = process.cwd();
while (true) {
const gitPath = join(dir, ".git");
if (existsSync(gitPath)) {
try {
const stat = statSync(gitPath);
if (stat.isFile()) {
const content = readFileSync(gitPath, "utf8").trim();
if (content.startsWith("gitdir: ")) {
const gitDir = content.slice(8);
const headPath = resolve(dir, gitDir, "HEAD");
if (existsSync(headPath)) return headPath;
}
} else if (stat.isDirectory()) {
const headPath = join(gitPath, "HEAD");
if (existsSync(headPath)) return headPath;
}
} catch {
return null;
}
}
const parent = dirname(dir);
if (parent === dir) return null;
dir = parent;
}
}
/**
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
* Token stats, model info available via ctx.sessionManager and ctx.model.
*/
export class FooterDataProvider {
private extensionStatuses = new Map<string, string>();
private cachedBranch: string | null | undefined = undefined;
private gitWatcher: FSWatcher | null = null;
private branchChangeCallbacks = new Set<() => void>();
private availableProviderCount = 0;
constructor() {
this.setupGitWatcher();
}
/** Current git branch, null if not in repo, "detached" if detached HEAD */
getGitBranch(): string | null {
if (this.cachedBranch !== undefined) return this.cachedBranch;
try {
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) {
this.cachedBranch = null;
return null;
}
const content = readFileSync(gitHeadPath, "utf8").trim();
this.cachedBranch = content.startsWith("ref: refs/heads/")
? content.slice(16)
: "detached";
} catch {
this.cachedBranch = null;
}
return this.cachedBranch;
}
/** Extension status texts set via ctx.ui.setStatus() */
getExtensionStatuses(): ReadonlyMap<string, string> {
return this.extensionStatuses;
}
/** Subscribe to git branch changes. Returns unsubscribe function. */
onBranchChange(callback: () => void): () => void {
this.branchChangeCallbacks.add(callback);
return () => this.branchChangeCallbacks.delete(callback);
}
/** Internal: set extension status */
setExtensionStatus(key: string, text: string | undefined): void {
if (text === undefined) {
this.extensionStatuses.delete(key);
} else {
this.extensionStatuses.set(key, text);
}
}
/** Internal: clear extension statuses */
clearExtensionStatuses(): void {
this.extensionStatuses.clear();
}
/** Number of unique providers with available models (for footer display) */
getAvailableProviderCount(): number {
return this.availableProviderCount;
}
/** Internal: update available provider count */
setAvailableProviderCount(count: number): void {
this.availableProviderCount = count;
}
/** Internal: cleanup */
dispose(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
this.branchChangeCallbacks.clear();
}
private setupGitWatcher(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) return;
// Watch the directory containing HEAD, not HEAD itself.
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
// fs.watch on a file stops working after the inode changes.
const gitDir = dirname(gitHeadPath);
try {
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
if (filename === "HEAD") {
this.cachedBranch = undefined;
for (const cb of this.branchChangeCallbacks) cb();
}
});
} catch {
// Silently fail if we can't watch
}
}
}
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
export type ReadonlyFooterDataProvider = Pick<
FooterDataProvider,
| "getGitBranch"
| "getExtensionStatuses"
| "getAvailableProviderCount"
| "onBranchChange"
>;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
/**
* Core modules shared between all run modes.
*/
export {
AgentSession,
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type ModelCycleResult,
type PromptOptions,
type SessionStats,
} from "./agent-session.js";
export {
type BashExecutorOptions,
type BashResult,
executeBash,
executeBashWithOperations,
} from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js";
export {
createEventBus,
type EventBus,
type EventBusController,
} from "./event-bus.js";
// Extensions system
export {
type AgentEndEvent,
type AgentStartEvent,
type AgentToolResult,
type AgentToolUpdateCallback,
type BeforeAgentStartEvent,
type ContextEvent,
discoverAndLoadExtensions,
type ExecOptions,
type ExecResult,
type Extension,
type ExtensionAPI,
type ExtensionCommandContext,
type ExtensionContext,
type ExtensionError,
type ExtensionEvent,
type ExtensionFactory,
type ExtensionFlag,
type ExtensionHandler,
ExtensionRunner,
type ExtensionShortcut,
type ExtensionUIContext,
type LoadExtensionsResult,
type MessageRenderer,
type RegisteredCommand,
type SessionBeforeCompactEvent,
type SessionBeforeForkEvent,
type SessionBeforeSwitchEvent,
type SessionBeforeTreeEvent,
type SessionCompactEvent,
type SessionForkEvent,
type SessionShutdownEvent,
type SessionStartEvent,
type SessionSwitchEvent,
type SessionTreeEvent,
type ToolCallEvent,
type ToolDefinition,
type ToolRenderResultOptions,
type ToolResultEvent,
type TurnEndEvent,
type TurnStartEvent,
wrapToolsWithExtensions,
} from "./extensions/index.js";

View file

@ -0,0 +1,211 @@
import {
DEFAULT_EDITOR_KEYBINDINGS,
type EditorAction,
type EditorKeybindingsConfig,
EditorKeybindingsManager,
type KeyId,
matchesKey,
setEditorKeybindings,
} from "@mariozechner/pi-tui";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
/**
* Application-level actions (coding agent specific).
*/
export type AppAction =
| "interrupt"
| "clear"
| "exit"
| "suspend"
| "cycleThinkingLevel"
| "cycleModelForward"
| "cycleModelBackward"
| "selectModel"
| "expandTools"
| "toggleThinking"
| "toggleSessionNamedFilter"
| "externalEditor"
| "followUp"
| "dequeue"
| "pasteImage"
| "newSession"
| "tree"
| "fork"
| "resume";
/**
* All configurable actions.
*/
export type KeyAction = AppAction | EditorAction;
/**
* Full keybindings configuration (app + editor actions).
*/
export type KeybindingsConfig = {
[K in KeyAction]?: KeyId | KeyId[];
};
/**
* Default application keybindings.
*/
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
interrupt: "escape",
clear: "ctrl+c",
exit: "ctrl+d",
suspend: "ctrl+z",
cycleThinkingLevel: "shift+tab",
cycleModelForward: "ctrl+p",
cycleModelBackward: "shift+ctrl+p",
selectModel: "ctrl+l",
expandTools: "ctrl+o",
toggleThinking: "ctrl+t",
toggleSessionNamedFilter: "ctrl+n",
externalEditor: "ctrl+g",
followUp: "alt+enter",
dequeue: "alt+up",
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
newSession: [],
tree: [],
fork: [],
resume: [],
};
/**
* All default keybindings (app + editor).
*/
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
...DEFAULT_EDITOR_KEYBINDINGS,
...DEFAULT_APP_KEYBINDINGS,
};
// App actions list for type checking
const APP_ACTIONS: AppAction[] = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"toggleSessionNamedFilter",
"externalEditor",
"followUp",
"dequeue",
"pasteImage",
"newSession",
"tree",
"fork",
"resume",
];
function isAppAction(action: string): action is AppAction {
return APP_ACTIONS.includes(action as AppAction);
}
/**
* Manages all keybindings (app + editor).
*/
export class KeybindingsManager {
private config: KeybindingsConfig;
private appActionToKeys: Map<AppAction, KeyId[]>;
private constructor(config: KeybindingsConfig) {
this.config = config;
this.appActionToKeys = new Map();
this.buildMaps();
}
/**
* Create from config file and set up editor keybindings.
*/
static create(agentDir: string = getAgentDir()): KeybindingsManager {
const configPath = join(agentDir, "keybindings.json");
const config = KeybindingsManager.loadFromFile(configPath);
const manager = new KeybindingsManager(config);
// Set up editor keybindings globally
// Include both editor actions and expandTools (shared between app and editor)
const editorConfig: EditorKeybindingsConfig = {};
for (const [action, keys] of Object.entries(config)) {
if (!isAppAction(action) || action === "expandTools") {
editorConfig[action as EditorAction] = keys;
}
}
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
return manager;
}
/**
* Create in-memory.
*/
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
return new KeybindingsManager(config);
}
private static loadFromFile(path: string): KeybindingsConfig {
if (!existsSync(path)) return {};
try {
return JSON.parse(readFileSync(path, "utf-8"));
} catch {
return {};
}
}
private buildMaps(): void {
this.appActionToKeys.clear();
// Set defaults for app actions
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action as AppAction, [...keyArray]);
}
// Override with user config (app actions only)
for (const [action, keys] of Object.entries(this.config)) {
if (keys === undefined || !isAppAction(action)) continue;
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action, keyArray);
}
}
/**
* Check if input matches an app action.
*/
matches(data: string, action: AppAction): boolean {
const keys = this.appActionToKeys.get(action);
if (!keys) return false;
for (const key of keys) {
if (matchesKey(data, key)) return true;
}
return false;
}
/**
* Get keys bound to an app action.
*/
getKeys(action: AppAction): KeyId[] {
return this.appActionToKeys.get(action) ?? [];
}
/**
* Get the full effective config.
*/
getEffectiveConfig(): Required<KeybindingsConfig> {
const result = { ...DEFAULT_KEYBINDINGS };
for (const [action, keys] of Object.entries(this.config)) {
if (keys !== undefined) {
(result as KeybindingsConfig)[action as KeyAction] = keys;
}
}
return result;
}
}
// Re-export for convenience
export type { EditorAction, KeyId };

View file

@ -0,0 +1,217 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AgentMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
<summary>
`;
export const COMPACTION_SUMMARY_SUFFIX = `
</summary>`;
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
<summary>
`;
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
/**
* Message type for bash executions via the ! command.
*/
export interface BashExecutionMessage {
role: "bashExecution";
command: string;
output: string;
exitCode: number | undefined;
cancelled: boolean;
truncated: boolean;
fullOutputPath?: string;
timestamp: number;
/** If true, this message is excluded from LLM context (!! prefix) */
excludeFromContext?: boolean;
}
/**
* Message type for extension-injected messages via sendMessage().
* These are custom messages that extensions can inject into the conversation.
*/
export interface CustomMessage<T = unknown> {
role: "custom";
customType: string;
content: string | (TextContent | ImageContent)[];
display: boolean;
details?: T;
timestamp: number;
}
export interface BranchSummaryMessage {
role: "branchSummary";
summary: string;
fromId: string;
timestamp: number;
}
export interface CompactionSummaryMessage {
role: "compactionSummary";
summary: string;
tokensBefore: number;
timestamp: number;
}
// Extend CustomAgentMessages via declaration merging
declare module "@mariozechner/pi-agent-core" {
interface CustomAgentMessages {
bashExecution: BashExecutionMessage;
custom: CustomMessage;
branchSummary: BranchSummaryMessage;
compactionSummary: CompactionSummaryMessage;
}
}
/**
* Convert a BashExecutionMessage to user message text for LLM context.
*/
export function bashExecutionToText(msg: BashExecutionMessage): string {
let text = `Ran \`${msg.command}\`\n`;
if (msg.output) {
text += `\`\`\`\n${msg.output}\n\`\`\``;
} else {
text += "(no output)";
}
if (msg.cancelled) {
text += "\n\n(command cancelled)";
} else if (
msg.exitCode !== null &&
msg.exitCode !== undefined &&
msg.exitCode !== 0
) {
text += `\n\nCommand exited with code ${msg.exitCode}`;
}
if (msg.truncated && msg.fullOutputPath) {
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
}
return text;
}
export function createBranchSummaryMessage(
summary: string,
fromId: string,
timestamp: string,
): BranchSummaryMessage {
return {
role: "branchSummary",
summary,
fromId,
timestamp: new Date(timestamp).getTime(),
};
}
export function createCompactionSummaryMessage(
summary: string,
tokensBefore: number,
timestamp: string,
): CompactionSummaryMessage {
return {
role: "compactionSummary",
summary: summary,
tokensBefore,
timestamp: new Date(timestamp).getTime(),
};
}
/** Convert CustomMessageEntry to AgentMessage format */
export function createCustomMessage(
customType: string,
content: string | (TextContent | ImageContent)[],
display: boolean,
details: unknown | undefined,
timestamp: string,
): CustomMessage {
return {
role: "custom",
customType,
content,
display,
details,
timestamp: new Date(timestamp).getTime(),
};
}
/**
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's transormToLlm option (for prompt calls and queued messages)
* - Compaction's generateSummary (for summarization)
* - Custom extensions and tools
*/
export function convertToLlm(messages: AgentMessage[]): Message[] {
return messages
.map((m): Message | undefined => {
switch (m.role) {
case "bashExecution":
// Skip messages excluded from context (!! prefix)
if (m.excludeFromContext) {
return undefined;
}
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
case "custom": {
const content =
typeof m.content === "string"
? [{ type: "text" as const, text: m.content }]
: m.content;
return {
role: "user",
content,
timestamp: m.timestamp,
};
}
case "branchSummary":
return {
role: "user",
content: [
{
type: "text" as const,
text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX,
},
],
timestamp: m.timestamp,
};
case "compactionSummary":
return {
role: "user",
content: [
{
type: "text" as const,
text:
COMPACTION_SUMMARY_PREFIX +
m.summary +
COMPACTION_SUMMARY_SUFFIX,
},
],
timestamp: m.timestamp,
};
case "user":
case "assistant":
case "toolResult":
return m;
default:
// biome-ignore lint/correctness/noSwitchDeclarations: fine
const _exhaustiveCheck: never = m;
return undefined;
}
})
.filter((m) => m !== undefined);
}

View file

@ -0,0 +1,822 @@
/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import {
type Api,
type AssistantMessageEventStream,
type Context,
getModels,
getProviders,
type KnownProvider,
type Model,
type OAuthProviderInterface,
type OpenAICompletionsCompat,
type OpenAIResponsesCompat,
registerApiProvider,
resetApiProviders,
type SimpleStreamOptions,
} from "@mariozechner/pi-ai";
import {
registerOAuthProvider,
resetOAuthProviders,
} from "@mariozechner/pi-ai/oauth";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import type { AuthStorage } from "./auth-storage.js";
import {
clearConfigValueCache,
resolveConfigValue,
resolveHeaders,
} from "./resolve-config-value.js";
const Ajv = (AjvModule as any).default || AjvModule;
const ajv = new Ajv();
// Schema for OpenRouter routing preferences
const OpenRouterRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for Vercel AI Gateway routing preferences
const VercelGatewayRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for OpenAI compatibility settings
const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(
Type.Union([
Type.Literal("max_completion_tokens"),
Type.Literal("max_tokens"),
]),
),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(
Type.Union([
Type.Literal("openai"),
Type.Literal("zai"),
Type.Literal("qwen"),
]),
),
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
});
const OpenAIResponsesCompatSchema = Type.Object({
// Reserved for future use
});
const OpenAICompatSchema = Type.Union([
OpenAICompletionsCompatSchema,
OpenAIResponsesCompatSchema,
]);
// Schema for custom model definition
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
// Schema for per-model overrides (all fields optional, merged with built-in model)
const ModelOverrideSchema = Type.Object({
name: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Optional(Type.Number()),
output: Type.Optional(Type.Number()),
cacheRead: Type.Optional(Type.Number()),
cacheWrite: Type.Optional(Type.Number()),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
type ModelOverride = Static<typeof ModelOverrideSchema>;
const ProviderConfigSchema = Type.Object({
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
modelOverrides: Type.Optional(
Type.Record(Type.String(), ModelOverrideSchema),
),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/** Provider override config (baseUrl, headers, apiKey) without custom models */
interface ProviderOverride {
baseUrl?: string;
headers?: Record<string, string>;
apiKey?: string;
}
/** Result of loading custom models from models.json */
interface CustomModelsResult {
models: Model<Api>[];
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
overrides: Map<string, ProviderOverride>;
/** Per-model overrides: provider -> modelId -> override */
modelOverrides: Map<string, Map<string, ModelOverride>>;
error: string | undefined;
}
function emptyCustomModelsResult(error?: string): CustomModelsResult {
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
}
function mergeCompat(
baseCompat: Model<Api>["compat"],
overrideCompat: ModelOverride["compat"],
): Model<Api>["compat"] | undefined {
if (!overrideCompat) return baseCompat;
const base = baseCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat
| undefined;
const override = overrideCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const merged = { ...base, ...override } as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const baseCompletions = base as OpenAICompletionsCompat | undefined;
const overrideCompletions = override as OpenAICompletionsCompat;
const mergedCompletions = merged as OpenAICompletionsCompat;
if (
baseCompletions?.openRouterRouting ||
overrideCompletions.openRouterRouting
) {
mergedCompletions.openRouterRouting = {
...baseCompletions?.openRouterRouting,
...overrideCompletions.openRouterRouting,
};
}
if (
baseCompletions?.vercelGatewayRouting ||
overrideCompletions.vercelGatewayRouting
) {
mergedCompletions.vercelGatewayRouting = {
...baseCompletions?.vercelGatewayRouting,
...overrideCompletions.vercelGatewayRouting,
};
}
return merged as Model<Api>["compat"];
}
/**
* Deep merge a model override into a model.
* Handles nested objects (cost, compat) by merging rather than replacing.
*/
function applyModelOverride(
model: Model<Api>,
override: ModelOverride,
): Model<Api> {
const result = { ...model };
// Simple field overrides
if (override.name !== undefined) result.name = override.name;
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
if (override.input !== undefined)
result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined)
result.contextWindow = override.contextWindow;
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
// Merge cost (partial override)
if (override.cost) {
result.cost = {
input: override.cost.input ?? model.cost.input,
output: override.cost.output ?? model.cost.output,
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
};
}
// Merge headers
if (override.headers) {
const resolvedHeaders = resolveHeaders(override.headers);
result.headers = resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers;
}
// Deep merge compat
result.compat = mergeCompat(model.compat, override.compat);
return result;
}
/** Clear the config value command cache. Exported for testing. */
export const clearApiKeyCache = clearConfigValueCache;
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | undefined = join(
getAgentDir(),
"models.json",
),
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveConfigValue(keyConfig);
}
return undefined;
});
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = undefined;
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
resetApiProviders();
resetOAuthProviders();
this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
}
/**
* Get any error from loading models.json (undefined if no error).
*/
getError(): string | undefined {
return this.loadError;
}
private loadModels(): void {
// Load custom models and overrides from models.json
const {
models: customModels,
overrides,
modelOverrides,
error,
} = this.modelsJsonPath
? this.loadCustomModels(this.modelsJsonPath)
: emptyCustomModelsResult();
if (error) {
this.loadError = error;
// Keep built-in models even if custom models failed to load
}
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
let combined = this.mergeCustomModels(builtInModels, customModels);
// Let OAuth providers modify their models (e.g., update baseUrl)
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
const cred = this.authStorage.get(oauthProvider.id);
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
combined = oauthProvider.modifyModels(combined, cred);
}
}
this.models = combined;
}
/** Load built-in models and apply provider/model overrides */
private loadBuiltInModels(
overrides: Map<string, ProviderOverride>,
modelOverrides: Map<string, Map<string, ModelOverride>>,
): Model<Api>[] {
return getProviders().flatMap((provider) => {
const models = getModels(provider as KnownProvider) as Model<Api>[];
const providerOverride = overrides.get(provider);
const perModelOverrides = modelOverrides.get(provider);
return models.map((m) => {
let model = m;
// Apply provider-level baseUrl/headers override
if (providerOverride) {
const resolvedHeaders = resolveHeaders(providerOverride.headers);
model = {
...model,
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
headers: resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers,
};
}
// Apply per-model override
const modelOverride = perModelOverrides?.get(m.id);
if (modelOverride) {
model = applyModelOverride(model, modelOverride);
}
return model;
});
});
}
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
private mergeCustomModels(
builtInModels: Model<Api>[],
customModels: Model<Api>[],
): Model<Api>[] {
const merged = [...builtInModels];
for (const customModel of customModels) {
const existingIndex = merged.findIndex(
(m) => m.provider === customModel.provider && m.id === customModel.id,
);
if (existingIndex >= 0) {
merged[existingIndex] = customModel;
} else {
merged.push(customModel);
}
}
return merged;
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
if (!existsSync(modelsJsonPath)) {
return emptyCustomModelsResult();
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const validate = ajv.getSchema("ModelsConfig")!;
if (!validate(config)) {
const errors =
validate.errors
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
.join("\n") || "Unknown schema error";
return emptyCustomModelsResult(
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
);
}
// Additional validation
this.validateConfig(config);
const overrides = new Map<string, ProviderOverride>();
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
if (
providerConfig.baseUrl ||
providerConfig.headers ||
providerConfig.apiKey
) {
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
apiKey: providerConfig.apiKey,
});
}
// Store API key for fallback resolver.
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
if (providerConfig.modelOverrides) {
modelOverrides.set(
providerName,
new Map(Object.entries(providerConfig.modelOverrides)),
);
}
}
return {
models: this.parseModels(config),
overrides,
modelOverrides,
error: undefined,
};
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
);
}
return emptyCustomModelsResult(
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
);
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
const hasModelOverrides =
providerConfig.modelOverrides &&
Object.keys(providerConfig.modelOverrides).length > 0;
if (models.length === 0) {
// Override-only config: needs baseUrl OR modelOverrides (or both)
if (!providerConfig.baseUrl && !hasModelOverrides) {
throw new Error(
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
);
}
} else {
// Custom models are merged into provider models and require endpoint + auth.
if (!providerConfig.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
);
}
if (!providerConfig.apiKey) {
throw new Error(
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
);
}
}
for (const modelDef of models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id)
throw new Error(`Provider ${providerName}: model missing "id"`);
// Validate contextWindow/maxTokens only if provided (they have defaults)
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
);
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
// Store API key config for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
for (const modelDef of modelDefs) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
// Resolve env vars and shell commands in header values
const providerHeaders = resolveHeaders(providerConfig.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader && providerConfig.apiKey) {
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
// Provider baseUrl is required when custom models are defined.
// Individual models can override it with modelDef.baseUrl.
const defaultCost = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
};
models.push({
id: modelDef.id,
name: modelDef.name ?? modelDef.id,
api: api as Api,
provider: providerName,
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
reasoning: modelDef.reasoning ?? false,
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
cost: modelDef.cost ?? defaultCost,
contextWindow: modelDef.contextWindow ?? 128000,
maxTokens: modelDef.maxTokens ?? 16384,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
}
/**
* Get only models that have auth configured.
* This is a fast check that doesn't refresh OAuth tokens.
*/
getAvailable(): Model<Api>[] {
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | undefined {
return this.models.find((m) => m.provider === provider && m.id === modelId);
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | undefined> {
return this.authStorage.getApiKey(model.provider);
}
/**
* Get API key for a provider.
*/
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
return this.authStorage.getApiKey(provider);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
/**
* Unregister a previously registered provider.
*
* Removes the provider from the registry and reloads models from disk so that
* built-in models overridden by this provider are restored to their original state.
* Also resets dynamic OAuth and API stream registrations before reapplying
* remaining dynamic providers.
* Has no effect if the provider was never registered.
*/
unregisterProvider(providerName: string): void {
if (!this.registeredProviders.has(providerName)) return;
this.registeredProviders.delete(providerName);
this.customProviderApiKeys.delete(providerName);
this.refresh();
}
private applyProviderConfig(
providerName: string,
config: ProviderConfigInput,
): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
if (config.streamSimple) {
if (!config.api) {
throw new Error(
`Provider ${providerName}: "api" is required when registering streamSimple.`,
);
}
const streamSimple = config.streamSimple;
registerApiProvider(
{
api: config.api,
stream: (model, context, options) =>
streamSimple(model, context, options as SimpleStreamOptions),
streamSimple,
},
`provider:${providerName}`,
);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining models.`,
);
}
if (!config.apiKey && !config.oauth) {
throw new Error(
`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`,
);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
if (config.oauth?.modifyModels) {
const cred = this.authStorage.get(providerName);
if (cred?.type === "oauth") {
this.models = config.oauth.modifyModels(this.models, cred);
}
}
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders
? { ...m.headers, ...resolvedHeaders }
: m.headers,
};
});
}
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
}>;
}

View file

@ -0,0 +1,707 @@
/**
* Model resolution, scoping, and initial selection
*/
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import {
type Api,
type KnownProvider,
type Model,
modelsAreEqual,
} from "@mariozechner/pi-ai";
import chalk from "chalk";
import { minimatch } from "minimatch";
import { isValidThinkingLevel } from "../cli/args.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type { ModelRegistry } from "./model-registry.js";
/** Default model IDs for each known provider */
export const defaultModelPerProvider: Record<KnownProvider, string> = {
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
anthropic: "claude-opus-4-6",
openai: "gpt-5.4",
"azure-openai-responses": "gpt-5.2",
"openai-codex": "gpt-5.4",
google: "gemini-2.5-pro",
"google-gemini-cli": "gemini-2.5-pro",
"google-antigravity": "gemini-3.1-pro-high",
"google-vertex": "gemini-3-pro-preview",
"github-copilot": "gpt-4o",
openrouter: "openai/gpt-5.1-codex",
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
xai: "grok-4-fast-non-reasoning",
groq: "openai/gpt-oss-120b",
cerebras: "zai-glm-4.6",
zai: "glm-4.6",
mistral: "devstral-medium-latest",
minimax: "MiniMax-M2.1",
"minimax-cn": "MiniMax-M2.1",
huggingface: "moonshotai/Kimi-K2.5",
opencode: "claude-opus-4-6",
"opencode-go": "kimi-k2.5",
"kimi-coding": "kimi-k2-thinking",
};
export interface ScopedModel {
model: Model<Api>;
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
thinkingLevel?: ThinkingLevel;
}
/**
* Helper to check if a model ID looks like an alias (no date suffix)
* Dates are typically in format: -20241022 or -20250929
*/
function isAlias(id: string): boolean {
// Check if ID ends with -latest
if (id.endsWith("-latest")) return true;
// Check if ID ends with a date pattern (-YYYYMMDD)
const datePattern = /-\d{8}$/;
return !datePattern.test(id);
}
/**
* Try to match a pattern to a model from the available models list.
* Returns the matched model or undefined if no match found.
*/
function tryMatchModel(
modelPattern: string,
availableModels: Model<Api>[],
): Model<Api> | undefined {
// Check for provider/modelId format (provider is everything before the first /)
const slashIndex = modelPattern.indexOf("/");
if (slashIndex !== -1) {
const provider = modelPattern.substring(0, slashIndex);
const modelId = modelPattern.substring(slashIndex + 1);
const providerMatch = availableModels.find(
(m) =>
m.provider.toLowerCase() === provider.toLowerCase() &&
m.id.toLowerCase() === modelId.toLowerCase(),
);
if (providerMatch) {
return providerMatch;
}
// No exact provider/model match - fall through to other matching
}
// Check for exact ID match (case-insensitive)
const exactMatch = availableModels.find(
(m) => m.id.toLowerCase() === modelPattern.toLowerCase(),
);
if (exactMatch) {
return exactMatch;
}
// No exact match - fall back to partial matching
const matches = availableModels.filter(
(m) =>
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
);
if (matches.length === 0) {
return undefined;
}
// Separate into aliases and dated versions
const aliases = matches.filter((m) => isAlias(m.id));
const datedVersions = matches.filter((m) => !isAlias(m.id));
if (aliases.length > 0) {
// Prefer alias - if multiple aliases, pick the one that sorts highest
aliases.sort((a, b) => b.id.localeCompare(a.id));
return aliases[0];
} else {
// No alias found, pick latest dated version
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
return datedVersions[0];
}
}
export interface ParsedModelResult {
model: Model<Api> | undefined;
/** Thinking level if explicitly specified in pattern, undefined otherwise */
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
}
function buildFallbackModel(
provider: string,
modelId: string,
availableModels: Model<Api>[],
): Model<Api> | undefined {
const providerModels = availableModels.filter((m) => m.provider === provider);
if (providerModels.length === 0) return undefined;
const defaultId = defaultModelPerProvider[provider as KnownProvider];
const baseModel = defaultId
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
: providerModels[0];
return {
...baseModel,
id: modelId,
name: modelId,
};
}
/**
* Parse a pattern to extract model and thinking level.
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
*
* Algorithm:
* 1. Try to match full pattern as a model
* 2. If found, return it with "off" thinking level
* 3. If not found and has colons, split on last colon:
* - If suffix is valid thinking level, use it and recurse on prefix
* - If suffix is invalid, warn and recurse on prefix with "off"
*
* @internal Exported for testing
*/
export function parseModelPattern(
pattern: string,
availableModels: Model<Api>[],
options?: { allowInvalidThinkingLevelFallback?: boolean },
): ParsedModelResult {
// Try exact match first
const exactMatch = tryMatchModel(pattern, availableModels);
if (exactMatch) {
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
}
// No match - try splitting on last colon if present
const lastColonIndex = pattern.lastIndexOf(":");
if (lastColonIndex === -1) {
// No colons, pattern simply doesn't match any model
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
const prefix = pattern.substring(0, lastColonIndex);
const suffix = pattern.substring(lastColonIndex + 1);
if (isValidThinkingLevel(suffix)) {
// Valid thinking level - recurse on prefix and use this level
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
// Only use this thinking level if no warning from inner recursion
return {
model: result.model,
thinkingLevel: result.warning ? undefined : suffix,
warning: result.warning,
};
}
return result;
} else {
// Invalid suffix
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
if (!allowFallback) {
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
// This avoids accidentally resolving to a different model.
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
// Scope mode: recurse on prefix and warn
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
return {
model: result.model,
thinkingLevel: undefined,
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
};
}
return result;
}
}
/**
* Resolve model patterns to actual Model objects with optional thinking levels
* Format: "pattern:level" where :level is optional
* For each pattern, finds all matching models and picks the best version:
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
* 2. If no alias, pick the latest dated version
*
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
* The algorithm tries to match the full pattern first, then progressively
* strips colon-suffixes to find a match.
*/
export async function resolveModelScope(
patterns: string[],
modelRegistry: ModelRegistry,
): Promise<ScopedModel[]> {
const availableModels = await modelRegistry.getAvailable();
const scopedModels: ScopedModel[] = [];
for (const pattern of patterns) {
// Check if pattern contains glob characters
if (
pattern.includes("*") ||
pattern.includes("?") ||
pattern.includes("[")
) {
// Extract optional thinking level suffix (e.g., "provider/*:high")
const colonIdx = pattern.lastIndexOf(":");
let globPattern = pattern;
let thinkingLevel: ThinkingLevel | undefined;
if (colonIdx !== -1) {
const suffix = pattern.substring(colonIdx + 1);
if (isValidThinkingLevel(suffix)) {
thinkingLevel = suffix;
globPattern = pattern.substring(0, colonIdx);
}
}
// Match against "provider/modelId" format OR just model ID
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
const matchingModels = availableModels.filter((m) => {
const fullId = `${m.provider}/${m.id}`;
return (
minimatch(fullId, globPattern, { nocase: true }) ||
minimatch(m.id, globPattern, { nocase: true })
);
});
if (matchingModels.length === 0) {
console.warn(
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
);
continue;
}
for (const model of matchingModels) {
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
continue;
}
const { model, thinkingLevel, warning } = parseModelPattern(
pattern,
availableModels,
);
if (warning) {
console.warn(chalk.yellow(`Warning: ${warning}`));
}
if (!model) {
console.warn(
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
);
continue;
}
// Avoid duplicates
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
return scopedModels;
}
export interface ResolveCliModelResult {
model: Model<Api> | undefined;
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
/**
* Error message suitable for CLI display.
* When set, model will be undefined.
*/
error: string | undefined;
}
/**
* Resolve a single model from CLI flags.
*
* Supports:
* - --provider <provider> --model <pattern>
* - --model <provider>/<pattern>
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
*
* Note: This does not apply the thinking level by itself, but it may *parse* and
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
*/
export function resolveCliModel(options: {
cliProvider?: string;
cliModel?: string;
modelRegistry: ModelRegistry;
}): ResolveCliModelResult {
const { cliProvider, cliModel, modelRegistry } = options;
if (!cliModel) {
return { model: undefined, warning: undefined, error: undefined };
}
// Important: use *all* models here, not just models with pre-configured auth.
// This allows "--api-key" to be used for first-time setup.
const availableModels = modelRegistry.getAll();
if (availableModels.length === 0) {
return {
model: undefined,
warning: undefined,
error:
"No models available. Check your installation or add models to models.json.",
};
}
// Build canonical provider lookup (case-insensitive)
const providerMap = new Map<string, string>();
for (const m of availableModels) {
providerMap.set(m.provider.toLowerCase(), m.provider);
}
let provider = cliProvider
? providerMap.get(cliProvider.toLowerCase())
: undefined;
if (cliProvider && !provider) {
return {
model: undefined,
warning: undefined,
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
};
}
// If no explicit --provider, try to interpret "provider/model" format first.
// When the prefix before the first slash matches a known provider, prefer that
// interpretation over matching models whose IDs literally contain slashes
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
// vercel-ai-gateway model with id "zai/glm-5").
let pattern = cliModel;
let inferredProvider = false;
if (!provider) {
const slashIndex = cliModel.indexOf("/");
if (slashIndex !== -1) {
const maybeProvider = cliModel.substring(0, slashIndex);
const canonical = providerMap.get(maybeProvider.toLowerCase());
if (canonical) {
provider = canonical;
pattern = cliModel.substring(slashIndex + 1);
inferredProvider = true;
}
}
}
// If no provider was inferred from the slash, try exact matches without provider inference.
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
if (!provider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) =>
m.id.toLowerCase() === lower ||
`${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return {
model: exact,
warning: undefined,
thinkingLevel: undefined,
error: undefined,
};
}
}
if (cliProvider && provider) {
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
const prefix = `${provider}/`;
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
pattern = cliModel.substring(prefix.length);
}
}
const candidates = provider
? availableModels.filter((m) => m.provider === provider)
: availableModels;
const { model, thinkingLevel, warning } = parseModelPattern(
pattern,
candidates,
{
allowInvalidThinkingLevelFallback: false,
},
);
if (model) {
return { model, thinkingLevel, warning, error: undefined };
}
// If we inferred a provider from the slash but found no match within that provider,
// fall back to matching the full input as a raw model id across all models.
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
// looks like a provider but the full string is actually a model id on openrouter.
if (inferredProvider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) =>
m.id.toLowerCase() === lower ||
`${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return {
model: exact,
warning: undefined,
thinkingLevel: undefined,
error: undefined,
};
}
// Also try parseModelPattern on the full input against all models
const fallback = parseModelPattern(cliModel, availableModels, {
allowInvalidThinkingLevelFallback: false,
});
if (fallback.model) {
return {
model: fallback.model,
thinkingLevel: fallback.thinkingLevel,
warning: fallback.warning,
error: undefined,
};
}
}
if (provider) {
const fallbackModel = buildFallbackModel(
provider,
pattern,
availableModels,
);
if (fallbackModel) {
const fallbackWarning = warning
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
return {
model: fallbackModel,
thinkingLevel: undefined,
warning: fallbackWarning,
error: undefined,
};
}
}
const display = provider ? `${provider}/${pattern}` : cliModel;
return {
model: undefined,
thinkingLevel: undefined,
warning,
error: `Model "${display}" not found. Use --list-models to see available models.`,
};
}
export interface InitialModelResult {
model: Model<Api> | undefined;
thinkingLevel: ThinkingLevel;
fallbackMessage: string | undefined;
}
/**
* Find the initial model to use based on priority:
* 1. CLI args (provider + model)
* 2. First model from scoped models (if not continuing/resuming)
* 3. Restored from session (if continuing/resuming)
* 4. Saved default from settings
* 5. First available model with valid API key
*/
export async function findInitialModel(options: {
cliProvider?: string;
cliModel?: string;
scopedModels: ScopedModel[];
isContinuing: boolean;
defaultProvider?: string;
defaultModelId?: string;
defaultThinkingLevel?: ThinkingLevel;
modelRegistry: ModelRegistry;
}): Promise<InitialModelResult> {
const {
cliProvider,
cliModel,
scopedModels,
isContinuing,
defaultProvider,
defaultModelId,
defaultThinkingLevel,
modelRegistry,
} = options;
let model: Model<Api> | undefined;
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
// 1. CLI args take priority
if (cliProvider && cliModel) {
const resolved = resolveCliModel({
cliProvider,
cliModel,
modelRegistry,
});
if (resolved.error) {
console.error(chalk.red(resolved.error));
process.exit(1);
}
if (resolved.model) {
return {
model: resolved.model,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
}
// 2. Use first model from scoped models (skip if continuing/resuming)
if (scopedModels.length > 0 && !isContinuing) {
return {
model: scopedModels[0].model,
thinkingLevel:
scopedModels[0].thinkingLevel ??
defaultThinkingLevel ??
DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
// 3. Try saved default from settings
if (defaultProvider && defaultModelId) {
const found = modelRegistry.find(defaultProvider, defaultModelId);
if (found) {
model = found;
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: undefined };
}
}
// 4. Try first available model with valid API key
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
for (const provider of Object.keys(
defaultModelPerProvider,
) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find(
(m) => m.provider === provider && m.id === defaultId,
);
if (match) {
return {
model: match,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
}
// If no default found, use first available
return {
model: availableModels[0],
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
// 5. No model found
return {
model: undefined,
thinkingLevel: DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
/**
* Restore model from session, with fallback to available models
*/
export async function restoreModelFromSession(
savedProvider: string,
savedModelId: string,
currentModel: Model<Api> | undefined,
shouldPrintMessages: boolean,
modelRegistry: ModelRegistry,
): Promise<{
model: Model<Api> | undefined;
fallbackMessage: string | undefined;
}> {
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
const hasApiKey = restoredModel
? !!(await modelRegistry.getApiKey(restoredModel))
: false;
if (restoredModel && hasApiKey) {
if (shouldPrintMessages) {
console.log(
chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`),
);
}
return { model: restoredModel, fallbackMessage: undefined };
}
// Model not found or no API key - fall back
const reason = !restoredModel
? "model no longer exists"
: "no API key available";
if (shouldPrintMessages) {
console.error(
chalk.yellow(
`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`,
),
);
}
// If we already have a model, use it as fallback
if (currentModel) {
if (shouldPrintMessages) {
console.log(
chalk.dim(
`Falling back to: ${currentModel.provider}/${currentModel.id}`,
),
);
}
return {
model: currentModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
};
}
// Try to find any available model
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
let fallbackModel: Model<Api> | undefined;
for (const provider of Object.keys(
defaultModelPerProvider,
) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find(
(m) => m.provider === provider && m.id === defaultId,
);
if (match) {
fallbackModel = match;
break;
}
}
// If no default found, use first available
if (!fallbackModel) {
fallbackModel = availableModels[0];
}
if (shouldPrintMessages) {
console.log(
chalk.dim(
`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`,
),
);
}
return {
model: fallbackModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
};
}
// No models available
return { model: undefined, fallbackMessage: undefined };
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,327 @@
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
import { homedir } from "os";
import { basename, isAbsolute, join, resolve, sep } from "path";
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
import { parseFrontmatter } from "../utils/frontmatter.js";
/**
* Represents a prompt template loaded from a markdown file
*/
export interface PromptTemplate {
name: string;
description: string;
content: string;
source: string; // "user", "project", or "path"
filePath: string; // Absolute path to the template file
}
/**
* Parse command arguments respecting quoted strings (bash-style)
* Returns array of arguments
*/
export function parseCommandArgs(argsString: string): string[] {
const args: string[] = [];
let current = "";
let inQuote: string | null = null;
for (let i = 0; i < argsString.length; i++) {
const char = argsString[i];
if (inQuote) {
if (char === inQuote) {
inQuote = null;
} else {
current += char;
}
} else if (char === '"' || char === "'") {
inQuote = char;
} else if (char === " " || char === "\t") {
if (current) {
args.push(current);
current = "";
}
} else {
current += char;
}
}
if (current) {
args.push(current);
}
return args;
}
/**
* Substitute argument placeholders in template content
* Supports:
* - $1, $2, ... for positional args
* - $@ and $ARGUMENTS for all args
* - ${@:N} for args from Nth onwards (bash-style slicing)
* - ${@:N:L} for L args starting from Nth
*
* Note: Replacement happens on the template string only. Argument values
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
*/
export function substituteArgs(content: string, args: string[]): string {
let result = content;
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
result = result.replace(/\$(\d+)/g, (_, num) => {
const index = parseInt(num, 10) - 1;
return args[index] ?? "";
});
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
// Process BEFORE simple $@ to avoid conflicts
result = result.replace(
/\$\{@:(\d+)(?::(\d+))?\}/g,
(_, startStr, lengthStr) => {
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
// Treat 0 as 1 (bash convention: args start at 1)
if (start < 0) start = 0;
if (lengthStr) {
const length = parseInt(lengthStr, 10);
return args.slice(start, start + length).join(" ");
}
return args.slice(start).join(" ");
},
);
// Pre-compute all args joined (optimization)
const allArgs = args.join(" ");
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
result = result.replace(/\$ARGUMENTS/g, allArgs);
// Replace $@ with all args joined (existing syntax)
result = result.replace(/\$@/g, allArgs);
return result;
}
function loadTemplateFromFile(
filePath: string,
source: string,
sourceLabel: string,
): PromptTemplate | null {
try {
const rawContent = readFileSync(filePath, "utf-8");
const { frontmatter, body } =
parseFrontmatter<Record<string, string>>(rawContent);
const name = basename(filePath).replace(/\.md$/, "");
// Get description from frontmatter or first non-empty line
let description = frontmatter.description || "";
if (!description) {
const firstLine = body.split("\n").find((line) => line.trim());
if (firstLine) {
// Truncate if too long
description = firstLine.slice(0, 60);
if (firstLine.length > 60) description += "...";
}
}
// Append source to description
description = description ? `${description} ${sourceLabel}` : sourceLabel;
return {
name,
description,
content: body,
source,
filePath,
};
} catch {
return null;
}
}
/**
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
*/
function loadTemplatesFromDir(
dir: string,
source: string,
sourceLabel: string,
): PromptTemplate[] {
const templates: PromptTemplate[] = [];
if (!existsSync(dir)) {
return templates;
}
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const fullPath = join(dir, entry.name);
// For symlinks, check if they point to a file
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
const stats = statSync(fullPath);
isFile = stats.isFile();
} catch {
// Broken symlink, skip it
continue;
}
}
if (isFile && entry.name.endsWith(".md")) {
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
if (template) {
templates.push(template);
}
}
}
} catch {
return templates;
}
return templates;
}
export interface LoadPromptTemplatesOptions {
/** Working directory for project-local templates. Default: process.cwd() */
cwd?: string;
/** Agent config directory for global templates. Default: from getPromptsDir() */
agentDir?: string;
/** Explicit prompt template paths (files or directories) */
promptPaths?: string[];
/** Include default prompt directories. Default: true */
includeDefaults?: boolean;
}
function normalizePath(input: string): string {
const trimmed = input.trim();
if (trimmed === "~") return homedir();
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
return trimmed;
}
function resolvePromptPath(p: string, cwd: string): string {
const normalized = normalizePath(p);
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
}
function buildPathSourceLabel(p: string): string {
const base = basename(p).replace(/\.md$/, "") || "path";
return `(path:${base})`;
}
/**
* Load all prompt templates from:
* 1. Global: agentDir/prompts/
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
* 3. Explicit prompt paths
*/
export function loadPromptTemplates(
options: LoadPromptTemplatesOptions = {},
): PromptTemplate[] {
const resolvedCwd = options.cwd ?? process.cwd();
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
const promptPaths = options.promptPaths ?? [];
const includeDefaults = options.includeDefaults ?? true;
const templates: PromptTemplate[] = [];
if (includeDefaults) {
// 1. Load global templates from agentDir/prompts/
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
const globalPromptsDir = options.agentDir
? join(options.agentDir, "prompts")
: resolvedAgentDir;
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
templates.push(
...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"),
);
}
const userPromptsDir = options.agentDir
? join(options.agentDir, "prompts")
: resolvedAgentDir;
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
const isUnderPath = (target: string, root: string): boolean => {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep)
? normalizedRoot
: `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
};
const getSourceInfo = (
resolvedPath: string,
): { source: string; label: string } => {
if (!includeDefaults) {
if (isUnderPath(resolvedPath, userPromptsDir)) {
return { source: "user", label: "(user)" };
}
if (isUnderPath(resolvedPath, projectPromptsDir)) {
return { source: "project", label: "(project)" };
}
}
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
};
// 3. Load explicit prompt paths
for (const rawPath of promptPaths) {
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
if (!existsSync(resolvedPath)) {
continue;
}
try {
const stats = statSync(resolvedPath);
const { source, label } = getSourceInfo(resolvedPath);
if (stats.isDirectory()) {
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
const template = loadTemplateFromFile(resolvedPath, source, label);
if (template) {
templates.push(template);
}
}
} catch {
// Ignore read failures
}
}
return templates;
}
/**
* Expand a prompt template if it matches a template name.
* Returns the expanded content or the original text if not a template.
*/
export function expandPromptTemplate(
text: string,
templates: PromptTemplate[],
): string {
if (!text.startsWith("/")) return text;
const spaceIndex = text.indexOf(" ");
const templateName =
spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
const template = templates.find((t) => t.name === templateName);
if (template) {
const args = parseCommandArgs(argsString);
return substituteArgs(template.content, args);
}
return text;
}

View file

@ -0,0 +1,66 @@
/**
* Resolve configuration values that may be shell commands, environment variables, or literals.
* Used by auth-storage.ts and model-registry.ts.
*/
import { execSync } from "child_process";
// Cache for shell command results (persists for process lifetime)
const commandResultCache = new Map<string, string | undefined>();
/**
* Resolve a config value (API key, header value, etc.) to an actual value.
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
* - Otherwise checks environment variable first, then treats as literal (not cached)
*/
export function resolveConfigValue(config: string): string | undefined {
if (config.startsWith("!")) {
return executeCommand(config);
}
const envValue = process.env[config];
return envValue || config;
}
function executeCommand(commandConfig: string): string | undefined {
if (commandResultCache.has(commandConfig)) {
return commandResultCache.get(commandConfig);
}
const command = commandConfig.slice(1);
let result: string | undefined;
try {
const output = execSync(command, {
encoding: "utf-8",
timeout: 10000,
stdio: ["ignore", "pipe", "ignore"],
});
result = output.trim() || undefined;
} catch {
result = undefined;
}
commandResultCache.set(commandConfig, result);
return result;
}
/**
* Resolve all header values using the same resolution logic as API keys.
*/
export function resolveHeaders(
headers: Record<string, string> | undefined,
): Record<string, string> | undefined {
if (!headers) return undefined;
const resolved: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
const resolvedValue = resolveConfigValue(value);
if (resolvedValue) {
resolved[key] = resolvedValue;
}
}
return Object.keys(resolved).length > 0 ? resolved : undefined;
}
/** Clear the config value command cache. Exported for testing. */
export function clearConfigValueCache(): void {
commandResultCache.clear();
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,398 @@
import { join } from "node:path";
import {
Agent,
type AgentMessage,
type ThinkingLevel,
} from "@mariozechner/pi-agent-core";
import type { Message, Model } from "@mariozechner/pi-ai";
import { getAgentDir, getDocsPath } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type {
ExtensionRunner,
LoadExtensionsResult,
ToolDefinition,
} from "./extensions/index.js";
import { convertToLlm } from "./messages.js";
import { ModelRegistry } from "./model-registry.js";
import { findInitialModel } from "./model-resolver.js";
import type { ResourceLoader } from "./resource-loader.js";
import { DefaultResourceLoader } from "./resource-loader.js";
import { SessionManager } from "./session-manager.js";
import { SettingsManager } from "./settings-manager.js";
import { time } from "./timings.js";
import {
allTools,
bashTool,
codingTools,
createBashTool,
createCodingTools,
createEditTool,
createFindTool,
createGrepTool,
createLsTool,
createReadOnlyTools,
createReadTool,
createWriteTool,
editTool,
findTool,
grepTool,
lsTool,
readOnlyTools,
readTool,
type Tool,
type ToolName,
writeTool,
} from "./tools/index.js";
export interface CreateAgentSessionOptions {
/** Working directory for project-local discovery. Default: process.cwd() */
cwd?: string;
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
authStorage?: AuthStorage;
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
thinkingLevel?: ThinkingLevel;
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
tools?: Tool[];
/** Custom tools to register (in addition to built-in tools). */
customTools?: ToolDefinition[];
/** Resource loader. When omitted, DefaultResourceLoader is used. */
resourceLoader?: ResourceLoader;
/** Session manager. Default: SessionManager.create(cwd) */
sessionManager?: SessionManager;
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
settingsManager?: SettingsManager;
}
/** Result from createAgentSession */
export interface CreateAgentSessionResult {
/** The created session */
session: AgentSession;
/** Extensions result (for UI context setup in interactive mode) */
extensionsResult: LoadExtensionsResult;
/** Warning if session was restored with a different model than saved */
modelFallbackMessage?: string;
}
// Re-exports
export type {
ExtensionAPI,
ExtensionCommandContext,
ExtensionContext,
ExtensionFactory,
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
ToolDefinition,
} from "./extensions/index.js";
export type { PromptTemplate } from "./prompt-templates.js";
export type { Skill } from "./skills.js";
export type { Tool } from "./tools/index.js";
export {
// Pre-built tools (use process.cwd())
readTool,
bashTool,
editTool,
writeTool,
grepTool,
findTool,
lsTool,
codingTools,
readOnlyTools,
allTools as allBuiltInTools,
// Tool factories (for custom cwd)
createCodingTools,
createReadOnlyTools,
createReadTool,
createBashTool,
createEditTool,
createWriteTool,
createGrepTool,
createFindTool,
createLsTool,
};
// Helper Functions
function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Create an AgentSession with the specified options.
*
* @example
* ```typescript
* // Minimal - uses defaults
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@mariozechner/pi-ai';
* const { session } = await createAgentSession({
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
* // Continue previous session
* const { session, modelFallbackMessage } = await createAgentSession({
* continueSession: true,
* });
*
* // Full control
* const loader = new DefaultResourceLoader({
* cwd: process.cwd(),
* agentDir: getAgentDir(),
* settingsManager: SettingsManager.create(),
* });
* await loader.reload();
* const { session } = await createAgentSession({
* model: myModel,
* tools: [readTool, bashTool],
* resourceLoader: loader,
* sessionManager: SessionManager.inMemory(),
* });
* ```
*/
export async function createAgentSession(
options: CreateAgentSessionOptions = {},
): Promise<CreateAgentSessionResult> {
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
let resourceLoader = options.resourceLoader;
// Use provided or create AuthStorage and ModelRegistry
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
const modelsPath = options.agentDir
? join(agentDir, "models.json")
: undefined;
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
const modelRegistry =
options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const settingsManager =
options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
if (!resourceLoader) {
resourceLoader = new DefaultResourceLoader({
cwd,
agentDir,
settingsManager,
});
await resourceLoader.reload();
time("resourceLoader.reload");
}
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
const hasExistingSession = existingSession.messages.length > 0;
const hasThinkingEntry = sessionManager
.getBranch()
.some((entry) => entry.type === "thinking_level_change");
let model = options.model;
let modelFallbackMessage: string | undefined;
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = modelRegistry.find(
existingSession.model.provider,
existingSession.model.modelId,
);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
}
}
// If still no model, use findInitialModel (checks settings default, then provider defaults)
if (!model) {
const result = await findInitialModel({
scopedModels: [],
isContinuing: hasExistingSession,
defaultProvider: settingsManager.getDefaultProvider(),
defaultModelId: settingsManager.getDefaultModel(),
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
modelRegistry,
});
model = result.model;
if (!model) {
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
} else if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
}
}
let thinkingLevel = options.thinkingLevel;
// If session has data, restore thinking level from it
if (thinkingLevel === undefined && hasExistingSession) {
thinkingLevel = hasThinkingEntry
? (existingSession.thinkingLevel as ThinkingLevel)
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
}
// Fall back to settings default
if (thinkingLevel === undefined) {
thinkingLevel =
settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
}
// Clamp to model capabilities
if (!model || !model.reasoning) {
thinkingLevel = "off";
}
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
const initialActiveToolNames: ToolName[] = options.tools
? options.tools
.map((t) => t.name)
.filter((n): n is ToolName => n in allTools)
: defaultActiveToolNames;
let agent: Agent;
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
const converted = convertToLlm(messages);
// Check setting dynamically so mid-session changes take effect
if (!settingsManager.getBlockImages()) {
return converted;
}
// Filter out ImageContent from all messages, replacing with text placeholder
return converted.map((msg) => {
if (msg.role === "user" || msg.role === "toolResult") {
const content = msg.content;
if (Array.isArray(content)) {
const hasImages = content.some((c) => c.type === "image");
if (hasImages) {
const filteredContent = content
.map((c) =>
c.type === "image"
? {
type: "text" as const,
text: "Image reading is disabled.",
}
: c,
)
.filter(
(c, i, arr) =>
// Dedupe consecutive "Image reading is disabled." texts
!(
c.type === "text" &&
c.text === "Image reading is disabled." &&
i > 0 &&
arr[i - 1].type === "text" &&
(arr[i - 1] as { type: "text"; text: string }).text ===
"Image reading is disabled."
),
);
return { ...msg, content: filteredContent };
}
}
}
return msg;
});
};
const extensionRunnerRef: { current?: ExtensionRunner } = {};
agent = new Agent({
initialState: {
systemPrompt: "",
model,
thinkingLevel,
tools: [],
},
convertToLlm: convertToLlmWithBlockImages,
sessionId: sessionManager.getSessionId(),
transformContext: async (messages) => {
const runner = extensionRunnerRef.current;
if (!runner) return messages;
return runner.emitContext(messages);
},
steeringMode: settingsManager.getSteeringMode(),
followUpMode: settingsManager.getFollowUpMode(),
transport: settingsManager.getTransport(),
thinkingBudgets: settingsManager.getThinkingBudgets(),
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
getApiKey: async (provider) => {
// Use the provider argument from the in-flight request;
// agent.state.model may already be switched mid-turn.
const resolvedProvider = provider || agent.state.model?.provider;
if (!resolvedProvider) {
throw new Error("No model selected");
}
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
if (!key) {
const model = agent.state.model;
const isOAuth = model && modelRegistry.isUsingOAuth(model);
if (isOAuth) {
throw new Error(
`Authentication failed for "${resolvedProvider}". ` +
`Credentials may have expired or network is unavailable. ` +
`Run '/login ${resolvedProvider}' to re-authenticate.`,
);
}
throw new Error(
`No API key found for "${resolvedProvider}". ` +
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
);
}
return key;
},
});
// Restore messages if session has existing data
if (hasExistingSession) {
agent.replaceMessages(existingSession.messages);
if (!hasThinkingEntry) {
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
} else {
// Save initial model and thinking level for new sessions so they can be restored on resume
if (model) {
sessionManager.appendModelChange(model.provider, model.id);
}
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
const session = new AgentSession({
agent,
sessionManager,
settingsManager,
cwd,
scopedModels: options.scopedModels,
resourceLoader,
customTools: options.customTools,
modelRegistry,
initialActiveToolNames,
extensionRunnerRef,
});
const extensionsResult = resourceLoader.getExtensions();
return {
session,
extensionsResult,
modelFallbackMessage,
};
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,518 @@
import {
existsSync,
readdirSync,
readFileSync,
realpathSync,
statSync,
} from "fs";
import ignore from "ignore";
import { homedir } from "os";
import {
basename,
dirname,
isAbsolute,
join,
relative,
resolve,
sep,
} from "path";
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
import { parseFrontmatter } from "../utils/frontmatter.js";
import type { ResourceDiagnostic } from "./diagnostics.js";
/** Max name length per spec */
const MAX_NAME_LENGTH = 64;
/** Max description length per spec */
const MAX_DESCRIPTION_LENGTH = 1024;
const IGNORE_FILE_NAMES = [".gitignore", ".ignore", ".fdignore"];
type IgnoreMatcher = ReturnType<typeof ignore>;
function toPosixPath(p: string): string {
return p.split(sep).join("/");
}
function prefixIgnorePattern(line: string, prefix: string): string | null {
const trimmed = line.trim();
if (!trimmed) return null;
if (trimmed.startsWith("#") && !trimmed.startsWith("\\#")) return null;
let pattern = line;
let negated = false;
if (pattern.startsWith("!")) {
negated = true;
pattern = pattern.slice(1);
} else if (pattern.startsWith("\\!")) {
pattern = pattern.slice(1);
}
if (pattern.startsWith("/")) {
pattern = pattern.slice(1);
}
const prefixed = prefix ? `${prefix}${pattern}` : pattern;
return negated ? `!${prefixed}` : prefixed;
}
function addIgnoreRules(ig: IgnoreMatcher, dir: string, rootDir: string): void {
const relativeDir = relative(rootDir, dir);
const prefix = relativeDir ? `${toPosixPath(relativeDir)}/` : "";
for (const filename of IGNORE_FILE_NAMES) {
const ignorePath = join(dir, filename);
if (!existsSync(ignorePath)) continue;
try {
const content = readFileSync(ignorePath, "utf-8");
const patterns = content
.split(/\r?\n/)
.map((line) => prefixIgnorePattern(line, prefix))
.filter((line): line is string => Boolean(line));
if (patterns.length > 0) {
ig.add(patterns);
}
} catch {}
}
}
export interface SkillFrontmatter {
name?: string;
description?: string;
"disable-model-invocation"?: boolean;
[key: string]: unknown;
}
export interface Skill {
name: string;
description: string;
filePath: string;
baseDir: string;
source: string;
disableModelInvocation: boolean;
}
export interface LoadSkillsResult {
skills: Skill[];
diagnostics: ResourceDiagnostic[];
}
/**
* Validate skill name per Agent Skills spec.
* Returns array of validation error messages (empty if valid).
*/
function validateName(name: string, parentDirName: string): string[] {
const errors: string[] = [];
if (name !== parentDirName) {
errors.push(
`name "${name}" does not match parent directory "${parentDirName}"`,
);
}
if (name.length > MAX_NAME_LENGTH) {
errors.push(`name exceeds ${MAX_NAME_LENGTH} characters (${name.length})`);
}
if (!/^[a-z0-9-]+$/.test(name)) {
errors.push(
`name contains invalid characters (must be lowercase a-z, 0-9, hyphens only)`,
);
}
if (name.startsWith("-") || name.endsWith("-")) {
errors.push(`name must not start or end with a hyphen`);
}
if (name.includes("--")) {
errors.push(`name must not contain consecutive hyphens`);
}
return errors;
}
/**
* Validate description per Agent Skills spec.
*/
function validateDescription(description: string | undefined): string[] {
const errors: string[] = [];
if (!description || description.trim() === "") {
errors.push("description is required");
} else if (description.length > MAX_DESCRIPTION_LENGTH) {
errors.push(
`description exceeds ${MAX_DESCRIPTION_LENGTH} characters (${description.length})`,
);
}
return errors;
}
export interface LoadSkillsFromDirOptions {
/** Directory to scan for skills */
dir: string;
/** Source identifier for these skills */
source: string;
}
/**
* Load skills from a directory.
*
* Discovery rules:
* - direct .md children in the root
* - recursive SKILL.md under subdirectories
*/
export function loadSkillsFromDir(
options: LoadSkillsFromDirOptions,
): LoadSkillsResult {
const { dir, source } = options;
return loadSkillsFromDirInternal(dir, source, true);
}
function loadSkillsFromDirInternal(
dir: string,
source: string,
includeRootFiles: boolean,
ignoreMatcher?: IgnoreMatcher,
rootDir?: string,
): LoadSkillsResult {
const skills: Skill[] = [];
const diagnostics: ResourceDiagnostic[] = [];
if (!existsSync(dir)) {
return { skills, diagnostics };
}
const root = rootDir ?? dir;
const ig = ignoreMatcher ?? ignore();
addIgnoreRules(ig, dir, root);
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
if (entry.name.startsWith(".")) {
continue;
}
// Skip node_modules to avoid scanning dependencies
if (entry.name === "node_modules") {
continue;
}
const fullPath = join(dir, entry.name);
// For symlinks, check if they point to a directory and follow them
let isDirectory = entry.isDirectory();
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
const stats = statSync(fullPath);
isDirectory = stats.isDirectory();
isFile = stats.isFile();
} catch {
// Broken symlink, skip it
continue;
}
}
const relPath = toPosixPath(relative(root, fullPath));
const ignorePath = isDirectory ? `${relPath}/` : relPath;
if (ig.ignores(ignorePath)) {
continue;
}
if (isDirectory) {
const subResult = loadSkillsFromDirInternal(
fullPath,
source,
false,
ig,
root,
);
skills.push(...subResult.skills);
diagnostics.push(...subResult.diagnostics);
continue;
}
if (!isFile) {
continue;
}
const isRootMd = includeRootFiles && entry.name.endsWith(".md");
const isSkillMd = !includeRootFiles && entry.name === "SKILL.md";
if (!isRootMd && !isSkillMd) {
continue;
}
const result = loadSkillFromFile(fullPath, source);
if (result.skill) {
skills.push(result.skill);
}
diagnostics.push(...result.diagnostics);
}
} catch {}
return { skills, diagnostics };
}
function loadSkillFromFile(
filePath: string,
source: string,
): { skill: Skill | null; diagnostics: ResourceDiagnostic[] } {
const diagnostics: ResourceDiagnostic[] = [];
try {
const rawContent = readFileSync(filePath, "utf-8");
const { frontmatter } = parseFrontmatter<SkillFrontmatter>(rawContent);
const skillDir = dirname(filePath);
const parentDirName = basename(skillDir);
// Validate description
const descErrors = validateDescription(frontmatter.description);
for (const error of descErrors) {
diagnostics.push({ type: "warning", message: error, path: filePath });
}
// Use name from frontmatter, or fall back to parent directory name
const name = frontmatter.name || parentDirName;
// Validate name
const nameErrors = validateName(name, parentDirName);
for (const error of nameErrors) {
diagnostics.push({ type: "warning", message: error, path: filePath });
}
// Still load the skill even with warnings (unless description is completely missing)
if (!frontmatter.description || frontmatter.description.trim() === "") {
return { skill: null, diagnostics };
}
return {
skill: {
name,
description: frontmatter.description,
filePath,
baseDir: skillDir,
source,
disableModelInvocation:
frontmatter["disable-model-invocation"] === true,
},
diagnostics,
};
} catch (error) {
const message =
error instanceof Error ? error.message : "failed to parse skill file";
diagnostics.push({ type: "warning", message, path: filePath });
return { skill: null, diagnostics };
}
}
/**
* Format skills for inclusion in a system prompt.
* Uses XML format per Agent Skills standard.
* See: https://agentskills.io/integrate-skills
*
* Skills with disableModelInvocation=true are excluded from the prompt
* (they can only be invoked explicitly via /skill:name commands).
*/
export function formatSkillsForPrompt(skills: Skill[]): string {
const visibleSkills = skills.filter((s) => !s.disableModelInvocation);
if (visibleSkills.length === 0) {
return "";
}
const lines = [
"\n\nThe following skills provide specialized instructions for specific tasks.",
"Use the read tool to load a skill's file when the task matches its description.",
"When a skill file references a relative path, resolve it against the skill directory (parent of SKILL.md / dirname of the path) and use that absolute path in tool commands.",
"",
"<available_skills>",
];
for (const skill of visibleSkills) {
lines.push(" <skill>");
lines.push(` <name>${escapeXml(skill.name)}</name>`);
lines.push(
` <description>${escapeXml(skill.description)}</description>`,
);
lines.push(` <location>${escapeXml(skill.filePath)}</location>`);
lines.push(" </skill>");
}
lines.push("</available_skills>");
return lines.join("\n");
}
function escapeXml(str: string): string {
return str
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&apos;");
}
export interface LoadSkillsOptions {
/** Working directory for project-local skills. Default: process.cwd() */
cwd?: string;
/** Agent config directory for global skills. Default: ~/.pi/agent */
agentDir?: string;
/** Explicit skill paths (files or directories) */
skillPaths?: string[];
/** Include default skills directories. Default: true */
includeDefaults?: boolean;
}
function normalizePath(input: string): string {
const trimmed = input.trim();
if (trimmed === "~") return homedir();
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
return trimmed;
}
function resolveSkillPath(p: string, cwd: string): string {
const normalized = normalizePath(p);
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
}
/**
* Load skills from all configured locations.
* Returns skills and any validation diagnostics.
*/
export function loadSkills(options: LoadSkillsOptions = {}): LoadSkillsResult {
const {
cwd = process.cwd(),
agentDir,
skillPaths = [],
includeDefaults = true,
} = options;
// Resolve agentDir - if not provided, use default from config
const resolvedAgentDir = agentDir ?? getAgentDir();
const skillMap = new Map<string, Skill>();
const realPathSet = new Set<string>();
const allDiagnostics: ResourceDiagnostic[] = [];
const collisionDiagnostics: ResourceDiagnostic[] = [];
function addSkills(result: LoadSkillsResult) {
allDiagnostics.push(...result.diagnostics);
for (const skill of result.skills) {
// Resolve symlinks to detect duplicate files
let realPath: string;
try {
realPath = realpathSync(skill.filePath);
} catch {
realPath = skill.filePath;
}
// Skip silently if we've already loaded this exact file (via symlink)
if (realPathSet.has(realPath)) {
continue;
}
const existing = skillMap.get(skill.name);
if (existing) {
collisionDiagnostics.push({
type: "collision",
message: `name "${skill.name}" collision`,
path: skill.filePath,
collision: {
resourceType: "skill",
name: skill.name,
winnerPath: existing.filePath,
loserPath: skill.filePath,
},
});
} else {
skillMap.set(skill.name, skill);
realPathSet.add(realPath);
}
}
}
if (includeDefaults) {
addSkills(
loadSkillsFromDirInternal(join(resolvedAgentDir, "skills"), "user", true),
);
addSkills(
loadSkillsFromDirInternal(
resolve(cwd, CONFIG_DIR_NAME, "skills"),
"project",
true,
),
);
}
const userSkillsDir = join(resolvedAgentDir, "skills");
const projectSkillsDir = resolve(cwd, CONFIG_DIR_NAME, "skills");
const isUnderPath = (target: string, root: string): boolean => {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep)
? normalizedRoot
: `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
};
const getSource = (resolvedPath: string): "user" | "project" | "path" => {
if (!includeDefaults) {
if (isUnderPath(resolvedPath, userSkillsDir)) return "user";
if (isUnderPath(resolvedPath, projectSkillsDir)) return "project";
}
return "path";
};
for (const rawPath of skillPaths) {
const resolvedPath = resolveSkillPath(rawPath, cwd);
if (!existsSync(resolvedPath)) {
allDiagnostics.push({
type: "warning",
message: "skill path does not exist",
path: resolvedPath,
});
continue;
}
try {
const stats = statSync(resolvedPath);
const source = getSource(resolvedPath);
if (stats.isDirectory()) {
addSkills(loadSkillsFromDirInternal(resolvedPath, source, true));
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
const result = loadSkillFromFile(resolvedPath, source);
if (result.skill) {
addSkills({
skills: [result.skill],
diagnostics: result.diagnostics,
});
} else {
allDiagnostics.push(...result.diagnostics);
}
} else {
allDiagnostics.push({
type: "warning",
message: "skill path is not a markdown file",
path: resolvedPath,
});
}
} catch (error) {
const message =
error instanceof Error ? error.message : "failed to read skill path";
allDiagnostics.push({ type: "warning", message, path: resolvedPath });
}
}
return {
skills: Array.from(skillMap.values()),
diagnostics: [...allDiagnostics, ...collisionDiagnostics],
};
}

View file

@ -0,0 +1,44 @@
export type SlashCommandSource = "extension" | "prompt" | "skill";
export type SlashCommandLocation = "user" | "project" | "path";
export interface SlashCommandInfo {
name: string;
description?: string;
source: SlashCommandSource;
location?: SlashCommandLocation;
path?: string;
}
export interface BuiltinSlashCommand {
name: string;
description: string;
}
export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
{ name: "settings", description: "Open settings menu" },
{ name: "model", description: "Select model (opens selector UI)" },
{
name: "scoped-models",
description: "Enable/disable models for Ctrl+P cycling",
},
{ name: "export", description: "Export session to HTML file" },
{ name: "share", description: "Share session as a secret GitHub gist" },
{ name: "copy", description: "Copy last agent message to clipboard" },
{ name: "name", description: "Set session display name" },
{ name: "session", description: "Show session info and stats" },
{ name: "changelog", description: "Show changelog entries" },
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
{ name: "fork", description: "Create a new fork from a previous message" },
{ name: "tree", description: "Navigate session tree (switch branches)" },
{ name: "login", description: "Login with OAuth provider" },
{ name: "logout", description: "Logout from OAuth provider" },
{ name: "new", description: "Start a new session" },
{ name: "compact", description: "Manually compact the session context" },
{ name: "resume", description: "Resume a different session" },
{
name: "reload",
description: "Reload extensions, skills, prompts, and themes",
},
{ name: "quit", description: "Quit pi" },
];

View file

@ -0,0 +1,237 @@
/**
* System prompt construction and project context loading
*/
import { getDocsPath, getReadmePath } from "../config.js";
import { formatSkillsForPrompt, type Skill } from "./skills.js";
/** Tool descriptions for system prompt */
const toolDescriptions: Record<string, string> = {
read: "Read file contents",
bash: "Execute bash commands (ls, grep, find, etc.)",
edit: "Make surgical edits to files (find exact text and replace)",
write: "Create or overwrite files",
grep: "Search file contents for patterns (respects .gitignore)",
find: "Find files by glob pattern (respects .gitignore)",
ls: "List directory contents",
};
export interface BuildSystemPromptOptions {
/** Custom system prompt (replaces default). */
customPrompt?: string;
/** Tools to include in prompt. Default: [read, bash, edit, write] */
selectedTools?: string[];
/** Optional one-line tool snippets keyed by tool name. */
toolSnippets?: Record<string, string>;
/** Additional guideline bullets appended to the default system prompt guidelines. */
promptGuidelines?: string[];
/** Text to append to system prompt. */
appendSystemPrompt?: string;
/** Working directory. Default: process.cwd() */
cwd?: string;
/** Pre-loaded context files. */
contextFiles?: Array<{ path: string; content: string }>;
/** Pre-loaded skills. */
skills?: Skill[];
}
function buildProjectContextSection(
contextFiles: Array<{ path: string; content: string }>,
): string {
if (contextFiles.length === 0) {
return "";
}
const hasSoulFile = contextFiles.some(
({ path }) =>
path.replaceAll("\\", "/").endsWith("/SOUL.md") || path === "SOUL.md",
);
let section = "\n\n# Project Context\n\n";
section += "Project-specific instructions and guidelines:\n";
if (hasSoulFile) {
section +=
"\nIf SOUL.md is present, embody its persona and tone. Avoid generic assistant filler and follow its guidance unless higher-priority instructions override it.\n";
}
section += "\n";
for (const { path: filePath, content } of contextFiles) {
section += `## ${filePath}\n\n${content}\n\n`;
}
return section;
}
/** Build the system prompt with tools, guidelines, and context */
export function buildSystemPrompt(
options: BuildSystemPromptOptions = {},
): string {
const {
customPrompt,
selectedTools,
toolSnippets,
promptGuidelines,
appendSystemPrompt,
cwd,
contextFiles: providedContextFiles,
skills: providedSkills,
} = options;
const resolvedCwd = cwd ?? process.cwd();
const now = new Date();
const dateTime = now.toLocaleString("en-US", {
weekday: "long",
year: "numeric",
month: "long",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
timeZoneName: "short",
});
const appendSection = appendSystemPrompt ? `\n\n${appendSystemPrompt}` : "";
const contextFiles = providedContextFiles ?? [];
const skills = providedSkills ?? [];
if (customPrompt) {
let prompt = customPrompt;
if (appendSection) {
prompt += appendSection;
}
// Append project context files
prompt += buildProjectContextSection(contextFiles);
// Append skills section (only if read tool is available)
const customPromptHasRead =
!selectedTools || selectedTools.includes("read");
if (customPromptHasRead && skills.length > 0) {
prompt += formatSkillsForPrompt(skills);
}
// Add date/time and working directory last
prompt += `\nCurrent date and time: ${dateTime}`;
prompt += `\nCurrent working directory: ${resolvedCwd}`;
return prompt;
}
// Get absolute paths to documentation
const readmePath = getReadmePath();
const docsPath = getDocsPath();
// Build tools list based on selected tools.
// Built-ins use toolDescriptions. Custom tools can provide one-line snippets.
const tools = selectedTools || ["read", "bash", "edit", "write"];
const toolsList =
tools.length > 0
? tools
.map((name) => {
const snippet =
toolSnippets?.[name] ?? toolDescriptions[name] ?? name;
return `- ${name}: ${snippet}`;
})
.join("\n")
: "(none)";
// Build guidelines based on which tools are actually available
const guidelinesList: string[] = [];
const guidelinesSet = new Set<string>();
const addGuideline = (guideline: string): void => {
if (guidelinesSet.has(guideline)) {
return;
}
guidelinesSet.add(guideline);
guidelinesList.push(guideline);
};
const hasBash = tools.includes("bash");
const hasEdit = tools.includes("edit");
const hasWrite = tools.includes("write");
const hasGrep = tools.includes("grep");
const hasFind = tools.includes("find");
const hasLs = tools.includes("ls");
const hasRead = tools.includes("read");
// File exploration guidelines
if (hasBash && !hasGrep && !hasFind && !hasLs) {
addGuideline("Use bash for file operations like ls, rg, find");
} else if (hasBash && (hasGrep || hasFind || hasLs)) {
addGuideline(
"Prefer grep/find/ls tools over bash for file exploration (faster, respects .gitignore)",
);
}
// Read before edit guideline
if (hasRead && hasEdit) {
addGuideline(
"Use read to examine files before editing. You must use this tool instead of cat or sed.",
);
}
// Edit guideline
if (hasEdit) {
addGuideline("Use edit for precise changes (old text must match exactly)");
}
// Write guideline
if (hasWrite) {
addGuideline("Use write only for new files or complete rewrites");
}
// Output guideline (only when actually writing or executing)
if (hasEdit || hasWrite) {
addGuideline(
"When summarizing your actions, output plain text directly - do NOT use cat or bash to display what you did",
);
}
for (const guideline of promptGuidelines ?? []) {
const normalized = guideline.trim();
if (normalized.length > 0) {
addGuideline(normalized);
}
}
// Always include these
addGuideline("Be concise in your responses");
addGuideline("Show file paths clearly when working with files");
const guidelines = guidelinesList.map((g) => `- ${g}`).join("\n");
let prompt = `You are an expert coding assistant operating inside pi, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
Available tools:
${toolsList}
In addition to the tools above, you may have access to other custom tools depending on the project.
Guidelines:
${guidelines}
Pi documentation (read only when the user asks about pi itself, its SDK, extensions, themes, skills, or TUI):
- Main documentation: ${readmePath}
- Additional docs: ${docsPath}
- When asked about: extensions (docs/extensions.md), themes (docs/themes.md), skills (docs/skills.md), prompt templates (docs/prompt-templates.md), TUI components (docs/tui.md), keybindings (docs/keybindings.md), SDK integrations (docs/sdk.md), custom providers (docs/custom-provider.md), adding models (docs/models.md), pi packages (docs/packages.md)
- When working on pi topics, read the docs and follow .md cross-references before implementing
- Always read pi .md files completely and follow links to related docs (e.g., tui.md for TUI API details)`;
if (appendSection) {
prompt += appendSection;
}
// Append project context files
prompt += buildProjectContextSection(contextFiles);
// Append skills section (only if read tool is available)
if (hasRead && skills.length > 0) {
prompt += formatSkillsForPrompt(skills);
}
// Add date/time and working directory last
prompt += `\nCurrent date and time: ${dateTime}`;
prompt += `\nCurrent working directory: ${resolvedCwd}`;
return prompt;
}

View file

@ -0,0 +1,25 @@
/**
* Central timing instrumentation for startup profiling.
* Enable with PI_TIMING=1 environment variable.
*/
const ENABLED = process.env.PI_TIMING === "1";
const timings: Array<{ label: string; ms: number }> = [];
let lastTime = Date.now();
export function time(label: string): void {
if (!ENABLED) return;
const now = Date.now();
timings.push({ label, ms: now - lastTime });
lastTime = now;
}
export function printTimings(): void {
if (!ENABLED || timings.length === 0) return;
console.error("\n--- Startup Timings ---");
for (const t of timings) {
console.error(` ${t.label}: ${t.ms}ms`);
}
console.error(` TOTAL: ${timings.reduce((a, b) => a + b.ms, 0)}ms`);
console.error("------------------------\n");
}

View file

@ -0,0 +1,358 @@
import { randomBytes } from "node:crypto";
import { createWriteStream, existsSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import {
getShellConfig,
getShellEnv,
killProcessTree,
} from "../../utils/shell.js";
import {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationResult,
truncateTail,
} from "./truncate.js";
/**
* Generate a unique temp file path for bash output
*/
function getTempFilePath(): string {
const id = randomBytes(8).toString("hex");
return join(tmpdir(), `pi-bash-${id}.log`);
}
const bashSchema = Type.Object({
command: Type.String({ description: "Bash command to execute" }),
timeout: Type.Optional(
Type.Number({
description: "Timeout in seconds (optional, no default timeout)",
}),
),
});
export type BashToolInput = Static<typeof bashSchema>;
export interface BashToolDetails {
truncation?: TruncationResult;
fullOutputPath?: string;
}
/**
* Pluggable operations for the bash tool.
* Override these to delegate command execution to remote systems (e.g., SSH).
*/
export interface BashOperations {
/**
* Execute a command and stream output.
* @param command - The command to execute
* @param cwd - Working directory
* @param options - Execution options
* @returns Promise resolving to exit code (null if killed)
*/
exec: (
command: string,
cwd: string,
options: {
onData: (data: Buffer) => void;
signal?: AbortSignal;
timeout?: number;
env?: NodeJS.ProcessEnv;
},
) => Promise<{ exitCode: number | null }>;
}
/**
* Default bash operations using local shell
*/
const defaultBashOperations: BashOperations = {
exec: (command, cwd, { onData, signal, timeout, env }) => {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
if (!existsSync(cwd)) {
reject(
new Error(
`Working directory does not exist: ${cwd}\nCannot execute bash commands.`,
),
);
return;
}
const child = spawn(shell, [...args, command], {
cwd,
detached: true,
env: env ?? getShellEnv(),
stdio: ["ignore", "pipe", "pipe"],
});
let timedOut = false;
// Set timeout if provided
let timeoutHandle: NodeJS.Timeout | undefined;
if (timeout !== undefined && timeout > 0) {
timeoutHandle = setTimeout(() => {
timedOut = true;
if (child.pid) {
killProcessTree(child.pid);
}
}, timeout * 1000);
}
// Stream stdout and stderr
if (child.stdout) {
child.stdout.on("data", onData);
}
if (child.stderr) {
child.stderr.on("data", onData);
}
// Handle shell spawn errors
child.on("error", (err) => {
if (timeoutHandle) clearTimeout(timeoutHandle);
if (signal) signal.removeEventListener("abort", onAbort);
reject(err);
});
// Handle abort signal - kill entire process tree
const onAbort = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (signal) {
if (signal.aborted) {
onAbort();
} else {
signal.addEventListener("abort", onAbort, { once: true });
}
}
// Handle process exit
child.on("close", (code) => {
if (timeoutHandle) clearTimeout(timeoutHandle);
if (signal) signal.removeEventListener("abort", onAbort);
if (signal?.aborted) {
reject(new Error("aborted"));
return;
}
if (timedOut) {
reject(new Error(`timeout:${timeout}`));
return;
}
resolve({ exitCode: code });
});
});
},
};
export interface BashSpawnContext {
command: string;
cwd: string;
env: NodeJS.ProcessEnv;
}
export type BashSpawnHook = (context: BashSpawnContext) => BashSpawnContext;
function resolveSpawnContext(
command: string,
cwd: string,
spawnHook?: BashSpawnHook,
): BashSpawnContext {
const baseContext: BashSpawnContext = {
command,
cwd,
env: { ...getShellEnv() },
};
return spawnHook ? spawnHook(baseContext) : baseContext;
}
export interface BashToolOptions {
/** Custom operations for command execution. Default: local shell */
operations?: BashOperations;
/** Command prefix prepended to every command (e.g., "shopt -s expand_aliases" for alias support) */
commandPrefix?: string;
/** Hook to adjust command, cwd, or env before execution */
spawnHook?: BashSpawnHook;
}
export function createBashTool(
cwd: string,
options?: BashToolOptions,
): AgentTool<typeof bashSchema> {
const ops = options?.operations ?? defaultBashOperations;
const commandPrefix = options?.commandPrefix;
const spawnHook = options?.spawnHook;
return {
name: "bash",
label: "bash",
description: `Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. Optionally provide a timeout in seconds.`,
parameters: bashSchema,
execute: async (
_toolCallId: string,
{ command, timeout }: { command: string; timeout?: number },
signal?: AbortSignal,
onUpdate?,
) => {
// Apply command prefix if configured (e.g., "shopt -s expand_aliases" for alias support)
const resolvedCommand = commandPrefix
? `${commandPrefix}\n${command}`
: command;
const spawnContext = resolveSpawnContext(resolvedCommand, cwd, spawnHook);
return new Promise((resolve, reject) => {
// We'll stream to a temp file if output gets large
let tempFilePath: string | undefined;
let tempFileStream: ReturnType<typeof createWriteStream> | undefined;
let totalBytes = 0;
// Keep a rolling buffer of the last chunk for tail truncation
const chunks: Buffer[] = [];
let chunksBytes = 0;
// Keep more than we need so we have enough for truncation
const maxChunksBytes = DEFAULT_MAX_BYTES * 2;
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Start writing to temp file once we exceed the threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
tempFilePath = getTempFilePath();
tempFileStream = createWriteStream(tempFilePath);
// Write all buffered chunks to the file
for (const chunk of chunks) {
tempFileStream.write(chunk);
}
}
// Write to temp file if we have one
if (tempFileStream) {
tempFileStream.write(data);
}
// Keep rolling buffer of recent data
chunks.push(data);
chunksBytes += data.length;
// Trim old chunks if buffer is too large
while (chunksBytes > maxChunksBytes && chunks.length > 1) {
const removed = chunks.shift()!;
chunksBytes -= removed.length;
}
// Stream partial output to callback (truncated rolling buffer)
if (onUpdate) {
const fullBuffer = Buffer.concat(chunks);
const fullText = fullBuffer.toString("utf-8");
const truncation = truncateTail(fullText);
onUpdate({
content: [{ type: "text", text: truncation.content || "" }],
details: {
truncation: truncation.truncated ? truncation : undefined,
fullOutputPath: tempFilePath,
},
});
}
};
ops
.exec(spawnContext.command, spawnContext.cwd, {
onData: handleData,
signal,
timeout,
env: spawnContext.env,
})
.then(({ exitCode }) => {
// Close temp file stream
if (tempFileStream) {
tempFileStream.end();
}
// Combine all buffered chunks
const fullBuffer = Buffer.concat(chunks);
const fullOutput = fullBuffer.toString("utf-8");
// Apply tail truncation
const truncation = truncateTail(fullOutput);
let outputText = truncation.content || "(no output)";
// Build details with truncation info
let details: BashToolDetails | undefined;
if (truncation.truncated) {
details = {
truncation,
fullOutputPath: tempFilePath,
};
// Build actionable notice
const startLine =
truncation.totalLines - truncation.outputLines + 1;
const endLine = truncation.totalLines;
if (truncation.lastLinePartial) {
// Edge case: last line alone > 30KB
const lastLineSize = formatSize(
Buffer.byteLength(
fullOutput.split("\n").pop() || "",
"utf-8",
),
);
outputText += `\n\n[Showing last ${formatSize(truncation.outputBytes)} of line ${endLine} (line is ${lastLineSize}). Full output: ${tempFilePath}]`;
} else if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines}. Full output: ${tempFilePath}]`;
} else {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Full output: ${tempFilePath}]`;
}
}
if (exitCode !== 0 && exitCode !== null) {
outputText += `\n\nCommand exited with code ${exitCode}`;
reject(new Error(outputText));
} else {
resolve({
content: [{ type: "text", text: outputText }],
details,
});
}
})
.catch((err: Error) => {
// Close temp file stream
if (tempFileStream) {
tempFileStream.end();
}
// Combine all buffered chunks for error output
const fullBuffer = Buffer.concat(chunks);
let output = fullBuffer.toString("utf-8");
if (err.message === "aborted") {
if (output) output += "\n\n";
output += "Command aborted";
reject(new Error(output));
} else if (err.message.startsWith("timeout:")) {
const timeoutSecs = err.message.split(":")[1];
if (output) output += "\n\n";
output += `Command timed out after ${timeoutSecs} seconds`;
reject(new Error(output));
} else {
reject(err);
}
});
});
},
};
}
/** Default bash tool using process.cwd() - for backwards compatibility */
export const bashTool = createBashTool(process.cwd());

View file

@ -0,0 +1,317 @@
/**
* Shared diff computation utilities for the edit tool.
* Used by both edit.ts (for execution) and tool-execution.ts (for preview rendering).
*/
import * as Diff from "diff";
import { constants } from "fs";
import { access, readFile } from "fs/promises";
import { resolveToCwd } from "./path-utils.js";
export function detectLineEnding(content: string): "\r\n" | "\n" {
const crlfIdx = content.indexOf("\r\n");
const lfIdx = content.indexOf("\n");
if (lfIdx === -1) return "\n";
if (crlfIdx === -1) return "\n";
return crlfIdx < lfIdx ? "\r\n" : "\n";
}
export function normalizeToLF(text: string): string {
return text.replace(/\r\n/g, "\n").replace(/\r/g, "\n");
}
export function restoreLineEndings(
text: string,
ending: "\r\n" | "\n",
): string {
return ending === "\r\n" ? text.replace(/\n/g, "\r\n") : text;
}
/**
* Normalize text for fuzzy matching. Applies progressive transformations:
* - Strip trailing whitespace from each line
* - Normalize smart quotes to ASCII equivalents
* - Normalize Unicode dashes/hyphens to ASCII hyphen
* - Normalize special Unicode spaces to regular space
*/
export function normalizeForFuzzyMatch(text: string): string {
return (
text
// Strip trailing whitespace per line
.split("\n")
.map((line) => line.trimEnd())
.join("\n")
// Smart single quotes → '
.replace(/[\u2018\u2019\u201A\u201B]/g, "'")
// Smart double quotes → "
.replace(/[\u201C\u201D\u201E\u201F]/g, '"')
// Various dashes/hyphens → -
// U+2010 hyphen, U+2011 non-breaking hyphen, U+2012 figure dash,
// U+2013 en-dash, U+2014 em-dash, U+2015 horizontal bar, U+2212 minus
.replace(/[\u2010\u2011\u2012\u2013\u2014\u2015\u2212]/g, "-")
// Special spaces → regular space
// U+00A0 NBSP, U+2002-U+200A various spaces, U+202F narrow NBSP,
// U+205F medium math space, U+3000 ideographic space
.replace(/[\u00A0\u2002-\u200A\u202F\u205F\u3000]/g, " ")
);
}
export interface FuzzyMatchResult {
/** Whether a match was found */
found: boolean;
/** The index where the match starts (in the content that should be used for replacement) */
index: number;
/** Length of the matched text */
matchLength: number;
/** Whether fuzzy matching was used (false = exact match) */
usedFuzzyMatch: boolean;
/**
* The content to use for replacement operations.
* When exact match: original content. When fuzzy match: normalized content.
*/
contentForReplacement: string;
}
/**
* Find oldText in content, trying exact match first, then fuzzy match.
* When fuzzy matching is used, the returned contentForReplacement is the
* fuzzy-normalized version of the content (trailing whitespace stripped,
* Unicode quotes/dashes normalized to ASCII).
*/
export function fuzzyFindText(
content: string,
oldText: string,
): FuzzyMatchResult {
// Try exact match first
const exactIndex = content.indexOf(oldText);
if (exactIndex !== -1) {
return {
found: true,
index: exactIndex,
matchLength: oldText.length,
usedFuzzyMatch: false,
contentForReplacement: content,
};
}
// Try fuzzy match - work entirely in normalized space
const fuzzyContent = normalizeForFuzzyMatch(content);
const fuzzyOldText = normalizeForFuzzyMatch(oldText);
const fuzzyIndex = fuzzyContent.indexOf(fuzzyOldText);
if (fuzzyIndex === -1) {
return {
found: false,
index: -1,
matchLength: 0,
usedFuzzyMatch: false,
contentForReplacement: content,
};
}
// When fuzzy matching, we work in the normalized space for replacement.
// This means the output will have normalized whitespace/quotes/dashes,
// which is acceptable since we're fixing minor formatting differences anyway.
return {
found: true,
index: fuzzyIndex,
matchLength: fuzzyOldText.length,
usedFuzzyMatch: true,
contentForReplacement: fuzzyContent,
};
}
/** Strip UTF-8 BOM if present, return both the BOM (if any) and the text without it */
export function stripBom(content: string): { bom: string; text: string } {
return content.startsWith("\uFEFF")
? { bom: "\uFEFF", text: content.slice(1) }
: { bom: "", text: content };
}
/**
* Generate a unified diff string with line numbers and context.
* Returns both the diff string and the first changed line number (in the new file).
*/
export function generateDiffString(
oldContent: string,
newContent: string,
contextLines = 4,
): { diff: string; firstChangedLine: number | undefined } {
const parts = Diff.diffLines(oldContent, newContent);
const output: string[] = [];
const oldLines = oldContent.split("\n");
const newLines = newContent.split("\n");
const maxLineNum = Math.max(oldLines.length, newLines.length);
const lineNumWidth = String(maxLineNum).length;
let oldLineNum = 1;
let newLineNum = 1;
let lastWasChange = false;
let firstChangedLine: number | undefined;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
const raw = part.value.split("\n");
if (raw[raw.length - 1] === "") {
raw.pop();
}
if (part.added || part.removed) {
// Capture the first changed line (in the new file)
if (firstChangedLine === undefined) {
firstChangedLine = newLineNum;
}
// Show the change
for (const line of raw) {
if (part.added) {
const lineNum = String(newLineNum).padStart(lineNumWidth, " ");
output.push(`+${lineNum} ${line}`);
newLineNum++;
} else {
// removed
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(`-${lineNum} ${line}`);
oldLineNum++;
}
}
lastWasChange = true;
} else {
// Context lines - only show a few before/after changes
const nextPartIsChange =
i < parts.length - 1 && (parts[i + 1].added || parts[i + 1].removed);
if (lastWasChange || nextPartIsChange) {
// Show context
let linesToShow = raw;
let skipStart = 0;
let skipEnd = 0;
if (!lastWasChange) {
// Show only last N lines as leading context
skipStart = Math.max(0, raw.length - contextLines);
linesToShow = raw.slice(skipStart);
}
if (!nextPartIsChange && linesToShow.length > contextLines) {
// Show only first N lines as trailing context
skipEnd = linesToShow.length - contextLines;
linesToShow = linesToShow.slice(0, contextLines);
}
// Add ellipsis if we skipped lines at start
if (skipStart > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped leading context
oldLineNum += skipStart;
newLineNum += skipStart;
}
for (const line of linesToShow) {
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(` ${lineNum} ${line}`);
oldLineNum++;
newLineNum++;
}
// Add ellipsis if we skipped lines at end
if (skipEnd > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped trailing context
oldLineNum += skipEnd;
newLineNum += skipEnd;
}
} else {
// Skip these context lines entirely
oldLineNum += raw.length;
newLineNum += raw.length;
}
lastWasChange = false;
}
}
return { diff: output.join("\n"), firstChangedLine };
}
export interface EditDiffResult {
diff: string;
firstChangedLine: number | undefined;
}
export interface EditDiffError {
error: string;
}
/**
* Compute the diff for an edit operation without applying it.
* Used for preview rendering in the TUI before the tool executes.
*/
export async function computeEditDiff(
path: string,
oldText: string,
newText: string,
cwd: string,
): Promise<EditDiffResult | EditDiffError> {
const absolutePath = resolveToCwd(path, cwd);
try {
// Check if file exists and is readable
try {
await access(absolutePath, constants.R_OK);
} catch {
return { error: `File not found: ${path}` };
}
// Read the file
const rawContent = await readFile(absolutePath, "utf-8");
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
const { text: content } = stripBom(rawContent);
const normalizedContent = normalizeToLF(content);
const normalizedOldText = normalizeToLF(oldText);
const normalizedNewText = normalizeToLF(newText);
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
const matchResult = fuzzyFindText(normalizedContent, normalizedOldText);
if (!matchResult.found) {
return {
error: `Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
};
}
// Count occurrences using fuzzy-normalized content for consistency
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
if (occurrences > 1) {
return {
error: `Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
};
}
// Compute the new content using the matched position
// When fuzzy matching was used, contentForReplacement is the normalized version
const baseContent = matchResult.contentForReplacement;
const newContent =
baseContent.substring(0, matchResult.index) +
normalizedNewText +
baseContent.substring(matchResult.index + matchResult.matchLength);
// Check if it would actually change anything
if (baseContent === newContent) {
return {
error: `No changes would be made to ${path}. The replacement produces identical content.`,
};
}
// Generate the diff
return generateDiffString(baseContent, newContent);
} catch (err) {
return { error: err instanceof Error ? err.message : String(err) };
}
}

View file

@ -0,0 +1,253 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { constants } from "fs";
import {
access as fsAccess,
readFile as fsReadFile,
writeFile as fsWriteFile,
} from "fs/promises";
import {
detectLineEnding,
fuzzyFindText,
generateDiffString,
normalizeForFuzzyMatch,
normalizeToLF,
restoreLineEndings,
stripBom,
} from "./edit-diff.js";
import { resolveToCwd } from "./path-utils.js";
const editSchema = Type.Object({
path: Type.String({
description: "Path to the file to edit (relative or absolute)",
}),
oldText: Type.String({
description: "Exact text to find and replace (must match exactly)",
}),
newText: Type.String({
description: "New text to replace the old text with",
}),
});
export type EditToolInput = Static<typeof editSchema>;
export interface EditToolDetails {
/** Unified diff of the changes made */
diff: string;
/** Line number of the first change in the new file (for editor navigation) */
firstChangedLine?: number;
}
/**
* Pluggable operations for the edit tool.
* Override these to delegate file editing to remote systems (e.g., SSH).
*/
export interface EditOperations {
/** Read file contents as a Buffer */
readFile: (absolutePath: string) => Promise<Buffer>;
/** Write content to a file */
writeFile: (absolutePath: string, content: string) => Promise<void>;
/** Check if file is readable and writable (throw if not) */
access: (absolutePath: string) => Promise<void>;
}
const defaultEditOperations: EditOperations = {
readFile: (path) => fsReadFile(path),
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
access: (path) => fsAccess(path, constants.R_OK | constants.W_OK),
};
export interface EditToolOptions {
/** Custom operations for file editing. Default: local filesystem */
operations?: EditOperations;
}
export function createEditTool(
cwd: string,
options?: EditToolOptions,
): AgentTool<typeof editSchema> {
const ops = options?.operations ?? defaultEditOperations;
return {
name: "edit",
label: "edit",
description:
"Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits.",
parameters: editSchema,
execute: async (
_toolCallId: string,
{
path,
oldText,
newText,
}: { path: string; oldText: string; newText: string },
signal?: AbortSignal,
) => {
const absolutePath = resolveToCwd(path, cwd);
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: EditToolDetails | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the edit operation
(async () => {
try {
// Check if file exists
try {
await ops.access(absolutePath);
} catch {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(new Error(`File not found: ${path}`));
return;
}
// Check if aborted before reading
if (aborted) {
return;
}
// Read the file
const buffer = await ops.readFile(absolutePath);
const rawContent = buffer.toString("utf-8");
// Check if aborted after reading
if (aborted) {
return;
}
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
const { bom, text: content } = stripBom(rawContent);
const originalEnding = detectLineEnding(content);
const normalizedContent = normalizeToLF(content);
const normalizedOldText = normalizeToLF(oldText);
const normalizedNewText = normalizeToLF(newText);
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
const matchResult = fuzzyFindText(
normalizedContent,
normalizedOldText,
);
if (!matchResult.found) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
),
);
return;
}
// Count occurrences using fuzzy-normalized content for consistency
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
if (occurrences > 1) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
),
);
return;
}
// Check if aborted before writing
if (aborted) {
return;
}
// Perform replacement using the matched text position
// When fuzzy matching was used, contentForReplacement is the normalized version
const baseContent = matchResult.contentForReplacement;
const newContent =
baseContent.substring(0, matchResult.index) +
normalizedNewText +
baseContent.substring(
matchResult.index + matchResult.matchLength,
);
// Verify the replacement actually changed something
if (baseContent === newContent) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`No changes made to ${path}. The replacement produced identical content. This might indicate an issue with special characters or the text not existing as expected.`,
),
);
return;
}
const finalContent =
bom + restoreLineEndings(newContent, originalEnding);
await ops.writeFile(absolutePath, finalContent);
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
const diffResult = generateDiffString(baseContent, newContent);
resolve({
content: [
{
type: "text",
text: `Successfully replaced text in ${path}.`,
},
],
details: {
diff: diffResult.diff,
firstChangedLine: diffResult.firstChangedLine,
},
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default edit tool using process.cwd() - for backwards compatibility */
export const editTool = createEditTool(process.cwd());

View file

@ -0,0 +1,308 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawnSync } from "child_process";
import { existsSync } from "fs";
import { globSync } from "glob";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const findSchema = Type.Object({
pattern: Type.String({
description:
"Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'",
}),
path: Type.Optional(
Type.String({
description: "Directory to search in (default: current directory)",
}),
),
limit: Type.Optional(
Type.Number({ description: "Maximum number of results (default: 1000)" }),
),
});
export type FindToolInput = Static<typeof findSchema>;
const DEFAULT_LIMIT = 1000;
export interface FindToolDetails {
truncation?: TruncationResult;
resultLimitReached?: number;
}
/**
* Pluggable operations for the find tool.
* Override these to delegate file search to remote systems (e.g., SSH).
*/
export interface FindOperations {
/** Check if path exists */
exists: (absolutePath: string) => Promise<boolean> | boolean;
/** Find files matching glob pattern. Returns relative paths. */
glob: (
pattern: string,
cwd: string,
options: { ignore: string[]; limit: number },
) => Promise<string[]> | string[];
}
const defaultFindOperations: FindOperations = {
exists: existsSync,
glob: (_pattern, _searchCwd, _options) => {
// This is a placeholder - actual fd execution happens in execute
return [];
},
};
export interface FindToolOptions {
/** Custom operations for find. Default: local filesystem + fd */
operations?: FindOperations;
}
export function createFindTool(
cwd: string,
options?: FindToolOptions,
): AgentTool<typeof findSchema> {
const customOps = options?.operations;
return {
name: "find",
label: "find",
description: `Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} results or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: findSchema,
execute: async (
_toolCallId: string,
{
pattern,
path: searchDir,
limit,
}: { pattern: string; path?: string; limit?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
(async () => {
try {
const searchPath = resolveToCwd(searchDir || ".", cwd);
const effectiveLimit = limit ?? DEFAULT_LIMIT;
const ops = customOps ?? defaultFindOperations;
// If custom operations provided with glob, use that
if (customOps?.glob) {
if (!(await ops.exists(searchPath))) {
reject(new Error(`Path not found: ${searchPath}`));
return;
}
const results = await ops.glob(pattern, searchPath, {
ignore: ["**/node_modules/**", "**/.git/**"],
limit: effectiveLimit,
});
signal?.removeEventListener("abort", onAbort);
if (results.length === 0) {
resolve({
content: [
{ type: "text", text: "No files found matching pattern" },
],
details: undefined,
});
return;
}
// Relativize paths
const relativized = results.map((p) => {
if (p.startsWith(searchPath)) {
return p.slice(searchPath.length + 1);
}
return path.relative(searchPath, p);
});
const resultLimitReached = relativized.length >= effectiveLimit;
const rawOutput = relativized.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let resultOutput = truncation.content;
const details: FindToolDetails = {};
const notices: string[] = [];
if (resultLimitReached) {
notices.push(`${effectiveLimit} results limit reached`);
details.resultLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
resultOutput += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: resultOutput }],
details: Object.keys(details).length > 0 ? details : undefined,
});
return;
}
// Default: use fd
const fdPath = await ensureTool("fd", true);
if (!fdPath) {
reject(
new Error("fd is not available and could not be downloaded"),
);
return;
}
// Build fd arguments
const args: string[] = [
"--glob",
"--color=never",
"--hidden",
"--max-results",
String(effectiveLimit),
];
// Include .gitignore files
const gitignoreFiles = new Set<string>();
const rootGitignore = path.join(searchPath, ".gitignore");
if (existsSync(rootGitignore)) {
gitignoreFiles.add(rootGitignore);
}
try {
const nestedGitignores = globSync("**/.gitignore", {
cwd: searchPath,
dot: true,
absolute: true,
ignore: ["**/node_modules/**", "**/.git/**"],
});
for (const file of nestedGitignores) {
gitignoreFiles.add(file);
}
} catch {
// Ignore glob errors
}
for (const gitignorePath of gitignoreFiles) {
args.push("--ignore-file", gitignorePath);
}
args.push(pattern, searchPath);
const result = spawnSync(fdPath, args, {
encoding: "utf-8",
maxBuffer: 10 * 1024 * 1024,
});
signal?.removeEventListener("abort", onAbort);
if (result.error) {
reject(new Error(`Failed to run fd: ${result.error.message}`));
return;
}
const output = result.stdout?.trim() || "";
if (result.status !== 0) {
const errorMsg =
result.stderr?.trim() || `fd exited with code ${result.status}`;
if (!output) {
reject(new Error(errorMsg));
return;
}
}
if (!output) {
resolve({
content: [
{ type: "text", text: "No files found matching pattern" },
],
details: undefined,
});
return;
}
const lines = output.split("\n");
const relativized: string[] = [];
for (const rawLine of lines) {
const line = rawLine.replace(/\r$/, "").trim();
if (!line) continue;
const hadTrailingSlash =
line.endsWith("/") || line.endsWith("\\");
let relativePath = line;
if (line.startsWith(searchPath)) {
relativePath = line.slice(searchPath.length + 1);
} else {
relativePath = path.relative(searchPath, line);
}
if (hadTrailingSlash && !relativePath.endsWith("/")) {
relativePath += "/";
}
relativized.push(relativePath);
}
const resultLimitReached = relativized.length >= effectiveLimit;
const rawOutput = relativized.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let resultOutput = truncation.content;
const details: FindToolDetails = {};
const notices: string[] = [];
if (resultLimitReached) {
notices.push(
`${effectiveLimit} results limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.resultLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
resultOutput += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: resultOutput }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
})();
});
},
};
}
/** Default find tool using process.cwd() - for backwards compatibility */
export const findTool = createFindTool(process.cwd());

View file

@ -0,0 +1,412 @@
import { createInterface } from "node:readline";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { readFileSync, statSync } from "fs";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
GREP_MAX_LINE_LENGTH,
type TruncationResult,
truncateHead,
truncateLine,
} from "./truncate.js";
const grepSchema = Type.Object({
pattern: Type.String({
description: "Search pattern (regex or literal string)",
}),
path: Type.Optional(
Type.String({
description: "Directory or file to search (default: current directory)",
}),
),
glob: Type.Optional(
Type.String({
description:
"Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'",
}),
),
ignoreCase: Type.Optional(
Type.Boolean({ description: "Case-insensitive search (default: false)" }),
),
literal: Type.Optional(
Type.Boolean({
description:
"Treat pattern as literal string instead of regex (default: false)",
}),
),
context: Type.Optional(
Type.Number({
description:
"Number of lines to show before and after each match (default: 0)",
}),
),
limit: Type.Optional(
Type.Number({
description: "Maximum number of matches to return (default: 100)",
}),
),
});
export type GrepToolInput = Static<typeof grepSchema>;
const DEFAULT_LIMIT = 100;
export interface GrepToolDetails {
truncation?: TruncationResult;
matchLimitReached?: number;
linesTruncated?: boolean;
}
/**
* Pluggable operations for the grep tool.
* Override these to delegate search to remote systems (e.g., SSH).
*/
export interface GrepOperations {
/** Check if path is a directory. Throws if path doesn't exist. */
isDirectory: (absolutePath: string) => Promise<boolean> | boolean;
/** Read file contents for context lines */
readFile: (absolutePath: string) => Promise<string> | string;
}
const defaultGrepOperations: GrepOperations = {
isDirectory: (p) => statSync(p).isDirectory(),
readFile: (p) => readFileSync(p, "utf-8"),
};
export interface GrepToolOptions {
/** Custom operations for grep. Default: local filesystem + ripgrep */
operations?: GrepOperations;
}
export function createGrepTool(
cwd: string,
options?: GrepToolOptions,
): AgentTool<typeof grepSchema> {
const customOps = options?.operations;
return {
name: "grep",
label: "grep",
description: `Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} matches or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Long lines are truncated to ${GREP_MAX_LINE_LENGTH} chars.`,
parameters: grepSchema,
execute: async (
_toolCallId: string,
{
pattern,
path: searchDir,
glob,
ignoreCase,
literal,
context,
limit,
}: {
pattern: string;
path?: string;
glob?: string;
ignoreCase?: boolean;
literal?: boolean;
context?: number;
limit?: number;
},
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let settled = false;
const settle = (fn: () => void) => {
if (!settled) {
settled = true;
fn();
}
};
(async () => {
try {
const rgPath = await ensureTool("rg", true);
if (!rgPath) {
settle(() =>
reject(
new Error(
"ripgrep (rg) is not available and could not be downloaded",
),
),
);
return;
}
const searchPath = resolveToCwd(searchDir || ".", cwd);
const ops = customOps ?? defaultGrepOperations;
let isDirectory: boolean;
try {
isDirectory = await ops.isDirectory(searchPath);
} catch (_err) {
settle(() => reject(new Error(`Path not found: ${searchPath}`)));
return;
}
const contextValue = context && context > 0 ? context : 0;
const effectiveLimit = Math.max(1, limit ?? DEFAULT_LIMIT);
const formatPath = (filePath: string): string => {
if (isDirectory) {
const relative = path.relative(searchPath, filePath);
if (relative && !relative.startsWith("..")) {
return relative.replace(/\\/g, "/");
}
}
return path.basename(filePath);
};
const fileCache = new Map<string, string[]>();
const getFileLines = async (
filePath: string,
): Promise<string[]> => {
let lines = fileCache.get(filePath);
if (!lines) {
try {
const content = await ops.readFile(filePath);
lines = content
.replace(/\r\n/g, "\n")
.replace(/\r/g, "\n")
.split("\n");
} catch {
lines = [];
}
fileCache.set(filePath, lines);
}
return lines;
};
const args: string[] = [
"--json",
"--line-number",
"--color=never",
"--hidden",
];
if (ignoreCase) {
args.push("--ignore-case");
}
if (literal) {
args.push("--fixed-strings");
}
if (glob) {
args.push("--glob", glob);
}
args.push(pattern, searchPath);
const child = spawn(rgPath, args, {
stdio: ["ignore", "pipe", "pipe"],
});
const rl = createInterface({ input: child.stdout });
let stderr = "";
let matchCount = 0;
let matchLimitReached = false;
let linesTruncated = false;
let aborted = false;
let killedDueToLimit = false;
const outputLines: string[] = [];
const cleanup = () => {
rl.close();
signal?.removeEventListener("abort", onAbort);
};
const stopChild = (dueToLimit: boolean = false) => {
if (!child.killed) {
killedDueToLimit = dueToLimit;
child.kill();
}
};
const onAbort = () => {
aborted = true;
stopChild();
};
signal?.addEventListener("abort", onAbort, { once: true });
child.stderr?.on("data", (chunk) => {
stderr += chunk.toString();
});
const formatBlock = async (
filePath: string,
lineNumber: number,
): Promise<string[]> => {
const relativePath = formatPath(filePath);
const lines = await getFileLines(filePath);
if (!lines.length) {
return [`${relativePath}:${lineNumber}: (unable to read file)`];
}
const block: string[] = [];
const start =
contextValue > 0
? Math.max(1, lineNumber - contextValue)
: lineNumber;
const end =
contextValue > 0
? Math.min(lines.length, lineNumber + contextValue)
: lineNumber;
for (let current = start; current <= end; current++) {
const lineText = lines[current - 1] ?? "";
const sanitized = lineText.replace(/\r/g, "");
const isMatchLine = current === lineNumber;
// Truncate long lines
const { text: truncatedText, wasTruncated } =
truncateLine(sanitized);
if (wasTruncated) {
linesTruncated = true;
}
if (isMatchLine) {
block.push(`${relativePath}:${current}: ${truncatedText}`);
} else {
block.push(`${relativePath}-${current}- ${truncatedText}`);
}
}
return block;
};
// Collect matches during streaming, format after
const matches: Array<{ filePath: string; lineNumber: number }> = [];
rl.on("line", (line) => {
if (!line.trim() || matchCount >= effectiveLimit) {
return;
}
let event: any;
try {
event = JSON.parse(line);
} catch {
return;
}
if (event.type === "match") {
matchCount++;
const filePath = event.data?.path?.text;
const lineNumber = event.data?.line_number;
if (filePath && typeof lineNumber === "number") {
matches.push({ filePath, lineNumber });
}
if (matchCount >= effectiveLimit) {
matchLimitReached = true;
stopChild(true);
}
}
});
child.on("error", (error) => {
cleanup();
settle(() =>
reject(new Error(`Failed to run ripgrep: ${error.message}`)),
);
});
child.on("close", async (code) => {
cleanup();
if (aborted) {
settle(() => reject(new Error("Operation aborted")));
return;
}
if (!killedDueToLimit && code !== 0 && code !== 1) {
const errorMsg =
stderr.trim() || `ripgrep exited with code ${code}`;
settle(() => reject(new Error(errorMsg)));
return;
}
if (matchCount === 0) {
settle(() =>
resolve({
content: [{ type: "text", text: "No matches found" }],
details: undefined,
}),
);
return;
}
// Format matches (async to support remote file reading)
for (const match of matches) {
const block = await formatBlock(
match.filePath,
match.lineNumber,
);
outputLines.push(...block);
}
// Apply byte truncation (no line limit since we already have match limit)
const rawOutput = outputLines.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let output = truncation.content;
const details: GrepToolDetails = {};
// Build notices
const notices: string[] = [];
if (matchLimitReached) {
notices.push(
`${effectiveLimit} matches limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.matchLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (linesTruncated) {
notices.push(
`Some lines truncated to ${GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines`,
);
details.linesTruncated = true;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
settle(() =>
resolve({
content: [{ type: "text", text: output }],
details:
Object.keys(details).length > 0 ? details : undefined,
}),
);
});
} catch (err) {
settle(() => reject(err as Error));
}
})();
});
},
};
}
/** Default grep tool using process.cwd() - for backwards compatibility */
export const grepTool = createGrepTool(process.cwd());

View file

@ -0,0 +1,150 @@
export {
type BashOperations,
type BashSpawnContext,
type BashSpawnHook,
type BashToolDetails,
type BashToolInput,
type BashToolOptions,
bashTool,
createBashTool,
} from "./bash.js";
export {
createEditTool,
type EditOperations,
type EditToolDetails,
type EditToolInput,
type EditToolOptions,
editTool,
} from "./edit.js";
export {
createFindTool,
type FindOperations,
type FindToolDetails,
type FindToolInput,
type FindToolOptions,
findTool,
} from "./find.js";
export {
createGrepTool,
type GrepOperations,
type GrepToolDetails,
type GrepToolInput,
type GrepToolOptions,
grepTool,
} from "./grep.js";
export {
createLsTool,
type LsOperations,
type LsToolDetails,
type LsToolInput,
type LsToolOptions,
lsTool,
} from "./ls.js";
export {
createReadTool,
type ReadOperations,
type ReadToolDetails,
type ReadToolInput,
type ReadToolOptions,
readTool,
} from "./read.js";
export {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationOptions,
type TruncationResult,
truncateHead,
truncateLine,
truncateTail,
} from "./truncate.js";
export {
createWriteTool,
type WriteOperations,
type WriteToolInput,
type WriteToolOptions,
writeTool,
} from "./write.js";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type BashToolOptions, bashTool, createBashTool } from "./bash.js";
import { createEditTool, editTool } from "./edit.js";
import { createFindTool, findTool } from "./find.js";
import { createGrepTool, grepTool } from "./grep.js";
import { createLsTool, lsTool } from "./ls.js";
import { createReadTool, type ReadToolOptions, readTool } from "./read.js";
import { createWriteTool, writeTool } from "./write.js";
/** Tool type (AgentTool from pi-ai) */
export type Tool = AgentTool<any>;
// Default tools for full access mode (using process.cwd())
export const codingTools: Tool[] = [readTool, bashTool, editTool, writeTool];
// Read-only tools for exploration without modification (using process.cwd())
export const readOnlyTools: Tool[] = [readTool, grepTool, findTool, lsTool];
// All available tools (using process.cwd())
export const allTools = {
read: readTool,
bash: bashTool,
edit: editTool,
write: writeTool,
grep: grepTool,
find: findTool,
ls: lsTool,
};
export type ToolName = keyof typeof allTools;
export interface ToolsOptions {
/** Options for the read tool */
read?: ReadToolOptions;
/** Options for the bash tool */
bash?: BashToolOptions;
}
/**
* Create coding tools configured for a specific working directory.
*/
export function createCodingTools(cwd: string, options?: ToolsOptions): Tool[] {
return [
createReadTool(cwd, options?.read),
createBashTool(cwd, options?.bash),
createEditTool(cwd),
createWriteTool(cwd),
];
}
/**
* Create read-only tools configured for a specific working directory.
*/
export function createReadOnlyTools(
cwd: string,
options?: ToolsOptions,
): Tool[] {
return [
createReadTool(cwd, options?.read),
createGrepTool(cwd),
createFindTool(cwd),
createLsTool(cwd),
];
}
/**
* Create all tools configured for a specific working directory.
*/
export function createAllTools(
cwd: string,
options?: ToolsOptions,
): Record<ToolName, Tool> {
return {
read: createReadTool(cwd, options?.read),
bash: createBashTool(cwd, options?.bash),
edit: createEditTool(cwd),
write: createWriteTool(cwd),
grep: createGrepTool(cwd),
find: createFindTool(cwd),
ls: createLsTool(cwd),
};
}

View file

@ -0,0 +1,197 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { existsSync, readdirSync, statSync } from "fs";
import nodePath from "path";
import { resolveToCwd } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const lsSchema = Type.Object({
path: Type.Optional(
Type.String({
description: "Directory to list (default: current directory)",
}),
),
limit: Type.Optional(
Type.Number({
description: "Maximum number of entries to return (default: 500)",
}),
),
});
export type LsToolInput = Static<typeof lsSchema>;
const DEFAULT_LIMIT = 500;
export interface LsToolDetails {
truncation?: TruncationResult;
entryLimitReached?: number;
}
/**
* Pluggable operations for the ls tool.
* Override these to delegate directory listing to remote systems (e.g., SSH).
*/
export interface LsOperations {
/** Check if path exists */
exists: (absolutePath: string) => Promise<boolean> | boolean;
/** Get file/directory stats. Throws if not found. */
stat: (
absolutePath: string,
) => Promise<{ isDirectory: () => boolean }> | { isDirectory: () => boolean };
/** Read directory entries */
readdir: (absolutePath: string) => Promise<string[]> | string[];
}
const defaultLsOperations: LsOperations = {
exists: existsSync,
stat: statSync,
readdir: readdirSync,
};
export interface LsToolOptions {
/** Custom operations for directory listing. Default: local filesystem */
operations?: LsOperations;
}
export function createLsTool(
cwd: string,
options?: LsToolOptions,
): AgentTool<typeof lsSchema> {
const ops = options?.operations ?? defaultLsOperations;
return {
name: "ls",
label: "ls",
description: `List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to ${DEFAULT_LIMIT} entries or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: lsSchema,
execute: async (
_toolCallId: string,
{ path, limit }: { path?: string; limit?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
(async () => {
try {
const dirPath = resolveToCwd(path || ".", cwd);
const effectiveLimit = limit ?? DEFAULT_LIMIT;
// Check if path exists
if (!(await ops.exists(dirPath))) {
reject(new Error(`Path not found: ${dirPath}`));
return;
}
// Check if path is a directory
const stat = await ops.stat(dirPath);
if (!stat.isDirectory()) {
reject(new Error(`Not a directory: ${dirPath}`));
return;
}
// Read directory entries
let entries: string[];
try {
entries = await ops.readdir(dirPath);
} catch (e: any) {
reject(new Error(`Cannot read directory: ${e.message}`));
return;
}
// Sort alphabetically (case-insensitive)
entries.sort((a, b) =>
a.toLowerCase().localeCompare(b.toLowerCase()),
);
// Format entries with directory indicators
const results: string[] = [];
let entryLimitReached = false;
for (const entry of entries) {
if (results.length >= effectiveLimit) {
entryLimitReached = true;
break;
}
const fullPath = nodePath.join(dirPath, entry);
let suffix = "";
try {
const entryStat = await ops.stat(fullPath);
if (entryStat.isDirectory()) {
suffix = "/";
}
} catch {
// Skip entries we can't stat
continue;
}
results.push(entry + suffix);
}
signal?.removeEventListener("abort", onAbort);
if (results.length === 0) {
resolve({
content: [{ type: "text", text: "(empty directory)" }],
details: undefined,
});
return;
}
// Apply byte truncation (no line limit since we already have entry limit)
const rawOutput = results.join("\n");
const truncation = truncateHead(rawOutput, {
maxLines: Number.MAX_SAFE_INTEGER,
});
let output = truncation.content;
const details: LsToolDetails = {};
// Build notices
const notices: string[] = [];
if (entryLimitReached) {
notices.push(
`${effectiveLimit} entries limit reached. Use limit=${effectiveLimit * 2} for more`,
);
details.entryLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: output }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
})();
});
},
};
}
/** Default ls tool using process.cwd() - for backwards compatibility */
export const lsTool = createLsTool(process.cwd());

View file

@ -0,0 +1,94 @@
import { accessSync, constants } from "node:fs";
import * as os from "node:os";
import { isAbsolute, resolve as resolvePath } from "node:path";
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
const NARROW_NO_BREAK_SPACE = "\u202F";
function normalizeUnicodeSpaces(str: string): string {
return str.replace(UNICODE_SPACES, " ");
}
function tryMacOSScreenshotPath(filePath: string): string {
return filePath.replace(/ (AM|PM)\./g, `${NARROW_NO_BREAK_SPACE}$1.`);
}
function tryNFDVariant(filePath: string): string {
// macOS stores filenames in NFD (decomposed) form, try converting user input to NFD
return filePath.normalize("NFD");
}
function tryCurlyQuoteVariant(filePath: string): string {
// macOS uses U+2019 (right single quotation mark) in screenshot names like "Capture d'écran"
// Users typically type U+0027 (straight apostrophe)
return filePath.replace(/'/g, "\u2019");
}
function fileExists(filePath: string): boolean {
try {
accessSync(filePath, constants.F_OK);
return true;
} catch {
return false;
}
}
function normalizeAtPrefix(filePath: string): string {
return filePath.startsWith("@") ? filePath.slice(1) : filePath;
}
export function expandPath(filePath: string): string {
const normalized = normalizeUnicodeSpaces(normalizeAtPrefix(filePath));
if (normalized === "~") {
return os.homedir();
}
if (normalized.startsWith("~/")) {
return os.homedir() + normalized.slice(1);
}
return normalized;
}
/**
* Resolve a path relative to the given cwd.
* Handles ~ expansion and absolute paths.
*/
export function resolveToCwd(filePath: string, cwd: string): string {
const expanded = expandPath(filePath);
if (isAbsolute(expanded)) {
return expanded;
}
return resolvePath(cwd, expanded);
}
export function resolveReadPath(filePath: string, cwd: string): string {
const resolved = resolveToCwd(filePath, cwd);
if (fileExists(resolved)) {
return resolved;
}
// Try macOS AM/PM variant (narrow no-break space before AM/PM)
const amPmVariant = tryMacOSScreenshotPath(resolved);
if (amPmVariant !== resolved && fileExists(amPmVariant)) {
return amPmVariant;
}
// Try NFD variant (macOS stores filenames in NFD form)
const nfdVariant = tryNFDVariant(resolved);
if (nfdVariant !== resolved && fileExists(nfdVariant)) {
return nfdVariant;
}
// Try curly quote variant (macOS uses U+2019 in screenshot names)
const curlyVariant = tryCurlyQuoteVariant(resolved);
if (curlyVariant !== resolved && fileExists(curlyVariant)) {
return curlyVariant;
}
// Try combined NFD + curly quote (for French macOS screenshots like "Capture d'écran")
const nfdCurlyVariant = tryCurlyQuoteVariant(nfdVariant);
if (nfdCurlyVariant !== resolved && fileExists(nfdCurlyVariant)) {
return nfdCurlyVariant;
}
return resolved;
}

View file

@ -0,0 +1,265 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import { constants } from "fs";
import { access as fsAccess, readFile as fsReadFile } from "fs/promises";
import { formatDimensionNote, resizeImage } from "../../utils/image-resize.js";
import { detectSupportedImageMimeTypeFromFile } from "../../utils/mime.js";
import { resolveReadPath } from "./path-utils.js";
import {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
formatSize,
type TruncationResult,
truncateHead,
} from "./truncate.js";
const readSchema = Type.Object({
path: Type.String({
description: "Path to the file to read (relative or absolute)",
}),
offset: Type.Optional(
Type.Number({
description: "Line number to start reading from (1-indexed)",
}),
),
limit: Type.Optional(
Type.Number({ description: "Maximum number of lines to read" }),
),
});
export type ReadToolInput = Static<typeof readSchema>;
export interface ReadToolDetails {
truncation?: TruncationResult;
}
/**
* Pluggable operations for the read tool.
* Override these to delegate file reading to remote systems (e.g., SSH).
*/
export interface ReadOperations {
/** Read file contents as a Buffer */
readFile: (absolutePath: string) => Promise<Buffer>;
/** Check if file is readable (throw if not) */
access: (absolutePath: string) => Promise<void>;
/** Detect image MIME type, return null/undefined for non-images */
detectImageMimeType?: (
absolutePath: string,
) => Promise<string | null | undefined>;
}
const defaultReadOperations: ReadOperations = {
readFile: (path) => fsReadFile(path),
access: (path) => fsAccess(path, constants.R_OK),
detectImageMimeType: detectSupportedImageMimeTypeFromFile,
};
export interface ReadToolOptions {
/** Whether to auto-resize images to 2000x2000 max. Default: true */
autoResizeImages?: boolean;
/** Custom operations for file reading. Default: local filesystem */
operations?: ReadOperations;
}
export function createReadTool(
cwd: string,
options?: ReadToolOptions,
): AgentTool<typeof readSchema> {
const autoResizeImages = options?.autoResizeImages ?? true;
const ops = options?.operations ?? defaultReadOperations;
return {
name: "read",
label: "read",
description: `Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files. When you need the full file, continue with offset until complete.`,
parameters: readSchema,
execute: async (
_toolCallId: string,
{
path,
offset,
limit,
}: { path: string; offset?: number; limit?: number },
signal?: AbortSignal,
) => {
const absolutePath = resolveReadPath(path, cwd);
return new Promise<{
content: (TextContent | ImageContent)[];
details: ReadToolDetails | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the read operation
(async () => {
try {
// Check if file exists
await ops.access(absolutePath);
// Check if aborted before reading
if (aborted) {
return;
}
const mimeType = ops.detectImageMimeType
? await ops.detectImageMimeType(absolutePath)
: undefined;
// Read the file based on type
let content: (TextContent | ImageContent)[];
let details: ReadToolDetails | undefined;
if (mimeType) {
// Read as image (binary)
const buffer = await ops.readFile(absolutePath);
const base64 = buffer.toString("base64");
if (autoResizeImages) {
// Resize image if needed
const resized = await resizeImage({
type: "image",
data: base64,
mimeType,
});
const dimensionNote = formatDimensionNote(resized);
let textNote = `Read image file [${resized.mimeType}]`;
if (dimensionNote) {
textNote += `\n${dimensionNote}`;
}
content = [
{ type: "text", text: textNote },
{
type: "image",
data: resized.data,
mimeType: resized.mimeType,
},
];
} else {
const textNote = `Read image file [${mimeType}]`;
content = [
{ type: "text", text: textNote },
{ type: "image", data: base64, mimeType },
];
}
} else {
// Read as text
const buffer = await ops.readFile(absolutePath);
const textContent = buffer.toString("utf-8");
const allLines = textContent.split("\n");
const totalFileLines = allLines.length;
// Apply offset if specified (1-indexed to 0-indexed)
const startLine = offset ? Math.max(0, offset - 1) : 0;
const startLineDisplay = startLine + 1; // For display (1-indexed)
// Check if offset is out of bounds
if (startLine >= allLines.length) {
throw new Error(
`Offset ${offset} is beyond end of file (${allLines.length} lines total)`,
);
}
// If limit is specified by user, use it; otherwise we'll let truncateHead decide
let selectedContent: string;
let userLimitedLines: number | undefined;
if (limit !== undefined) {
const endLine = Math.min(startLine + limit, allLines.length);
selectedContent = allLines.slice(startLine, endLine).join("\n");
userLimitedLines = endLine - startLine;
} else {
selectedContent = allLines.slice(startLine).join("\n");
}
// Apply truncation (respects both line and byte limits)
const truncation = truncateHead(selectedContent);
let outputText: string;
if (truncation.firstLineExceedsLimit) {
// First line at offset exceeds 30KB - tell model to use bash
const firstLineSize = formatSize(
Buffer.byteLength(allLines[startLine], "utf-8"),
);
outputText = `[Line ${startLineDisplay} is ${firstLineSize}, exceeds ${formatSize(DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '${startLineDisplay}p' ${path} | head -c ${DEFAULT_MAX_BYTES}]`;
details = { truncation };
} else if (truncation.truncated) {
// Truncation occurred - build actionable notice
const endLineDisplay =
startLineDisplay + truncation.outputLines - 1;
const nextOffset = endLineDisplay + 1;
outputText = truncation.content;
if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines}. Use offset=${nextOffset} to continue.]`;
} else {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Use offset=${nextOffset} to continue.]`;
}
details = { truncation };
} else if (
userLimitedLines !== undefined &&
startLine + userLimitedLines < allLines.length
) {
// User specified limit, there's more content, but no truncation
const remaining =
allLines.length - (startLine + userLimitedLines);
const nextOffset = startLine + userLimitedLines + 1;
outputText = truncation.content;
outputText += `\n\n[${remaining} more lines in file. Use offset=${nextOffset} to continue.]`;
} else {
// No truncation, no user limit exceeded
outputText = truncation.content;
}
content = [{ type: "text", text: outputText }];
}
// Check if aborted after reading
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({ content, details });
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default read tool using process.cwd() - for backwards compatibility */
export const readTool = createReadTool(process.cwd());

View file

@ -0,0 +1,279 @@
/**
* Shared truncation utilities for tool outputs.
*
* Truncation is based on two independent limits - whichever is hit first wins:
* - Line limit (default: 2000 lines)
* - Byte limit (default: 50KB)
*
* Never returns partial lines (except bash tail truncation edge case).
*/
export const DEFAULT_MAX_LINES = 2000;
export const DEFAULT_MAX_BYTES = 50 * 1024; // 50KB
export const GREP_MAX_LINE_LENGTH = 500; // Max chars per grep match line
export interface TruncationResult {
/** The truncated content */
content: string;
/** Whether truncation occurred */
truncated: boolean;
/** Which limit was hit: "lines", "bytes", or null if not truncated */
truncatedBy: "lines" | "bytes" | null;
/** Total number of lines in the original content */
totalLines: number;
/** Total number of bytes in the original content */
totalBytes: number;
/** Number of complete lines in the truncated output */
outputLines: number;
/** Number of bytes in the truncated output */
outputBytes: number;
/** Whether the last line was partially truncated (only for tail truncation edge case) */
lastLinePartial: boolean;
/** Whether the first line exceeded the byte limit (for head truncation) */
firstLineExceedsLimit: boolean;
/** The max lines limit that was applied */
maxLines: number;
/** The max bytes limit that was applied */
maxBytes: number;
}
export interface TruncationOptions {
/** Maximum number of lines (default: 2000) */
maxLines?: number;
/** Maximum number of bytes (default: 50KB) */
maxBytes?: number;
}
/**
* Format bytes as human-readable size.
*/
export function formatSize(bytes: number): string {
if (bytes < 1024) {
return `${bytes}B`;
} else if (bytes < 1024 * 1024) {
return `${(bytes / 1024).toFixed(1)}KB`;
} else {
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
}
}
/**
* Truncate content from the head (keep first N lines/bytes).
* Suitable for file reads where you want to see the beginning.
*
* Never returns partial lines. If first line exceeds byte limit,
* returns empty content with firstLineExceedsLimit=true.
*/
export function truncateHead(
content: string,
options: TruncationOptions = {},
): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
// Check if first line alone exceeds byte limit
const firstLineBytes = Buffer.byteLength(lines[0], "utf-8");
if (firstLineBytes > maxBytes) {
return {
content: "",
truncated: true,
truncatedBy: "bytes",
totalLines,
totalBytes,
outputLines: 0,
outputBytes: 0,
lastLinePartial: false,
firstLineExceedsLimit: true,
maxLines,
maxBytes,
};
}
// Collect complete lines that fit
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
for (let i = 0; i < lines.length && i < maxLines; i++) {
const line = lines[i];
const lineBytes = Buffer.byteLength(line, "utf-8") + (i > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
break;
}
outputLinesArr.push(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
/**
* Truncate content from the tail (keep last N lines/bytes).
* Suitable for bash output where you want to see the end (errors, final results).
*
* May return partial first line if the last line of original content exceeds byte limit.
*/
export function truncateTail(
content: string,
options: TruncationOptions = {},
): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
// Work backwards from the end
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
let lastLinePartial = false;
for (
let i = lines.length - 1;
i >= 0 && outputLinesArr.length < maxLines;
i--
) {
const line = lines[i];
const lineBytes =
Buffer.byteLength(line, "utf-8") + (outputLinesArr.length > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
// Edge case: if we haven't added ANY lines yet and this line exceeds maxBytes,
// take the end of the line (partial)
if (outputLinesArr.length === 0) {
const truncatedLine = truncateStringToBytesFromEnd(line, maxBytes);
outputLinesArr.unshift(truncatedLine);
outputBytesCount = Buffer.byteLength(truncatedLine, "utf-8");
lastLinePartial = true;
}
break;
}
outputLinesArr.unshift(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial,
firstLineExceedsLimit: false,
maxLines,
maxBytes,
};
}
/**
* Truncate a string to fit within a byte limit (from the end).
* Handles multi-byte UTF-8 characters correctly.
*/
function truncateStringToBytesFromEnd(str: string, maxBytes: number): string {
const buf = Buffer.from(str, "utf-8");
if (buf.length <= maxBytes) {
return str;
}
// Start from the end, skip maxBytes back
let start = buf.length - maxBytes;
// Find a valid UTF-8 boundary (start of a character)
while (start < buf.length && (buf[start] & 0xc0) === 0x80) {
start++;
}
return buf.slice(start).toString("utf-8");
}
/**
* Truncate a single line to max characters, adding [truncated] suffix.
* Used for grep match lines.
*/
export function truncateLine(
line: string,
maxChars: number = GREP_MAX_LINE_LENGTH,
): { text: string; wasTruncated: boolean } {
if (line.length <= maxChars) {
return { text: line, wasTruncated: false };
}
return {
text: `${line.slice(0, maxChars)}... [truncated]`,
wasTruncated: true,
};
}

View file

@ -0,0 +1,129 @@
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { type Static, Type } from "@sinclair/typebox";
import { mkdir as fsMkdir, writeFile as fsWriteFile } from "fs/promises";
import { dirname } from "path";
import { resolveToCwd } from "./path-utils.js";
const writeSchema = Type.Object({
path: Type.String({
description: "Path to the file to write (relative or absolute)",
}),
content: Type.String({ description: "Content to write to the file" }),
});
export type WriteToolInput = Static<typeof writeSchema>;
/**
* Pluggable operations for the write tool.
* Override these to delegate file writing to remote systems (e.g., SSH).
*/
export interface WriteOperations {
/** Write content to a file */
writeFile: (absolutePath: string, content: string) => Promise<void>;
/** Create directory (recursively) */
mkdir: (dir: string) => Promise<void>;
}
const defaultWriteOperations: WriteOperations = {
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
mkdir: (dir) => fsMkdir(dir, { recursive: true }).then(() => {}),
};
export interface WriteToolOptions {
/** Custom operations for file writing. Default: local filesystem */
operations?: WriteOperations;
}
export function createWriteTool(
cwd: string,
options?: WriteToolOptions,
): AgentTool<typeof writeSchema> {
const ops = options?.operations ?? defaultWriteOperations;
return {
name: "write",
label: "write",
description:
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories.",
parameters: writeSchema,
execute: async (
_toolCallId: string,
{ path, content }: { path: string; content: string },
signal?: AbortSignal,
) => {
const absolutePath = resolveToCwd(path, cwd);
const dir = dirname(absolutePath);
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the write operation
(async () => {
try {
// Create parent directories if needed
await ops.mkdir(dir);
// Check if aborted before writing
if (aborted) {
return;
}
// Write the file
await ops.writeFile(absolutePath, content);
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({
content: [
{
type: "text",
text: `Successfully wrote ${content.length} bytes to ${path}`,
},
],
details: undefined,
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};
}
/** Default write tool using process.cwd() - for backwards compatibility */
export const writeTool = createWriteTool(process.cwd());

View file

@ -0,0 +1,205 @@
import { randomUUID } from "node:crypto";
import type { ServerResponse } from "node:http";
import type { AgentSessionEvent } from "./agent-session.js";
/**
* Write a single Vercel AI SDK v5+ SSE chunk to the response.
* Format: `data: <JSON>\n\n`
* For the terminal [DONE] sentinel: `data: [DONE]\n\n`
*/
function writeChunk(response: ServerResponse, chunk: object | string): void {
if (response.writableEnded) return;
const payload = typeof chunk === "string" ? chunk : JSON.stringify(chunk);
response.write(`data: ${payload}\n\n`);
}
/**
* Extract the user's text from the request body.
* Supports both useChat format ({ messages: UIMessage[] }) and simple gateway format ({ text: string }).
*/
export function extractUserText(body: Record<string, unknown>): string | null {
// Simple gateway format
if (typeof body.text === "string" && body.text.trim()) {
return body.text;
}
// Convenience format
if (typeof body.prompt === "string" && body.prompt.trim()) {
return body.prompt;
}
// Vercel AI SDK useChat format - extract last user message
if (Array.isArray(body.messages)) {
for (let i = body.messages.length - 1; i >= 0; i--) {
const msg = body.messages[i] as Record<string, unknown>;
if (msg.role !== "user") continue;
// v5+ format with parts array
if (Array.isArray(msg.parts)) {
for (const part of msg.parts as Array<Record<string, unknown>>) {
if (part.type === "text" && typeof part.text === "string") {
return part.text;
}
}
}
// v4 format with content string
if (typeof msg.content === "string" && msg.content.trim()) {
return msg.content;
}
}
}
return null;
}
/**
* Create an AgentSessionEvent listener that translates events to Vercel AI SDK v5+ SSE
* chunks and writes them to the HTTP response.
*
* Returns the listener function. The caller is responsible for subscribing/unsubscribing.
*/
export function createVercelStreamListener(
response: ServerResponse,
messageId?: string,
): (event: AgentSessionEvent) => void {
// Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
// handleChat now subscribes this listener immediately before the queued prompt starts,
// so these guards only need to bound the stream to that prompt's event span.
let active = false;
const msgId = messageId ?? randomUUID();
return (event: AgentSessionEvent) => {
if (response.writableEnded) return;
// Activate on our agent_start, deactivate on agent_end
if (event.type === "agent_start") {
if (!active) {
active = true;
writeChunk(response, { type: "start", messageId: msgId });
}
return;
}
if (event.type === "agent_end") {
active = false;
return;
}
// Drop events that don't belong to our message
if (!active) return;
switch (event.type) {
case "turn_start":
writeChunk(response, { type: "start-step" });
return;
case "message_update": {
const inner = event.assistantMessageEvent;
switch (inner.type) {
case "text_start":
writeChunk(response, {
type: "text-start",
id: `text_${inner.contentIndex}`,
});
return;
case "text_delta":
writeChunk(response, {
type: "text-delta",
id: `text_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "text_end":
writeChunk(response, {
type: "text-end",
id: `text_${inner.contentIndex}`,
});
return;
case "toolcall_start": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-start",
toolCallId: content.id,
toolName: content.name,
});
}
return;
}
case "toolcall_delta": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-delta",
toolCallId: content.id,
inputTextDelta: inner.delta,
});
}
return;
}
case "toolcall_end":
writeChunk(response, {
type: "tool-input-available",
toolCallId: inner.toolCall.id,
toolName: inner.toolCall.name,
input: inner.toolCall.arguments,
});
return;
case "thinking_start":
writeChunk(response, {
type: "reasoning-start",
id: `reasoning_${inner.contentIndex}`,
});
return;
case "thinking_delta":
writeChunk(response, {
type: "reasoning-delta",
id: `reasoning_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "thinking_end":
writeChunk(response, {
type: "reasoning-end",
id: `reasoning_${inner.contentIndex}`,
});
return;
}
return;
}
case "turn_end":
writeChunk(response, { type: "finish-step" });
return;
case "tool_execution_end":
writeChunk(response, {
type: "tool-output-available",
toolCallId: event.toolCallId,
output: event.result,
});
return;
}
};
}
/**
* Write the terminal finish sequence and end the response.
*/
export function finishVercelStream(
response: ServerResponse,
finishReason: string = "stop",
): void {
if (response.writableEnded) return;
writeChunk(response, { type: "finish", finishReason });
writeChunk(response, "[DONE]");
response.end();
}
/**
* Write an error chunk and end the response.
*/
export function errorVercelStream(
response: ServerResponse,
errorText: string,
): void {
if (response.writableEnded) return;
writeChunk(response, { type: "error", errorText });
writeChunk(response, "[DONE]");
response.end();
}

View file

@ -0,0 +1,353 @@
// Core session management
// Config paths
export { getAgentDir, VERSION } from "./config.js";
export {
AgentSession,
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type ModelCycleResult,
type ParsedSkillBlock,
type PromptOptions,
parseSkillBlock,
type SessionStats,
} from "./core/agent-session.js";
// Auth and model registry
export {
type ApiKeyCredential,
type AuthCredential,
AuthStorage,
type AuthStorageBackend,
FileAuthStorageBackend,
InMemoryAuthStorageBackend,
type OAuthCredential,
} from "./core/auth-storage.js";
// Compaction
export {
type BranchPreparation,
type BranchSummaryResult,
type CollectEntriesResult,
type CompactionResult,
type CutPointResult,
calculateContextTokens,
collectEntriesForBranchSummary,
compact,
DEFAULT_COMPACTION_SETTINGS,
estimateTokens,
type FileOperations,
findCutPoint,
findTurnStartIndex,
type GenerateBranchSummaryOptions,
generateBranchSummary,
generateSummary,
getLastAssistantUsage,
prepareBranchEntries,
serializeConversation,
shouldCompact,
} from "./core/compaction/index.js";
export {
createEventBus,
type EventBus,
type EventBusController,
} from "./core/event-bus.js";
// Extension system
export type {
AgentEndEvent,
AgentStartEvent,
AgentToolResult,
AgentToolUpdateCallback,
AppAction,
BashToolCallEvent,
BeforeAgentStartEvent,
CompactOptions,
ContextEvent,
ContextUsage,
CustomToolCallEvent,
EditToolCallEvent,
ExecOptions,
ExecResult,
Extension,
ExtensionActions,
ExtensionAPI,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
ExtensionError,
ExtensionEvent,
ExtensionFactory,
ExtensionFlag,
ExtensionHandler,
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
ExtensionUIDialogOptions,
ExtensionWidgetOptions,
FindToolCallEvent,
GrepToolCallEvent,
InputEvent,
InputEventResult,
InputSource,
KeybindingsManager,
LoadExtensionsResult,
LsToolCallEvent,
MessageRenderer,
MessageRenderOptions,
ProviderConfig,
ProviderModelConfig,
ReadToolCallEvent,
RegisteredCommand,
RegisteredTool,
SessionBeforeCompactEvent,
SessionBeforeForkEvent,
SessionBeforeSwitchEvent,
SessionBeforeTreeEvent,
SessionCompactEvent,
SessionForkEvent,
SessionShutdownEvent,
SessionStartEvent,
SessionSwitchEvent,
SessionTreeEvent,
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
TerminalInputHandler,
ToolCallEvent,
ToolDefinition,
ToolInfo,
ToolRenderResultOptions,
ToolResultEvent,
TurnEndEvent,
TurnStartEvent,
UserBashEvent,
UserBashEventResult,
WidgetPlacement,
WriteToolCallEvent,
} from "./core/extensions/index.js";
export {
createExtensionRuntime,
discoverAndLoadExtensions,
ExtensionRunner,
isBashToolResult,
isEditToolResult,
isFindToolResult,
isGrepToolResult,
isLsToolResult,
isReadToolResult,
isToolCallEventType,
isWriteToolResult,
wrapRegisteredTool,
wrapRegisteredTools,
wrapToolsWithExtensions,
wrapToolWithExtensions,
} from "./core/extensions/index.js";
// Footer data provider (git branch + extension statuses - data not otherwise available to extensions)
export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js";
export {
createGatewaySessionManager,
type GatewayConfig,
type GatewayMessageRequest,
type GatewayMessageResult,
GatewayRuntime,
type GatewayRuntimeOptions,
type GatewaySessionFactory,
type GatewaySessionSnapshot,
getActiveGatewayRuntime,
sanitizeSessionKey,
setActiveGatewayRuntime,
} from "./core/gateway-runtime.js";
export { convertToLlm } from "./core/messages.js";
export { ModelRegistry } from "./core/model-registry.js";
export type {
PackageManager,
PathMetadata,
ProgressCallback,
ProgressEvent,
ResolvedPaths,
ResolvedResource,
} from "./core/package-manager.js";
export { DefaultPackageManager } from "./core/package-manager.js";
export type {
ResourceCollision,
ResourceDiagnostic,
ResourceLoader,
} from "./core/resource-loader.js";
export { DefaultResourceLoader } from "./core/resource-loader.js";
// SDK for programmatic usage
export {
type CreateAgentSessionOptions,
type CreateAgentSessionResult,
// Factory
createAgentSession,
createBashTool,
// Tool factories (for custom cwd)
createCodingTools,
createEditTool,
createFindTool,
createGrepTool,
createLsTool,
createReadOnlyTools,
createReadTool,
createWriteTool,
type PromptTemplate,
// Pre-built tools (use process.cwd())
readOnlyTools,
} from "./core/sdk.js";
export {
type BranchSummaryEntry,
buildSessionContext,
type CompactionEntry,
CURRENT_SESSION_VERSION,
type CustomEntry,
type CustomMessageEntry,
type FileEntry,
getLatestCompactionEntry,
type ModelChangeEntry,
migrateSessionEntries,
type NewSessionOptions,
parseSessionEntries,
type SessionContext,
type SessionEntry,
type SessionEntryBase,
type SessionHeader,
type SessionInfo,
type SessionInfoEntry,
SessionManager,
type SessionMessageEntry,
type ThinkingLevelChangeEntry,
} from "./core/session-manager.js";
export {
type CompactionSettings,
type GatewaySettings,
type ImageSettings,
type PackageSource,
type RetrySettings,
SettingsManager,
} from "./core/settings-manager.js";
// Skills
export {
formatSkillsForPrompt,
type LoadSkillsFromDirOptions,
type LoadSkillsResult,
loadSkills,
loadSkillsFromDir,
type Skill,
type SkillFrontmatter,
} from "./core/skills.js";
// Tools
export {
type BashOperations,
type BashSpawnContext,
type BashSpawnHook,
type BashToolDetails,
type BashToolInput,
type BashToolOptions,
bashTool,
codingTools,
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
type EditOperations,
type EditToolDetails,
type EditToolInput,
type EditToolOptions,
editTool,
type FindOperations,
type FindToolDetails,
type FindToolInput,
type FindToolOptions,
findTool,
formatSize,
type GrepOperations,
type GrepToolDetails,
type GrepToolInput,
type GrepToolOptions,
grepTool,
type LsOperations,
type LsToolDetails,
type LsToolInput,
type LsToolOptions,
lsTool,
type ReadOperations,
type ReadToolDetails,
type ReadToolInput,
type ReadToolOptions,
readTool,
type ToolsOptions,
type TruncationOptions,
type TruncationResult,
truncateHead,
truncateLine,
truncateTail,
type WriteOperations,
type WriteToolInput,
type WriteToolOptions,
writeTool,
} from "./core/tools/index.js";
// Main entry point
export { main } from "./main.js";
// Run modes for programmatic SDK usage
export {
InteractiveMode,
type InteractiveModeOptions,
type PrintModeOptions,
runPrintMode,
runRpcMode,
} from "./modes/index.js";
// UI components for extensions
export {
ArminComponent,
AssistantMessageComponent,
appKey,
appKeyHint,
BashExecutionComponent,
BorderedLoader,
BranchSummaryMessageComponent,
CompactionSummaryMessageComponent,
CustomEditor,
CustomMessageComponent,
DynamicBorder,
ExtensionEditorComponent,
ExtensionInputComponent,
ExtensionSelectorComponent,
editorKey,
FooterComponent,
keyHint,
LoginDialogComponent,
ModelSelectorComponent,
OAuthSelectorComponent,
type RenderDiffOptions,
rawKeyHint,
renderDiff,
SessionSelectorComponent,
type SettingsCallbacks,
type SettingsConfig,
SettingsSelectorComponent,
ShowImagesSelectorComponent,
SkillInvocationMessageComponent,
ThemeSelectorComponent,
ThinkingSelectorComponent,
ToolExecutionComponent,
type ToolExecutionOptions,
TreeSelectorComponent,
truncateToVisualLines,
UserMessageComponent,
UserMessageSelectorComponent,
type VisualTruncateResult,
} from "./modes/interactive/components/index.js";
// Theme utilities for custom tools and extensions
export {
getLanguageFromPath,
getMarkdownTheme,
getSelectListTheme,
getSettingsListTheme,
highlightCode,
initTheme,
Theme,
type ThemeColor,
} from "./modes/interactive/theme/theme.js";
// Clipboard utilities
export { copyToClipboard } from "./utils/clipboard.js";
export { parseFrontmatter, stripFrontmatter } from "./utils/frontmatter.js";
// Shell utilities
export { getShellConfig } from "./utils/shell.js";

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,317 @@
/**
* One-time migrations that run on startup.
*/
import chalk from "chalk";
import {
existsSync,
mkdirSync,
readdirSync,
readFileSync,
renameSync,
rmSync,
writeFileSync,
} from "fs";
import { dirname, join } from "path";
import { CONFIG_DIR_NAME, getAgentDir, getBinDir } from "./config.js";
const MIGRATION_GUIDE_URL =
"https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/CHANGELOG.md#extensions-migration";
const EXTENSIONS_DOC_URL =
"https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/docs/extensions.md";
/**
* Migrate legacy oauth.json and settings.json apiKeys to auth.json.
*
* @returns Array of provider names that were migrated
*/
export function migrateAuthToAuthJson(): string[] {
const agentDir = getAgentDir();
const authPath = join(agentDir, "auth.json");
const oauthPath = join(agentDir, "oauth.json");
const settingsPath = join(agentDir, "settings.json");
// Skip if auth.json already exists
if (existsSync(authPath)) return [];
const migrated: Record<string, unknown> = {};
const providers: string[] = [];
// Migrate oauth.json
if (existsSync(oauthPath)) {
try {
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
for (const [provider, cred] of Object.entries(oauth)) {
migrated[provider] = { type: "oauth", ...(cred as object) };
providers.push(provider);
}
renameSync(oauthPath, `${oauthPath}.migrated`);
} catch {
// Skip on error
}
}
// Migrate settings.json apiKeys
if (existsSync(settingsPath)) {
try {
const content = readFileSync(settingsPath, "utf-8");
const settings = JSON.parse(content);
if (settings.apiKeys && typeof settings.apiKeys === "object") {
for (const [provider, key] of Object.entries(settings.apiKeys)) {
if (!migrated[provider] && typeof key === "string") {
migrated[provider] = { type: "api_key", key };
providers.push(provider);
}
}
delete settings.apiKeys;
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
}
} catch {
// Skip on error
}
}
if (Object.keys(migrated).length > 0) {
mkdirSync(dirname(authPath), { recursive: true });
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
}
return providers;
}
/**
* Migrate sessions from ~/.pi/agent/*.jsonl to proper session directories.
*
* Bug in v0.30.0: Sessions were saved to ~/.pi/agent/ instead of
* ~/.pi/agent/sessions/<encoded-cwd>/. This migration moves them
* to the correct location based on the cwd in their session header.
*
* See: https://github.com/badlogic/pi-mono/issues/320
*/
export function migrateSessionsFromAgentRoot(): void {
const agentDir = getAgentDir();
// Find all .jsonl files directly in agentDir (not in subdirectories)
let files: string[];
try {
files = readdirSync(agentDir)
.filter((f) => f.endsWith(".jsonl"))
.map((f) => join(agentDir, f));
} catch {
return;
}
if (files.length === 0) return;
for (const file of files) {
try {
// Read first line to get session header
const content = readFileSync(file, "utf8");
const firstLine = content.split("\n")[0];
if (!firstLine?.trim()) continue;
const header = JSON.parse(firstLine);
if (header.type !== "session" || !header.cwd) continue;
const cwd: string = header.cwd;
// Compute the correct session directory (same encoding as session-manager.ts)
const safePath = `--${cwd.replace(/^[/\\]/, "").replace(/[/\\:]/g, "-")}--`;
const correctDir = join(agentDir, "sessions", safePath);
// Create directory if needed
if (!existsSync(correctDir)) {
mkdirSync(correctDir, { recursive: true });
}
// Move the file
const fileName = file.split("/").pop() || file.split("\\").pop();
const newPath = join(correctDir, fileName!);
if (existsSync(newPath)) continue; // Skip if target exists
renameSync(file, newPath);
} catch {
// Skip files that can't be migrated
}
}
}
/**
* Migrate commands/ to prompts/ if needed.
* Works for both regular directories and symlinks.
*/
function migrateCommandsToPrompts(baseDir: string, label: string): boolean {
const commandsDir = join(baseDir, "commands");
const promptsDir = join(baseDir, "prompts");
if (existsSync(commandsDir) && !existsSync(promptsDir)) {
try {
renameSync(commandsDir, promptsDir);
console.log(chalk.green(`Migrated ${label} commands/ → prompts/`));
return true;
} catch (err) {
console.log(
chalk.yellow(
`Warning: Could not migrate ${label} commands/ to prompts/: ${err instanceof Error ? err.message : err}`,
),
);
}
}
return false;
}
/**
* Move fd/rg binaries from tools/ to bin/ if they exist.
*/
function migrateToolsToBin(): void {
const agentDir = getAgentDir();
const toolsDir = join(agentDir, "tools");
const binDir = getBinDir();
if (!existsSync(toolsDir)) return;
const binaries = ["fd", "rg", "fd.exe", "rg.exe"];
let movedAny = false;
for (const bin of binaries) {
const oldPath = join(toolsDir, bin);
const newPath = join(binDir, bin);
if (existsSync(oldPath)) {
if (!existsSync(binDir)) {
mkdirSync(binDir, { recursive: true });
}
if (!existsSync(newPath)) {
try {
renameSync(oldPath, newPath);
movedAny = true;
} catch {
// Ignore errors
}
} else {
// Target exists, just delete the old one
try {
rmSync?.(oldPath, { force: true });
} catch {
// Ignore
}
}
}
}
if (movedAny) {
console.log(chalk.green(`Migrated managed binaries tools/ → bin/`));
}
}
/**
* Check for deprecated hooks/ and tools/ directories.
* Note: tools/ may contain fd/rg binaries extracted by pi, so only warn if it has other files.
*/
function checkDeprecatedExtensionDirs(
baseDir: string,
label: string,
): string[] {
const hooksDir = join(baseDir, "hooks");
const toolsDir = join(baseDir, "tools");
const warnings: string[] = [];
if (existsSync(hooksDir)) {
warnings.push(
`${label} hooks/ directory found. Hooks have been renamed to extensions.`,
);
}
if (existsSync(toolsDir)) {
// Check if tools/ contains anything other than fd/rg (which are auto-extracted binaries)
try {
const entries = readdirSync(toolsDir);
const customTools = entries.filter((e) => {
const lower = e.toLowerCase();
return (
lower !== "fd" &&
lower !== "rg" &&
lower !== "fd.exe" &&
lower !== "rg.exe" &&
!e.startsWith(".") // Ignore .DS_Store and other hidden files
);
});
if (customTools.length > 0) {
warnings.push(
`${label} tools/ directory contains custom tools. Custom tools have been merged into extensions.`,
);
}
} catch {
// Ignore read errors
}
}
return warnings;
}
/**
* Run extension system migrations (commandsprompts) and collect warnings about deprecated directories.
*/
function migrateExtensionSystem(cwd: string): string[] {
const agentDir = getAgentDir();
const projectDir = join(cwd, CONFIG_DIR_NAME);
// Migrate commands/ to prompts/
migrateCommandsToPrompts(agentDir, "Global");
migrateCommandsToPrompts(projectDir, "Project");
// Check for deprecated directories
const warnings = [
...checkDeprecatedExtensionDirs(agentDir, "Global"),
...checkDeprecatedExtensionDirs(projectDir, "Project"),
];
return warnings;
}
/**
* Print deprecation warnings and wait for keypress.
*/
export async function showDeprecationWarnings(
warnings: string[],
): Promise<void> {
if (warnings.length === 0) return;
for (const warning of warnings) {
console.log(chalk.yellow(`Warning: ${warning}`));
}
console.log(
chalk.yellow(`\nMove your extensions to the extensions/ directory.`),
);
console.log(chalk.yellow(`Migration guide: ${MIGRATION_GUIDE_URL}`));
console.log(chalk.yellow(`Documentation: ${EXTENSIONS_DOC_URL}`));
console.log(chalk.dim(`\nPress any key to continue...`));
await new Promise<void>((resolve) => {
process.stdin.setRawMode?.(true);
process.stdin.resume();
process.stdin.once("data", () => {
process.stdin.setRawMode?.(false);
process.stdin.pause();
resolve();
});
});
console.log();
}
/**
* Run all migrations. Called once on startup.
*
* @returns Object with migration results and deprecation warnings
*/
export function runMigrations(cwd: string = process.cwd()): {
migratedAuthProviders: string[];
deprecationWarnings: string[];
} {
const migratedAuthProviders = migrateAuthToAuthJson();
migrateSessionsFromAgentRoot();
migrateToolsToBin();
const deprecationWarnings = migrateExtensionSystem(cwd);
return { migratedAuthProviders, deprecationWarnings };
}

View file

@ -0,0 +1,233 @@
/**
* Daemon mode (always-on background execution).
*
* Starts agent extensions, accepts messages from extension sources
* (webhooks, queues, Telegram/Slack gateways, etc.), and stays alive
* until explicitly stopped.
*/
import type { ImageContent } from "@mariozechner/pi-ai";
import type { AgentSession } from "../core/agent-session.js";
import {
GatewayRuntime,
type GatewaySessionFactory,
setActiveGatewayRuntime,
} from "../core/gateway-runtime.js";
import type { GatewaySettings } from "../core/settings-manager.js";
/**
* Options for daemon mode.
*/
export interface DaemonModeOptions {
/** First message to send at startup (can include @file content expansion by caller). */
initialMessage?: string;
/** Images to attach to the startup message. */
initialImages?: ImageContent[];
/** Additional startup messages (sent after initialMessage, one by one). */
messages?: string[];
/** Factory for creating additional gateway-owned sessions. */
createSession: GatewaySessionFactory;
/** Gateway config from settings/env. */
gateway: GatewaySettings;
}
export interface DaemonModeResult {
reason: "shutdown";
}
function createCommandContextActions(session: AgentSession) {
return {
waitForIdle: () => session.agent.waitForIdle(),
newSession: async (options?: {
parentSession?: string;
setup?: (
sessionManager: typeof session.sessionManager,
) => Promise<void> | void;
}) => {
const success = await session.newSession({
parentSession: options?.parentSession,
});
if (success && options?.setup) {
await options.setup(session.sessionManager);
}
return { cancelled: !success };
},
fork: async (entryId: string) => {
const result = await session.fork(entryId);
return { cancelled: result.cancelled };
},
navigateTree: async (
targetId: string,
options?: {
summarize?: boolean;
customInstructions?: string;
replaceInstructions?: boolean;
label?: string;
},
) => {
const result = await session.navigateTree(targetId, {
summarize: options?.summarize,
customInstructions: options?.customInstructions,
replaceInstructions: options?.replaceInstructions,
label: options?.label,
});
return { cancelled: result.cancelled };
},
switchSession: async (sessionPath: string) => {
const success = await session.switchSession(sessionPath);
return { cancelled: !success };
},
reload: async () => {
await session.reload();
},
};
}
/**
* Run in daemon mode.
* Stays alive indefinitely unless stopped by signal or extension trigger.
*/
export async function runDaemonMode(
session: AgentSession,
options: DaemonModeOptions,
): Promise<DaemonModeResult> {
const { initialMessage, initialImages, messages = [] } = options;
let isShuttingDown = false;
let resolveReady: (result: DaemonModeResult) => void = () => {};
const ready = new Promise<DaemonModeResult>((resolve) => {
resolveReady = resolve;
});
const gatewayBind =
process.env.PI_GATEWAY_BIND ?? options.gateway.bind ?? "127.0.0.1";
const gatewayPort =
Number.parseInt(process.env.PI_GATEWAY_PORT ?? "", 10) ||
options.gateway.port ||
8787;
const gatewayToken =
process.env.PI_GATEWAY_TOKEN ?? options.gateway.bearerToken;
const gateway = new GatewayRuntime({
config: {
bind: gatewayBind,
port: gatewayPort,
bearerToken: gatewayToken,
session: {
idleMinutes: options.gateway.session?.idleMinutes ?? 60,
maxQueuePerSession: options.gateway.session?.maxQueuePerSession ?? 8,
},
webhook: {
enabled: options.gateway.webhook?.enabled ?? true,
basePath: options.gateway.webhook?.basePath ?? "/webhooks",
secret:
process.env.PI_GATEWAY_WEBHOOK_SECRET ??
options.gateway.webhook?.secret,
},
},
primarySessionKey: "web:main",
primarySession: session,
createSession: options.createSession,
log: (message) => {
console.error(`[pi-gateway] ${message}`);
},
});
setActiveGatewayRuntime(gateway);
const shutdown = async (reason: "signal" | "extension"): Promise<void> => {
if (isShuttingDown) return;
isShuttingDown = true;
console.error(`[pi-gateway] shutdown requested: ${reason}`);
setActiveGatewayRuntime(null);
await gateway.stop();
const runner = session.extensionRunner;
if (runner?.hasHandlers("session_shutdown")) {
await runner.emit({ type: "session_shutdown" });
}
session.dispose();
resolveReady({ reason: "shutdown" });
};
const handleShutdownSignal = (signal: NodeJS.Signals) => {
void shutdown("signal").catch((error) => {
console.error(
`[pi-gateway] shutdown failed for ${signal}: ${error instanceof Error ? error.message : String(error)}`,
);
resolveReady({ reason: "shutdown" });
});
};
const sigintHandler = () => handleShutdownSignal("SIGINT");
const sigtermHandler = () => handleShutdownSignal("SIGTERM");
const sigquitHandler = () => handleShutdownSignal("SIGQUIT");
const sighupHandler = () => handleShutdownSignal("SIGHUP");
const unhandledRejectionHandler = (error: unknown) => {
console.error(
`[pi-gateway] unhandled rejection: ${error instanceof Error ? error.message : String(error)}`,
);
};
process.once("SIGINT", sigintHandler);
process.once("SIGTERM", sigtermHandler);
process.once("SIGQUIT", sigquitHandler);
process.once("SIGHUP", sighupHandler);
process.on("unhandledRejection", unhandledRejectionHandler);
await session.bindExtensions({
commandContextActions: createCommandContextActions(session),
shutdownHandler: () => {
void shutdown("extension").catch((error) => {
console.error(
`[pi-gateway] extension shutdown failed: ${error instanceof Error ? error.message : String(error)}`,
);
resolveReady({ reason: "shutdown" });
});
},
onError: (err) => {
console.error(`Extension error (${err.extensionPath}): ${err.error}`);
},
});
// Emit structured events to stderr for supervisor logs.
session.subscribe((event) => {
console.error(
JSON.stringify({
type: event.type,
sessionId: session.sessionId,
messageCount: session.messages.length,
}),
);
});
// Startup probes/messages.
if (initialMessage) {
await session.prompt(initialMessage, { images: initialImages });
}
for (const message of messages) {
await session.prompt(message);
}
await gateway.start();
console.error(
`[pi-gateway] startup complete (session=${session.sessionId ?? "unknown"}, bind=${gatewayBind}, port=${gatewayPort})`,
);
// Keep process alive forever.
const keepAlive = setInterval(() => {
// Intentionally keep the daemon event loop active.
}, 1000);
const cleanup = () => {
clearInterval(keepAlive);
process.removeListener("SIGINT", sigintHandler);
process.removeListener("SIGTERM", sigtermHandler);
process.removeListener("SIGQUIT", sigquitHandler);
process.removeListener("SIGHUP", sighupHandler);
process.removeListener("unhandledRejection", unhandledRejectionHandler);
};
try {
return await ready;
} finally {
cleanup();
}
}

View file

@ -0,0 +1,26 @@
/**
* Run modes for the coding agent.
*/
export {
type DaemonModeOptions,
type DaemonModeResult,
runDaemonMode,
} from "./daemon-mode.js";
export {
InteractiveMode,
type InteractiveModeOptions,
} from "./interactive/interactive-mode.js";
export { type PrintModeOptions, runPrintMode } from "./print-mode.js";
export {
type ModelInfo,
RpcClient,
type RpcClientOptions,
type RpcEventListener,
} from "./rpc/rpc-client.js";
export { runRpcMode } from "./rpc/rpc-mode.js";
export type {
RpcCommand,
RpcResponse,
RpcSessionState,
} from "./rpc/rpc-types.js";

View file

@ -0,0 +1,422 @@
/**
* Armin says hi! A fun easter egg with animated XBM art.
*/
import type { Component, TUI } from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
// XBM image: 31x36 pixels, LSB first, 1=background, 0=foreground
const WIDTH = 31;
const HEIGHT = 36;
const BITS = [
0xff, 0xff, 0xff, 0x7f, 0xff, 0xf0, 0xff, 0x7f, 0xff, 0xed, 0xff, 0x7f, 0xff,
0xdb, 0xff, 0x7f, 0xff, 0xb7, 0xff, 0x7f, 0xff, 0x77, 0xfe, 0x7f, 0x3f, 0xf8,
0xfe, 0x7f, 0xdf, 0xff, 0xfe, 0x7f, 0xdf, 0x3f, 0xfc, 0x7f, 0x9f, 0xc3, 0xfb,
0x7f, 0x6f, 0xfc, 0xf4, 0x7f, 0xf7, 0x0f, 0xf7, 0x7f, 0xf7, 0xff, 0xf7, 0x7f,
0xf7, 0xff, 0xe3, 0x7f, 0xf7, 0x07, 0xe8, 0x7f, 0xef, 0xf8, 0x67, 0x70, 0x0f,
0xff, 0xbb, 0x6f, 0xf1, 0x00, 0xd0, 0x5b, 0xfd, 0x3f, 0xec, 0x53, 0xc1, 0xff,
0xef, 0x57, 0x9f, 0xfd, 0xee, 0x5f, 0x9f, 0xfc, 0xae, 0x5f, 0x1f, 0x78, 0xac,
0x5f, 0x3f, 0x00, 0x50, 0x6c, 0x7f, 0x00, 0xdc, 0x77, 0xff, 0xc0, 0x3f, 0x78,
0xff, 0x01, 0xf8, 0x7f, 0xff, 0x03, 0x9c, 0x78, 0xff, 0x07, 0x8c, 0x7c, 0xff,
0x0f, 0xce, 0x78, 0xff, 0xff, 0xcf, 0x7f, 0xff, 0xff, 0xcf, 0x78, 0xff, 0xff,
0xdf, 0x78, 0xff, 0xff, 0xdf, 0x7d, 0xff, 0xff, 0x3f, 0x7e, 0xff, 0xff, 0xff,
0x7f,
];
const BYTES_PER_ROW = Math.ceil(WIDTH / 8);
const DISPLAY_HEIGHT = Math.ceil(HEIGHT / 2); // Half-block rendering
type Effect =
| "typewriter"
| "scanline"
| "rain"
| "fade"
| "crt"
| "glitch"
| "dissolve";
const EFFECTS: Effect[] = [
"typewriter",
"scanline",
"rain",
"fade",
"crt",
"glitch",
"dissolve",
];
// Get pixel at (x, y): true = foreground, false = background
function getPixel(x: number, y: number): boolean {
if (y >= HEIGHT) return false;
const byteIndex = y * BYTES_PER_ROW + Math.floor(x / 8);
const bitIndex = x % 8;
return ((BITS[byteIndex] >> bitIndex) & 1) === 0;
}
// Get the character for a cell (2 vertical pixels packed)
function getChar(x: number, row: number): string {
const upper = getPixel(x, row * 2);
const lower = getPixel(x, row * 2 + 1);
if (upper && lower) return "█";
if (upper) return "▀";
if (lower) return "▄";
return " ";
}
// Build the final image grid
function buildFinalGrid(): string[][] {
const grid: string[][] = [];
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
const line: string[] = [];
for (let x = 0; x < WIDTH; x++) {
line.push(getChar(x, row));
}
grid.push(line);
}
return grid;
}
export class ArminComponent implements Component {
private ui: TUI;
private interval: ReturnType<typeof setInterval> | null = null;
private effect: Effect;
private finalGrid: string[][];
private currentGrid: string[][];
private effectState: Record<string, unknown> = {};
private cachedLines: string[] = [];
private cachedWidth = 0;
private gridVersion = 0;
private cachedVersion = -1;
constructor(ui: TUI) {
this.ui = ui;
this.effect = EFFECTS[Math.floor(Math.random() * EFFECTS.length)];
this.finalGrid = buildFinalGrid();
this.currentGrid = this.createEmptyGrid();
this.initEffect();
this.startAnimation();
}
invalidate(): void {
this.cachedWidth = 0;
}
render(width: number): string[] {
if (width === this.cachedWidth && this.cachedVersion === this.gridVersion) {
return this.cachedLines;
}
const padding = 1;
const availableWidth = width - padding;
this.cachedLines = this.currentGrid.map((row) => {
// Clip row to available width before applying color
const clipped = row.slice(0, availableWidth).join("");
const padRight = Math.max(0, width - padding - clipped.length);
return ` ${theme.fg("accent", clipped)}${" ".repeat(padRight)}`;
});
// Add "ARMIN SAYS HI" at the end
const message = "ARMIN SAYS HI";
const msgPadRight = Math.max(0, width - padding - message.length);
this.cachedLines.push(
` ${theme.fg("accent", message)}${" ".repeat(msgPadRight)}`,
);
this.cachedWidth = width;
this.cachedVersion = this.gridVersion;
return this.cachedLines;
}
private createEmptyGrid(): string[][] {
return Array.from({ length: DISPLAY_HEIGHT }, () => Array(WIDTH).fill(" "));
}
private initEffect(): void {
switch (this.effect) {
case "typewriter":
this.effectState = { pos: 0 };
break;
case "scanline":
this.effectState = { row: 0 };
break;
case "rain":
// Track falling position for each column
this.effectState = {
drops: Array.from({ length: WIDTH }, () => ({
y: -Math.floor(Math.random() * DISPLAY_HEIGHT * 2),
settled: 0,
})),
};
break;
case "fade": {
// Shuffle all pixel positions
const positions: [number, number][] = [];
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
for (let x = 0; x < WIDTH; x++) {
positions.push([row, x]);
}
}
// Fisher-Yates shuffle
for (let i = positions.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[positions[i], positions[j]] = [positions[j], positions[i]];
}
this.effectState = { positions, idx: 0 };
break;
}
case "crt":
this.effectState = { expansion: 0 };
break;
case "glitch":
this.effectState = { phase: 0, glitchFrames: 8 };
break;
case "dissolve": {
// Start with random noise
this.currentGrid = Array.from({ length: DISPLAY_HEIGHT }, () =>
Array.from({ length: WIDTH }, () => {
const chars = [" ", "░", "▒", "▓", "█", "▀", "▄"];
return chars[Math.floor(Math.random() * chars.length)];
}),
);
// Shuffle positions for gradual resolve
const dissolvePositions: [number, number][] = [];
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
for (let x = 0; x < WIDTH; x++) {
dissolvePositions.push([row, x]);
}
}
for (let i = dissolvePositions.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[dissolvePositions[i], dissolvePositions[j]] = [
dissolvePositions[j],
dissolvePositions[i],
];
}
this.effectState = { positions: dissolvePositions, idx: 0 };
break;
}
}
}
private startAnimation(): void {
const fps = this.effect === "glitch" ? 60 : 30;
this.interval = setInterval(() => {
const done = this.tickEffect();
this.updateDisplay();
this.ui.requestRender();
if (done) {
this.stopAnimation();
}
}, 1000 / fps);
}
private stopAnimation(): void {
if (this.interval) {
clearInterval(this.interval);
this.interval = null;
}
}
private tickEffect(): boolean {
switch (this.effect) {
case "typewriter":
return this.tickTypewriter();
case "scanline":
return this.tickScanline();
case "rain":
return this.tickRain();
case "fade":
return this.tickFade();
case "crt":
return this.tickCrt();
case "glitch":
return this.tickGlitch();
case "dissolve":
return this.tickDissolve();
default:
return true;
}
}
private tickTypewriter(): boolean {
const state = this.effectState as { pos: number };
const pixelsPerFrame = 3;
for (let i = 0; i < pixelsPerFrame; i++) {
const row = Math.floor(state.pos / WIDTH);
const x = state.pos % WIDTH;
if (row >= DISPLAY_HEIGHT) return true;
this.currentGrid[row][x] = this.finalGrid[row][x];
state.pos++;
}
return false;
}
private tickScanline(): boolean {
const state = this.effectState as { row: number };
if (state.row >= DISPLAY_HEIGHT) return true;
// Copy row
for (let x = 0; x < WIDTH; x++) {
this.currentGrid[state.row][x] = this.finalGrid[state.row][x];
}
state.row++;
return false;
}
private tickRain(): boolean {
const state = this.effectState as {
drops: { y: number; settled: number }[];
};
let allSettled = true;
this.currentGrid = this.createEmptyGrid();
for (let x = 0; x < WIDTH; x++) {
const drop = state.drops[x];
// Draw settled pixels
for (
let row = DISPLAY_HEIGHT - 1;
row >= DISPLAY_HEIGHT - drop.settled;
row--
) {
if (row >= 0) {
this.currentGrid[row][x] = this.finalGrid[row][x];
}
}
// Check if this column is done
if (drop.settled >= DISPLAY_HEIGHT) continue;
allSettled = false;
// Find the target row for this column (lowest non-space pixel)
let targetRow = -1;
for (let row = DISPLAY_HEIGHT - 1 - drop.settled; row >= 0; row--) {
if (this.finalGrid[row][x] !== " ") {
targetRow = row;
break;
}
}
// Move drop down
drop.y++;
// Draw falling drop
if (drop.y >= 0 && drop.y < DISPLAY_HEIGHT) {
if (targetRow >= 0 && drop.y >= targetRow) {
// Settle
drop.settled = DISPLAY_HEIGHT - targetRow;
drop.y = -Math.floor(Math.random() * 5) - 1;
} else {
// Still falling
this.currentGrid[drop.y][x] = "▓";
}
}
}
return allSettled;
}
private tickFade(): boolean {
const state = this.effectState as {
positions: [number, number][];
idx: number;
};
const pixelsPerFrame = 15;
for (let i = 0; i < pixelsPerFrame; i++) {
if (state.idx >= state.positions.length) return true;
const [row, x] = state.positions[state.idx];
this.currentGrid[row][x] = this.finalGrid[row][x];
state.idx++;
}
return false;
}
private tickCrt(): boolean {
const state = this.effectState as { expansion: number };
const midRow = Math.floor(DISPLAY_HEIGHT / 2);
this.currentGrid = this.createEmptyGrid();
// Draw from middle expanding outward
const top = midRow - state.expansion;
const bottom = midRow + state.expansion;
for (
let row = Math.max(0, top);
row <= Math.min(DISPLAY_HEIGHT - 1, bottom);
row++
) {
for (let x = 0; x < WIDTH; x++) {
this.currentGrid[row][x] = this.finalGrid[row][x];
}
}
state.expansion++;
return state.expansion > DISPLAY_HEIGHT;
}
private tickGlitch(): boolean {
const state = this.effectState as { phase: number; glitchFrames: number };
if (state.phase < state.glitchFrames) {
// Glitch phase: show corrupted version
this.currentGrid = this.finalGrid.map((row) => {
const offset = Math.floor(Math.random() * 7) - 3;
const glitchRow = [...row];
// Random horizontal offset
if (Math.random() < 0.3) {
const shifted = glitchRow
.slice(offset)
.concat(glitchRow.slice(0, offset));
return shifted.slice(0, WIDTH);
}
// Random vertical swap
if (Math.random() < 0.2) {
const swapRow = Math.floor(Math.random() * DISPLAY_HEIGHT);
return [...this.finalGrid[swapRow]];
}
return glitchRow;
});
state.phase++;
return false;
}
// Final frame: show clean image
this.currentGrid = this.finalGrid.map((row) => [...row]);
return true;
}
private tickDissolve(): boolean {
const state = this.effectState as {
positions: [number, number][];
idx: number;
};
const pixelsPerFrame = 20;
for (let i = 0; i < pixelsPerFrame; i++) {
if (state.idx >= state.positions.length) return true;
const [row, x] = state.positions[state.idx];
this.currentGrid[row][x] = this.finalGrid[row][x];
state.idx++;
}
return false;
}
private updateDisplay(): void {
this.gridVersion++;
}
dispose(): void {
this.stopAnimation();
}
}

View file

@ -0,0 +1,139 @@
import type { AssistantMessage } from "@mariozechner/pi-ai";
import {
Container,
Markdown,
type MarkdownTheme,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a complete assistant message
*/
export class AssistantMessageComponent extends Container {
private contentContainer: Container;
private hideThinkingBlock: boolean;
private markdownTheme: MarkdownTheme;
private lastMessage?: AssistantMessage;
constructor(
message?: AssistantMessage,
hideThinkingBlock = false,
markdownTheme: MarkdownTheme = getMarkdownTheme(),
) {
super();
this.hideThinkingBlock = hideThinkingBlock;
this.markdownTheme = markdownTheme;
// Container for text/thinking content
this.contentContainer = new Container();
this.addChild(this.contentContainer);
if (message) {
this.updateContent(message);
}
}
override invalidate(): void {
super.invalidate();
if (this.lastMessage) {
this.updateContent(this.lastMessage);
}
}
setHideThinkingBlock(hide: boolean): void {
this.hideThinkingBlock = hide;
}
updateContent(message: AssistantMessage): void {
this.lastMessage = message;
// Clear content container
this.contentContainer.clear();
const hasVisibleContent = message.content.some(
(c) =>
(c.type === "text" && c.text.trim()) ||
(c.type === "thinking" && c.thinking.trim()),
);
if (hasVisibleContent) {
this.contentContainer.addChild(new Spacer(1));
}
// Render content in order
for (let i = 0; i < message.content.length; i++) {
const content = message.content[i];
if (content.type === "text" && content.text.trim()) {
// Assistant text messages with no background - trim the text
// Set paddingY=0 to avoid extra spacing before tool executions
this.contentContainer.addChild(
new Markdown(content.text.trim(), 1, 0, this.markdownTheme),
);
} else if (content.type === "thinking" && content.thinking.trim()) {
// Add spacing only when another visible assistant content block follows.
// This avoids a superfluous blank line before separately-rendered tool execution blocks.
const hasVisibleContentAfter = message.content
.slice(i + 1)
.some(
(c) =>
(c.type === "text" && c.text.trim()) ||
(c.type === "thinking" && c.thinking.trim()),
);
if (this.hideThinkingBlock) {
// Show static "Thinking..." label when hidden
this.contentContainer.addChild(
new Text(
theme.italic(theme.fg("thinkingText", "Thinking...")),
1,
0,
),
);
if (hasVisibleContentAfter) {
this.contentContainer.addChild(new Spacer(1));
}
} else {
// Thinking traces in thinkingText color, italic
this.contentContainer.addChild(
new Markdown(content.thinking.trim(), 1, 0, this.markdownTheme, {
color: (text: string) => theme.fg("thinkingText", text),
italic: true,
}),
);
if (hasVisibleContentAfter) {
this.contentContainer.addChild(new Spacer(1));
}
}
}
}
// Check if aborted - show after partial content
// But only if there are no tool calls (tool execution components will show the error)
const hasToolCalls = message.content.some((c) => c.type === "toolCall");
if (!hasToolCalls) {
if (message.stopReason === "aborted") {
const abortMessage =
message.errorMessage && message.errorMessage !== "Request was aborted"
? message.errorMessage
: "Operation aborted";
if (hasVisibleContent) {
this.contentContainer.addChild(new Spacer(1));
} else {
this.contentContainer.addChild(new Spacer(1));
}
this.contentContainer.addChild(
new Text(theme.fg("error", abortMessage), 1, 0),
);
} else if (message.stopReason === "error") {
const errorMsg = message.errorMessage || "Unknown error";
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(
new Text(theme.fg("error", `Error: ${errorMsg}`), 1, 0),
);
}
}
}
}

View file

@ -0,0 +1,241 @@
/**
* Component for displaying bash command execution with streaming output.
*/
import {
Container,
Loader,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import stripAnsi from "strip-ansi";
import {
DEFAULT_MAX_BYTES,
DEFAULT_MAX_LINES,
type TruncationResult,
truncateTail,
} from "../../../core/tools/truncate.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { editorKey, keyHint } from "./keybinding-hints.js";
import { truncateToVisualLines } from "./visual-truncate.js";
// Preview line limit when not expanded (matches tool execution behavior)
const PREVIEW_LINES = 20;
export class BashExecutionComponent extends Container {
private command: string;
private outputLines: string[] = [];
private status: "running" | "complete" | "cancelled" | "error" = "running";
private exitCode: number | undefined = undefined;
private loader: Loader;
private truncationResult?: TruncationResult;
private fullOutputPath?: string;
private expanded = false;
private contentContainer: Container;
private ui: TUI;
constructor(command: string, ui: TUI, excludeFromContext = false) {
super();
this.command = command;
this.ui = ui;
// Use dim border for excluded-from-context commands (!! prefix)
const colorKey = excludeFromContext ? "dim" : "bashMode";
const borderColor = (str: string) => theme.fg(colorKey, str);
// Add spacer
this.addChild(new Spacer(1));
// Top border
this.addChild(new DynamicBorder(borderColor));
// Content container (holds dynamic content between borders)
this.contentContainer = new Container();
this.addChild(this.contentContainer);
// Command header
const header = new Text(
theme.fg(colorKey, theme.bold(`$ ${command}`)),
1,
0,
);
this.contentContainer.addChild(header);
// Loader
this.loader = new Loader(
ui,
(spinner) => theme.fg(colorKey, spinner),
(text) => theme.fg("muted", text),
`Running... (${editorKey("selectCancel")} to cancel)`, // Plain text for loader
);
this.contentContainer.addChild(this.loader);
// Bottom border
this.addChild(new DynamicBorder(borderColor));
}
/**
* Set whether the output is expanded (shows full output) or collapsed (preview only).
*/
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
override invalidate(): void {
super.invalidate();
this.updateDisplay();
}
appendOutput(chunk: string): void {
// Strip ANSI codes and normalize line endings
// Note: binary data is already sanitized in tui-renderer.ts executeBashCommand
const clean = stripAnsi(chunk).replace(/\r\n/g, "\n").replace(/\r/g, "\n");
// Append to output lines
const newLines = clean.split("\n");
if (this.outputLines.length > 0 && newLines.length > 0) {
// Append first chunk to last line (incomplete line continuation)
this.outputLines[this.outputLines.length - 1] += newLines[0];
this.outputLines.push(...newLines.slice(1));
} else {
this.outputLines.push(...newLines);
}
this.updateDisplay();
}
setComplete(
exitCode: number | undefined,
cancelled: boolean,
truncationResult?: TruncationResult,
fullOutputPath?: string,
): void {
this.exitCode = exitCode;
this.status = cancelled
? "cancelled"
: exitCode !== 0 && exitCode !== undefined && exitCode !== null
? "error"
: "complete";
this.truncationResult = truncationResult;
this.fullOutputPath = fullOutputPath;
// Stop loader
this.loader.stop();
this.updateDisplay();
}
private updateDisplay(): void {
// Apply truncation for LLM context limits (same limits as bash tool)
const fullOutput = this.outputLines.join("\n");
const contextTruncation = truncateTail(fullOutput, {
maxLines: DEFAULT_MAX_LINES,
maxBytes: DEFAULT_MAX_BYTES,
});
// Get the lines to potentially display (after context truncation)
const availableLines = contextTruncation.content
? contextTruncation.content.split("\n")
: [];
// Apply preview truncation based on expanded state
const previewLogicalLines = availableLines.slice(-PREVIEW_LINES);
const hiddenLineCount = availableLines.length - previewLogicalLines.length;
// Rebuild content container
this.contentContainer.clear();
// Command header
const header = new Text(
theme.fg("bashMode", theme.bold(`$ ${this.command}`)),
1,
0,
);
this.contentContainer.addChild(header);
// Output
if (availableLines.length > 0) {
if (this.expanded) {
// Show all lines
const displayText = availableLines
.map((line) => theme.fg("muted", line))
.join("\n");
this.contentContainer.addChild(new Text(`\n${displayText}`, 1, 0));
} else {
// Use shared visual truncation utility
const styledOutput = previewLogicalLines
.map((line) => theme.fg("muted", line))
.join("\n");
const { visualLines } = truncateToVisualLines(
`\n${styledOutput}`,
PREVIEW_LINES,
this.ui.terminal.columns,
1, // padding
);
this.contentContainer.addChild({
render: () => visualLines,
invalidate: () => {},
});
}
}
// Loader or status
if (this.status === "running") {
this.contentContainer.addChild(this.loader);
} else {
const statusParts: string[] = [];
// Show how many lines are hidden (collapsed preview)
if (hiddenLineCount > 0) {
if (this.expanded) {
statusParts.push(`(${keyHint("expandTools", "to collapse")})`);
} else {
statusParts.push(
`${theme.fg("muted", `... ${hiddenLineCount} more lines`)} (${keyHint("expandTools", "to expand")})`,
);
}
}
if (this.status === "cancelled") {
statusParts.push(theme.fg("warning", "(cancelled)"));
} else if (this.status === "error") {
statusParts.push(theme.fg("error", `(exit ${this.exitCode})`));
}
// Add truncation warning (context truncation, not preview truncation)
const wasTruncated =
this.truncationResult?.truncated || contextTruncation.truncated;
if (wasTruncated && this.fullOutputPath) {
statusParts.push(
theme.fg(
"warning",
`Output truncated. Full output: ${this.fullOutputPath}`,
),
);
}
if (statusParts.length > 0) {
this.contentContainer.addChild(
new Text(`\n${statusParts.join("\n")}`, 1, 0),
);
}
}
}
/**
* Get the raw output for creating BashExecutionMessage.
*/
getOutput(): string {
return this.outputLines.join("\n");
}
/**
* Get the command that was executed.
*/
getCommand(): string {
return this.command;
}
}

View file

@ -0,0 +1,78 @@
import {
CancellableLoader,
Container,
Loader,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import type { Theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint } from "./keybinding-hints.js";
/** Loader wrapped with borders for extension UI */
export class BorderedLoader extends Container {
private loader: CancellableLoader | Loader;
private cancellable: boolean;
private signalController?: AbortController;
constructor(
tui: TUI,
theme: Theme,
message: string,
options?: { cancellable?: boolean },
) {
super();
this.cancellable = options?.cancellable ?? true;
const borderColor = (s: string) => theme.fg("border", s);
this.addChild(new DynamicBorder(borderColor));
if (this.cancellable) {
this.loader = new CancellableLoader(
tui,
(s) => theme.fg("accent", s),
(s) => theme.fg("muted", s),
message,
);
} else {
this.signalController = new AbortController();
this.loader = new Loader(
tui,
(s) => theme.fg("accent", s),
(s) => theme.fg("muted", s),
message,
);
}
this.addChild(this.loader);
if (this.cancellable) {
this.addChild(new Spacer(1));
this.addChild(new Text(keyHint("selectCancel", "cancel"), 1, 0));
}
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder(borderColor));
}
get signal(): AbortSignal {
if (this.cancellable) {
return (this.loader as CancellableLoader).signal;
}
return this.signalController?.signal ?? new AbortController().signal;
}
set onAbort(fn: (() => void) | undefined) {
if (this.cancellable) {
(this.loader as CancellableLoader).onAbort = fn;
}
}
handleInput(data: string): void {
if (this.cancellable) {
(this.loader as CancellableLoader).handleInput(data);
}
}
dispose(): void {
if ("dispose" in this.loader && typeof this.loader.dispose === "function") {
this.loader.dispose();
}
}
}

View file

@ -0,0 +1,67 @@
import {
Box,
Markdown,
type MarkdownTheme,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import type { BranchSummaryMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
import { editorKey } from "./keybinding-hints.js";
/**
* Component that renders a branch summary message with collapsed/expanded state.
* Uses same background color as custom messages for visual consistency.
*/
export class BranchSummaryMessageComponent extends Box {
private expanded = false;
private message: BranchSummaryMessage;
private markdownTheme: MarkdownTheme;
constructor(
message: BranchSummaryMessage,
markdownTheme: MarkdownTheme = getMarkdownTheme(),
) {
super(1, 1, (t) => theme.bg("customMessageBg", t));
this.message = message;
this.markdownTheme = markdownTheme;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
override invalidate(): void {
super.invalidate();
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
const label = theme.fg("customMessageLabel", `\x1b[1m[branch]\x1b[22m`);
this.addChild(new Text(label, 0, 0));
this.addChild(new Spacer(1));
if (this.expanded) {
const header = "**Branch Summary**\n\n";
this.addChild(
new Markdown(header + this.message.summary, 0, 0, this.markdownTheme, {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
} else {
this.addChild(
new Text(
theme.fg("customMessageText", "Branch summary (") +
theme.fg("dim", editorKey("expandTools")) +
theme.fg("customMessageText", " to expand)"),
0,
0,
),
);
}
}
}

View file

@ -0,0 +1,68 @@
import {
Box,
Markdown,
type MarkdownTheme,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import type { CompactionSummaryMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
import { editorKey } from "./keybinding-hints.js";
/**
* Component that renders a compaction message with collapsed/expanded state.
* Uses same background color as custom messages for visual consistency.
*/
export class CompactionSummaryMessageComponent extends Box {
private expanded = false;
private message: CompactionSummaryMessage;
private markdownTheme: MarkdownTheme;
constructor(
message: CompactionSummaryMessage,
markdownTheme: MarkdownTheme = getMarkdownTheme(),
) {
super(1, 1, (t) => theme.bg("customMessageBg", t));
this.message = message;
this.markdownTheme = markdownTheme;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
override invalidate(): void {
super.invalidate();
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
const tokenStr = this.message.tokensBefore.toLocaleString();
const label = theme.fg("customMessageLabel", `\x1b[1m[compaction]\x1b[22m`);
this.addChild(new Text(label, 0, 0));
this.addChild(new Spacer(1));
if (this.expanded) {
const header = `**Compacted from ${tokenStr} tokens**\n\n`;
this.addChild(
new Markdown(header + this.message.summary, 0, 0, this.markdownTheme, {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
} else {
this.addChild(
new Text(
theme.fg("customMessageText", `Compacted from ${tokenStr} tokens (`) +
theme.fg("dim", editorKey("expandTools")) +
theme.fg("customMessageText", " to expand)"),
0,
0,
),
);
}
}
}

View file

@ -0,0 +1,669 @@
/**
* TUI component for managing package resources (enable/disable)
*/
import { basename, dirname, join, relative } from "node:path";
import {
type Component,
Container,
type Focusable,
getEditorKeybindings,
Input,
matchesKey,
Spacer,
truncateToWidth,
visibleWidth,
} from "@mariozechner/pi-tui";
import { CONFIG_DIR_NAME } from "../../../config.js";
import type {
PathMetadata,
ResolvedPaths,
ResolvedResource,
} from "../../../core/package-manager.js";
import type {
PackageSource,
SettingsManager,
} from "../../../core/settings-manager.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { rawKeyHint } from "./keybinding-hints.js";
type ResourceType = "extensions" | "skills" | "prompts" | "themes";
const RESOURCE_TYPE_LABELS: Record<ResourceType, string> = {
extensions: "Extensions",
skills: "Skills",
prompts: "Prompts",
themes: "Themes",
};
interface ResourceItem {
path: string;
enabled: boolean;
metadata: PathMetadata;
resourceType: ResourceType;
displayName: string;
groupKey: string;
subgroupKey: string;
}
interface ResourceSubgroup {
type: ResourceType;
label: string;
items: ResourceItem[];
}
interface ResourceGroup {
key: string;
label: string;
scope: "user" | "project" | "temporary";
origin: "package" | "top-level";
source: string;
subgroups: ResourceSubgroup[];
}
function getGroupLabel(metadata: PathMetadata): string {
if (metadata.origin === "package") {
return `${metadata.source} (${metadata.scope})`;
}
// Top-level resources
if (metadata.source === "auto") {
return metadata.scope === "user" ? "User (~/.pi/agent/)" : "Project (.pi/)";
}
return metadata.scope === "user" ? "User settings" : "Project settings";
}
function buildGroups(resolved: ResolvedPaths): ResourceGroup[] {
const groupMap = new Map<string, ResourceGroup>();
const addToGroup = (
resources: ResolvedResource[],
resourceType: ResourceType,
) => {
for (const res of resources) {
const { path, enabled, metadata } = res;
const groupKey = `${metadata.origin}:${metadata.scope}:${metadata.source}`;
if (!groupMap.has(groupKey)) {
groupMap.set(groupKey, {
key: groupKey,
label: getGroupLabel(metadata),
scope: metadata.scope,
origin: metadata.origin,
source: metadata.source,
subgroups: [],
});
}
const group = groupMap.get(groupKey)!;
const subgroupKey = `${groupKey}:${resourceType}`;
let subgroup = group.subgroups.find((sg) => sg.type === resourceType);
if (!subgroup) {
subgroup = {
type: resourceType,
label: RESOURCE_TYPE_LABELS[resourceType],
items: [],
};
group.subgroups.push(subgroup);
}
const fileName = basename(path);
const parentFolder = basename(dirname(path));
let displayName: string;
if (resourceType === "extensions" && parentFolder !== "extensions") {
displayName = `${parentFolder}/${fileName}`;
} else if (resourceType === "skills" && fileName === "SKILL.md") {
displayName = parentFolder;
} else {
displayName = fileName;
}
subgroup.items.push({
path,
enabled,
metadata,
resourceType,
displayName,
groupKey,
subgroupKey,
});
}
};
addToGroup(resolved.extensions, "extensions");
addToGroup(resolved.skills, "skills");
addToGroup(resolved.prompts, "prompts");
addToGroup(resolved.themes, "themes");
// Sort groups: packages first, then top-level; user before project
const groups = Array.from(groupMap.values());
groups.sort((a, b) => {
if (a.origin !== b.origin) {
return a.origin === "package" ? -1 : 1;
}
if (a.scope !== b.scope) {
return a.scope === "user" ? -1 : 1;
}
return a.source.localeCompare(b.source);
});
// Sort subgroups within each group by type order, and items by name
const typeOrder: Record<ResourceType, number> = {
extensions: 0,
skills: 1,
prompts: 2,
themes: 3,
};
for (const group of groups) {
group.subgroups.sort((a, b) => typeOrder[a.type] - typeOrder[b.type]);
for (const subgroup of group.subgroups) {
subgroup.items.sort((a, b) => a.displayName.localeCompare(b.displayName));
}
}
return groups;
}
type FlatEntry =
| { type: "group"; group: ResourceGroup }
| { type: "subgroup"; subgroup: ResourceSubgroup; group: ResourceGroup }
| { type: "item"; item: ResourceItem };
class ConfigSelectorHeader implements Component {
invalidate(): void {}
render(width: number): string[] {
const title = theme.bold("Resource Configuration");
const sep = theme.fg("muted", " · ");
const hint =
rawKeyHint("space", "toggle") + sep + rawKeyHint("esc", "close");
const hintWidth = visibleWidth(hint);
const titleWidth = visibleWidth(title);
const spacing = Math.max(1, width - titleWidth - hintWidth);
return [
truncateToWidth(`${title}${" ".repeat(spacing)}${hint}`, width, ""),
theme.fg("muted", "Type to filter resources"),
];
}
}
class ResourceList implements Component, Focusable {
private groups: ResourceGroup[];
private flatItems: FlatEntry[] = [];
private filteredItems: FlatEntry[] = [];
private selectedIndex = 0;
private searchInput: Input;
private maxVisible = 15;
private settingsManager: SettingsManager;
private cwd: string;
private agentDir: string;
public onCancel?: () => void;
public onExit?: () => void;
public onToggle?: (item: ResourceItem, newEnabled: boolean) => void;
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.searchInput.focused = value;
}
constructor(
groups: ResourceGroup[],
settingsManager: SettingsManager,
cwd: string,
agentDir: string,
) {
this.groups = groups;
this.settingsManager = settingsManager;
this.cwd = cwd;
this.agentDir = agentDir;
this.searchInput = new Input();
this.buildFlatList();
this.filteredItems = [...this.flatItems];
}
private buildFlatList(): void {
this.flatItems = [];
for (const group of this.groups) {
this.flatItems.push({ type: "group", group });
for (const subgroup of group.subgroups) {
this.flatItems.push({ type: "subgroup", subgroup, group });
for (const item of subgroup.items) {
this.flatItems.push({ type: "item", item });
}
}
}
// Start selection on first item (not header)
this.selectedIndex = this.flatItems.findIndex((e) => e.type === "item");
if (this.selectedIndex < 0) this.selectedIndex = 0;
}
private findNextItem(fromIndex: number, direction: 1 | -1): number {
let idx = fromIndex + direction;
while (idx >= 0 && idx < this.filteredItems.length) {
if (this.filteredItems[idx].type === "item") {
return idx;
}
idx += direction;
}
return fromIndex; // Stay at current if no item found
}
private filterItems(query: string): void {
if (!query.trim()) {
this.filteredItems = [...this.flatItems];
this.selectFirstItem();
return;
}
const lowerQuery = query.toLowerCase();
const matchingItems = new Set<ResourceItem>();
const matchingSubgroups = new Set<ResourceSubgroup>();
const matchingGroups = new Set<ResourceGroup>();
for (const entry of this.flatItems) {
if (entry.type === "item") {
const item = entry.item;
if (
item.displayName.toLowerCase().includes(lowerQuery) ||
item.resourceType.toLowerCase().includes(lowerQuery) ||
item.path.toLowerCase().includes(lowerQuery)
) {
matchingItems.add(item);
}
}
}
// Find which subgroups and groups contain matching items
for (const group of this.groups) {
for (const subgroup of group.subgroups) {
for (const item of subgroup.items) {
if (matchingItems.has(item)) {
matchingSubgroups.add(subgroup);
matchingGroups.add(group);
}
}
}
}
this.filteredItems = [];
for (const entry of this.flatItems) {
if (entry.type === "group" && matchingGroups.has(entry.group)) {
this.filteredItems.push(entry);
} else if (
entry.type === "subgroup" &&
matchingSubgroups.has(entry.subgroup)
) {
this.filteredItems.push(entry);
} else if (entry.type === "item" && matchingItems.has(entry.item)) {
this.filteredItems.push(entry);
}
}
this.selectFirstItem();
}
private selectFirstItem(): void {
const firstItemIndex = this.filteredItems.findIndex(
(e) => e.type === "item",
);
this.selectedIndex = firstItemIndex >= 0 ? firstItemIndex : 0;
}
updateItem(item: ResourceItem, enabled: boolean): void {
item.enabled = enabled;
// Update in groups too
for (const group of this.groups) {
for (const subgroup of group.subgroups) {
const found = subgroup.items.find(
(i) => i.path === item.path && i.resourceType === item.resourceType,
);
if (found) {
found.enabled = enabled;
return;
}
}
}
}
invalidate(): void {}
render(width: number): string[] {
const lines: string[] = [];
// Search input
lines.push(...this.searchInput.render(width));
lines.push("");
if (this.filteredItems.length === 0) {
lines.push(theme.fg("muted", " No resources found"));
return lines;
}
// Calculate visible range
const startIndex = Math.max(
0,
Math.min(
this.selectedIndex - Math.floor(this.maxVisible / 2),
this.filteredItems.length - this.maxVisible,
),
);
const endIndex = Math.min(
startIndex + this.maxVisible,
this.filteredItems.length,
);
for (let i = startIndex; i < endIndex; i++) {
const entry = this.filteredItems[i];
const isSelected = i === this.selectedIndex;
if (entry.type === "group") {
// Main group header (no cursor)
const groupLine = theme.fg("accent", theme.bold(entry.group.label));
lines.push(truncateToWidth(` ${groupLine}`, width, ""));
} else if (entry.type === "subgroup") {
// Subgroup header (indented, no cursor)
const subgroupLine = theme.fg("muted", entry.subgroup.label);
lines.push(truncateToWidth(` ${subgroupLine}`, width, ""));
} else {
// Resource item (cursor only on items)
const item = entry.item;
const cursor = isSelected ? "> " : " ";
const checkbox = item.enabled
? theme.fg("success", "[x]")
: theme.fg("dim", "[ ]");
const name = isSelected
? theme.bold(item.displayName)
: item.displayName;
lines.push(
truncateToWidth(`${cursor} ${checkbox} ${name}`, width, "..."),
);
}
}
// Scroll indicator
if (startIndex > 0 || endIndex < this.filteredItems.length) {
lines.push(
theme.fg(
"dim",
` (${this.selectedIndex + 1}/${this.filteredItems.length})`,
),
);
}
return lines;
}
handleInput(data: string): void {
const kb = getEditorKeybindings();
if (kb.matches(data, "selectUp")) {
this.selectedIndex = this.findNextItem(this.selectedIndex, -1);
return;
}
if (kb.matches(data, "selectDown")) {
this.selectedIndex = this.findNextItem(this.selectedIndex, 1);
return;
}
if (kb.matches(data, "selectPageUp")) {
// Jump up by maxVisible, then find nearest item
let target = Math.max(0, this.selectedIndex - this.maxVisible);
while (
target < this.filteredItems.length &&
this.filteredItems[target].type !== "item"
) {
target++;
}
if (target < this.filteredItems.length) {
this.selectedIndex = target;
}
return;
}
if (kb.matches(data, "selectPageDown")) {
// Jump down by maxVisible, then find nearest item
let target = Math.min(
this.filteredItems.length - 1,
this.selectedIndex + this.maxVisible,
);
while (target >= 0 && this.filteredItems[target].type !== "item") {
target--;
}
if (target >= 0) {
this.selectedIndex = target;
}
return;
}
if (kb.matches(data, "selectCancel")) {
this.onCancel?.();
return;
}
if (matchesKey(data, "ctrl+c")) {
this.onExit?.();
return;
}
if (data === " " || kb.matches(data, "selectConfirm")) {
const entry = this.filteredItems[this.selectedIndex];
if (entry?.type === "item") {
const newEnabled = !entry.item.enabled;
this.toggleResource(entry.item, newEnabled);
this.updateItem(entry.item, newEnabled);
this.onToggle?.(entry.item, newEnabled);
}
return;
}
// Pass to search input
this.searchInput.handleInput(data);
this.filterItems(this.searchInput.getValue());
}
private toggleResource(item: ResourceItem, enabled: boolean): void {
if (item.metadata.origin === "top-level") {
this.toggleTopLevelResource(item, enabled);
} else {
this.togglePackageResource(item, enabled);
}
}
private toggleTopLevelResource(item: ResourceItem, enabled: boolean): void {
const scope = item.metadata.scope as "user" | "project";
const settings =
scope === "project"
? this.settingsManager.getProjectSettings()
: this.settingsManager.getGlobalSettings();
const arrayKey = item.resourceType as
| "extensions"
| "skills"
| "prompts"
| "themes";
const current = (settings[arrayKey] ?? []) as string[];
// Generate pattern for this resource
const pattern = this.getResourcePattern(item);
const disablePattern = `-${pattern}`;
const enablePattern = `+${pattern}`;
// Filter out existing patterns for this resource
const updated = current.filter((p) => {
const stripped =
p.startsWith("!") || p.startsWith("+") || p.startsWith("-")
? p.slice(1)
: p;
return stripped !== pattern;
});
if (enabled) {
updated.push(enablePattern);
} else {
updated.push(disablePattern);
}
if (scope === "project") {
if (arrayKey === "extensions") {
this.settingsManager.setProjectExtensionPaths(updated);
} else if (arrayKey === "skills") {
this.settingsManager.setProjectSkillPaths(updated);
} else if (arrayKey === "prompts") {
this.settingsManager.setProjectPromptTemplatePaths(updated);
} else if (arrayKey === "themes") {
this.settingsManager.setProjectThemePaths(updated);
}
} else {
if (arrayKey === "extensions") {
this.settingsManager.setExtensionPaths(updated);
} else if (arrayKey === "skills") {
this.settingsManager.setSkillPaths(updated);
} else if (arrayKey === "prompts") {
this.settingsManager.setPromptTemplatePaths(updated);
} else if (arrayKey === "themes") {
this.settingsManager.setThemePaths(updated);
}
}
}
private togglePackageResource(item: ResourceItem, enabled: boolean): void {
const scope = item.metadata.scope as "user" | "project";
const settings =
scope === "project"
? this.settingsManager.getProjectSettings()
: this.settingsManager.getGlobalSettings();
const packages = [...(settings.packages ?? [])] as PackageSource[];
const pkgIndex = packages.findIndex((pkg) => {
const source = typeof pkg === "string" ? pkg : pkg.source;
return source === item.metadata.source;
});
if (pkgIndex === -1) return;
let pkg = packages[pkgIndex];
// Convert string to object form if needed
if (typeof pkg === "string") {
pkg = { source: pkg };
packages[pkgIndex] = pkg;
}
// Get the resource array for this type
const arrayKey = item.resourceType as
| "extensions"
| "skills"
| "prompts"
| "themes";
const current = (pkg[arrayKey] ?? []) as string[];
// Generate pattern relative to package root
const pattern = this.getPackageResourcePattern(item);
const disablePattern = `-${pattern}`;
const enablePattern = `+${pattern}`;
// Filter out existing patterns for this resource
const updated = current.filter((p) => {
const stripped =
p.startsWith("!") || p.startsWith("+") || p.startsWith("-")
? p.slice(1)
: p;
return stripped !== pattern;
});
if (enabled) {
updated.push(enablePattern);
} else {
updated.push(disablePattern);
}
(pkg as Record<string, unknown>)[arrayKey] =
updated.length > 0 ? updated : undefined;
// Clean up empty filter object
const hasFilters = ["extensions", "skills", "prompts", "themes"].some(
(k) => (pkg as Record<string, unknown>)[k] !== undefined,
);
if (!hasFilters) {
packages[pkgIndex] = (pkg as { source: string }).source;
}
if (scope === "project") {
this.settingsManager.setProjectPackages(packages);
} else {
this.settingsManager.setPackages(packages);
}
}
private getTopLevelBaseDir(scope: "user" | "project"): string {
return scope === "project"
? join(this.cwd, CONFIG_DIR_NAME)
: this.agentDir;
}
private getResourcePattern(item: ResourceItem): string {
const scope = item.metadata.scope as "user" | "project";
const baseDir = this.getTopLevelBaseDir(scope);
return relative(baseDir, item.path);
}
private getPackageResourcePattern(item: ResourceItem): string {
const baseDir = item.metadata.baseDir ?? dirname(item.path);
return relative(baseDir, item.path);
}
}
export class ConfigSelectorComponent extends Container implements Focusable {
private resourceList: ResourceList;
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.resourceList.focused = value;
}
constructor(
resolvedPaths: ResolvedPaths,
settingsManager: SettingsManager,
cwd: string,
agentDir: string,
onClose: () => void,
onExit: () => void,
requestRender: () => void,
) {
super();
const groups = buildGroups(resolvedPaths);
// Add header
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
this.addChild(new ConfigSelectorHeader());
this.addChild(new Spacer(1));
// Resource list
this.resourceList = new ResourceList(
groups,
settingsManager,
cwd,
agentDir,
);
this.resourceList.onCancel = onClose;
this.resourceList.onExit = onExit;
this.resourceList.onToggle = () => requestRender();
this.addChild(this.resourceList);
// Bottom border
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
}
getResourceList(): ResourceList {
return this.resourceList;
}
}

View file

@ -0,0 +1,38 @@
/**
* Reusable countdown timer for dialog components.
*/
import type { TUI } from "@mariozechner/pi-tui";
export class CountdownTimer {
private intervalId: ReturnType<typeof setInterval> | undefined;
private remainingSeconds: number;
constructor(
timeoutMs: number,
private tui: TUI | undefined,
private onTick: (seconds: number) => void,
private onExpire: () => void,
) {
this.remainingSeconds = Math.ceil(timeoutMs / 1000);
this.onTick(this.remainingSeconds);
this.intervalId = setInterval(() => {
this.remainingSeconds--;
this.onTick(this.remainingSeconds);
this.tui?.requestRender();
if (this.remainingSeconds <= 0) {
this.dispose();
this.onExpire();
}
}, 1000);
}
dispose(): void {
if (this.intervalId) {
clearInterval(this.intervalId);
this.intervalId = undefined;
}
}
}

View file

@ -0,0 +1,97 @@
import {
Editor,
type EditorOptions,
type EditorTheme,
type TUI,
} from "@mariozechner/pi-tui";
import type {
AppAction,
KeybindingsManager,
} from "../../../core/keybindings.js";
/**
* Custom editor that handles app-level keybindings for coding-agent.
*/
export class CustomEditor extends Editor {
private keybindings: KeybindingsManager;
public actionHandlers: Map<AppAction, () => void> = new Map();
// Special handlers that can be dynamically replaced
public onEscape?: () => void;
public onCtrlD?: () => void;
public onPasteImage?: () => void;
/** Handler for extension-registered shortcuts. Returns true if handled. */
public onExtensionShortcut?: (data: string) => boolean;
constructor(
tui: TUI,
theme: EditorTheme,
keybindings: KeybindingsManager,
options?: EditorOptions,
) {
super(tui, theme, options);
this.keybindings = keybindings;
}
/**
* Register a handler for an app action.
*/
onAction(action: AppAction, handler: () => void): void {
this.actionHandlers.set(action, handler);
}
handleInput(data: string): void {
// Check extension-registered shortcuts first
if (this.onExtensionShortcut?.(data)) {
return;
}
// Check for paste image keybinding
if (this.keybindings.matches(data, "pasteImage")) {
this.onPasteImage?.();
return;
}
// Check app keybindings first
// Escape/interrupt - only if autocomplete is NOT active
if (this.keybindings.matches(data, "interrupt")) {
if (!this.isShowingAutocomplete()) {
// Use dynamic onEscape if set, otherwise registered handler
const handler = this.onEscape ?? this.actionHandlers.get("interrupt");
if (handler) {
handler();
return;
}
}
// Let parent handle escape for autocomplete cancellation
super.handleInput(data);
return;
}
// Exit (Ctrl+D) - only when editor is empty
if (this.keybindings.matches(data, "exit")) {
if (this.getText().length === 0) {
const handler = this.onCtrlD ?? this.actionHandlers.get("exit");
if (handler) handler();
return;
}
// Fall through to editor handling for delete-char-forward when not empty
}
// Check all other app actions
for (const [action, handler] of this.actionHandlers) {
if (
action !== "interrupt" &&
action !== "exit" &&
this.keybindings.matches(data, action)
) {
handler();
return;
}
}
// Pass to parent for editor handling
super.handleInput(data);
}
}

View file

@ -0,0 +1,113 @@
import type { TextContent } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui";
import {
Box,
Container,
Markdown,
type MarkdownTheme,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import type { MessageRenderer } from "../../../core/extensions/types.js";
import type { CustomMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a custom message entry from extensions.
* Uses distinct styling to differentiate from user messages.
*/
export class CustomMessageComponent extends Container {
private message: CustomMessage<unknown>;
private customRenderer?: MessageRenderer;
private box: Box;
private customComponent?: Component;
private markdownTheme: MarkdownTheme;
private _expanded = false;
constructor(
message: CustomMessage<unknown>,
customRenderer?: MessageRenderer,
markdownTheme: MarkdownTheme = getMarkdownTheme(),
) {
super();
this.message = message;
this.customRenderer = customRenderer;
this.markdownTheme = markdownTheme;
this.addChild(new Spacer(1));
// Create box with purple background (used for default rendering)
this.box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
this.rebuild();
}
setExpanded(expanded: boolean): void {
if (this._expanded !== expanded) {
this._expanded = expanded;
this.rebuild();
}
}
override invalidate(): void {
super.invalidate();
this.rebuild();
}
private rebuild(): void {
// Remove previous content component
if (this.customComponent) {
this.removeChild(this.customComponent);
this.customComponent = undefined;
}
this.removeChild(this.box);
// Try custom renderer first - it handles its own styling
if (this.customRenderer) {
try {
const component = this.customRenderer(
this.message,
{ expanded: this._expanded },
theme,
);
if (component) {
// Custom renderer provides its own styled component
this.customComponent = component;
this.addChild(component);
return;
}
} catch {
// Fall through to default rendering
}
}
// Default rendering uses our box
this.addChild(this.box);
this.box.clear();
// Default rendering: label + content
const label = theme.fg(
"customMessageLabel",
`\x1b[1m[${this.message.customType}]\x1b[22m`,
);
this.box.addChild(new Text(label, 0, 0));
this.box.addChild(new Spacer(1));
// Extract text content
let text: string;
if (typeof this.message.content === "string") {
text = this.message.content;
} else {
text = this.message.content
.filter((c): c is TextContent => c.type === "text")
.map((c) => c.text)
.join("\n");
}
this.box.addChild(
new Markdown(text, 0, 0, this.markdownTheme, {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
}
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,179 @@
import * as Diff from "diff";
import { theme } from "../theme/theme.js";
/**
* Parse diff line to extract prefix, line number, and content.
* Format: "+123 content" or "-123 content" or " 123 content" or " ..."
*/
function parseDiffLine(
line: string,
): { prefix: string; lineNum: string; content: string } | null {
const match = line.match(/^([+-\s])(\s*\d*)\s(.*)$/);
if (!match) return null;
return { prefix: match[1], lineNum: match[2], content: match[3] };
}
/**
* Replace tabs with spaces for consistent rendering.
*/
function replaceTabs(text: string): string {
return text.replace(/\t/g, " ");
}
/**
* Compute word-level diff and render with inverse on changed parts.
* Uses diffWords which groups whitespace with adjacent words for cleaner highlighting.
* Strips leading whitespace from inverse to avoid highlighting indentation.
*/
function renderIntraLineDiff(
oldContent: string,
newContent: string,
): { removedLine: string; addedLine: string } {
const wordDiff = Diff.diffWords(oldContent, newContent);
let removedLine = "";
let addedLine = "";
let isFirstRemoved = true;
let isFirstAdded = true;
for (const part of wordDiff) {
if (part.removed) {
let value = part.value;
// Strip leading whitespace from the first removed part
if (isFirstRemoved) {
const leadingWs = value.match(/^(\s*)/)?.[1] || "";
value = value.slice(leadingWs.length);
removedLine += leadingWs;
isFirstRemoved = false;
}
if (value) {
removedLine += theme.inverse(value);
}
} else if (part.added) {
let value = part.value;
// Strip leading whitespace from the first added part
if (isFirstAdded) {
const leadingWs = value.match(/^(\s*)/)?.[1] || "";
value = value.slice(leadingWs.length);
addedLine += leadingWs;
isFirstAdded = false;
}
if (value) {
addedLine += theme.inverse(value);
}
} else {
removedLine += part.value;
addedLine += part.value;
}
}
return { removedLine, addedLine };
}
export interface RenderDiffOptions {
/** File path (unused, kept for API compatibility) */
filePath?: string;
}
/**
* Render a diff string with colored lines and intra-line change highlighting.
* - Context lines: dim/gray
* - Removed lines: red, with inverse on changed tokens
* - Added lines: green, with inverse on changed tokens
*/
export function renderDiff(
diffText: string,
_options: RenderDiffOptions = {},
): string {
const lines = diffText.split("\n");
const result: string[] = [];
let i = 0;
while (i < lines.length) {
const line = lines[i];
const parsed = parseDiffLine(line);
if (!parsed) {
result.push(theme.fg("toolDiffContext", line));
i++;
continue;
}
if (parsed.prefix === "-") {
// Collect consecutive removed lines
const removedLines: { lineNum: string; content: string }[] = [];
while (i < lines.length) {
const p = parseDiffLine(lines[i]);
if (!p || p.prefix !== "-") break;
removedLines.push({ lineNum: p.lineNum, content: p.content });
i++;
}
// Collect consecutive added lines
const addedLines: { lineNum: string; content: string }[] = [];
while (i < lines.length) {
const p = parseDiffLine(lines[i]);
if (!p || p.prefix !== "+") break;
addedLines.push({ lineNum: p.lineNum, content: p.content });
i++;
}
// Only do intra-line diffing when there's exactly one removed and one added line
// (indicating a single line modification). Otherwise, show lines as-is.
if (removedLines.length === 1 && addedLines.length === 1) {
const removed = removedLines[0];
const added = addedLines[0];
const { removedLine, addedLine } = renderIntraLineDiff(
replaceTabs(removed.content),
replaceTabs(added.content),
);
result.push(
theme.fg("toolDiffRemoved", `-${removed.lineNum} ${removedLine}`),
);
result.push(
theme.fg("toolDiffAdded", `+${added.lineNum} ${addedLine}`),
);
} else {
// Show all removed lines first, then all added lines
for (const removed of removedLines) {
result.push(
theme.fg(
"toolDiffRemoved",
`-${removed.lineNum} ${replaceTabs(removed.content)}`,
),
);
}
for (const added of addedLines) {
result.push(
theme.fg(
"toolDiffAdded",
`+${added.lineNum} ${replaceTabs(added.content)}`,
),
);
}
}
} else if (parsed.prefix === "+") {
// Standalone added line
result.push(
theme.fg(
"toolDiffAdded",
`+${parsed.lineNum} ${replaceTabs(parsed.content)}`,
),
);
i++;
} else {
// Context line
result.push(
theme.fg(
"toolDiffContext",
` ${parsed.lineNum} ${replaceTabs(parsed.content)}`,
),
);
i++;
}
}
return result.join("\n");
}

View file

@ -0,0 +1,27 @@
import type { Component } from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
/**
* Dynamic border component that adjusts to viewport width.
*
* Note: When used from extensions loaded via jiti, the global `theme` may be undefined
* because jiti creates a separate module cache. Always pass an explicit color
* function when using DynamicBorder in components exported for extension use.
*/
export class DynamicBorder implements Component {
private color: (str: string) => string;
constructor(
color: (str: string) => string = (str) => theme.fg("border", str),
) {
this.color = color;
}
invalidate(): void {
// No cached state to invalidate currently
}
render(width: number): string[] {
return [this.color("─".repeat(Math.max(1, width)))];
}
}

View file

@ -0,0 +1,151 @@
/**
* Multi-line editor component for extensions.
* Supports Ctrl+G for external editor.
*/
import { spawnSync } from "node:child_process";
import * as fs from "node:fs";
import * as os from "node:os";
import * as path from "node:path";
import {
Container,
Editor,
type EditorOptions,
type Focusable,
getEditorKeybindings,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import type { KeybindingsManager } from "../../../core/keybindings.js";
import { getEditorTheme, theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { appKeyHint, keyHint } from "./keybinding-hints.js";
export class ExtensionEditorComponent extends Container implements Focusable {
private editor: Editor;
private onSubmitCallback: (value: string) => void;
private onCancelCallback: () => void;
private tui: TUI;
private keybindings: KeybindingsManager;
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.editor.focused = value;
}
constructor(
tui: TUI,
keybindings: KeybindingsManager,
title: string,
prefill: string | undefined,
onSubmit: (value: string) => void,
onCancel: () => void,
options?: EditorOptions,
) {
super();
this.tui = tui;
this.keybindings = keybindings;
this.onSubmitCallback = onSubmit;
this.onCancelCallback = onCancel;
// Add top border
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
// Add title
this.addChild(new Text(theme.fg("accent", title), 1, 0));
this.addChild(new Spacer(1));
// Create editor
this.editor = new Editor(tui, getEditorTheme(), options);
if (prefill) {
this.editor.setText(prefill);
}
// Wire up Enter to submit (Shift+Enter for newlines, like the main editor)
this.editor.onSubmit = (text: string) => {
this.onSubmitCallback(text);
};
this.addChild(this.editor);
this.addChild(new Spacer(1));
// Add hint
const hasExternalEditor = !!(process.env.VISUAL || process.env.EDITOR);
const hint =
keyHint("selectConfirm", "submit") +
" " +
keyHint("newLine", "newline") +
" " +
keyHint("selectCancel", "cancel") +
(hasExternalEditor
? ` ${appKeyHint(this.keybindings, "externalEditor", "external editor")}`
: "");
this.addChild(new Text(hint, 1, 0));
this.addChild(new Spacer(1));
// Add bottom border
this.addChild(new DynamicBorder());
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
// Escape or Ctrl+C to cancel
if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
return;
}
// External editor (app keybinding)
if (this.keybindings.matches(keyData, "externalEditor")) {
this.openExternalEditor();
return;
}
// Forward to editor
this.editor.handleInput(keyData);
}
private openExternalEditor(): void {
const editorCmd = process.env.VISUAL || process.env.EDITOR;
if (!editorCmd) {
return;
}
const currentText = this.editor.getText();
const tmpFile = path.join(
os.tmpdir(),
`pi-extension-editor-${Date.now()}.md`,
);
try {
fs.writeFileSync(tmpFile, currentText, "utf-8");
this.tui.stop();
const [editor, ...editorArgs] = editorCmd.split(" ");
const result = spawnSync(editor, [...editorArgs, tmpFile], {
stdio: "inherit",
});
if (result.status === 0) {
const newContent = fs.readFileSync(tmpFile, "utf-8").replace(/\n$/, "");
this.editor.setText(newContent);
}
} finally {
try {
fs.unlinkSync(tmpFile);
} catch {
// Ignore cleanup errors
}
this.tui.start();
// Force full re-render since external editor uses alternate screen
this.tui.requestRender(true);
}
}
}

View file

@ -0,0 +1,102 @@
/**
* Simple text input component for extensions.
*/
import {
Container,
type Focusable,
getEditorKeybindings,
Input,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
import { CountdownTimer } from "./countdown-timer.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint } from "./keybinding-hints.js";
export interface ExtensionInputOptions {
tui?: TUI;
timeout?: number;
}
export class ExtensionInputComponent extends Container implements Focusable {
private input: Input;
private onSubmitCallback: (value: string) => void;
private onCancelCallback: () => void;
private titleText: Text;
private baseTitle: string;
private countdown: CountdownTimer | undefined;
// Focusable implementation - propagate to input for IME cursor positioning
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.input.focused = value;
}
constructor(
title: string,
_placeholder: string | undefined,
onSubmit: (value: string) => void,
onCancel: () => void,
opts?: ExtensionInputOptions,
) {
super();
this.onSubmitCallback = onSubmit;
this.onCancelCallback = onCancel;
this.baseTitle = title;
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
this.titleText = new Text(theme.fg("accent", title), 1, 0);
this.addChild(this.titleText);
this.addChild(new Spacer(1));
if (opts?.timeout && opts.timeout > 0 && opts.tui) {
this.countdown = new CountdownTimer(
opts.timeout,
opts.tui,
(s) =>
this.titleText.setText(
theme.fg("accent", `${this.baseTitle} (${s}s)`),
),
() => this.onCancelCallback(),
);
}
this.input = new Input();
this.addChild(this.input);
this.addChild(new Spacer(1));
this.addChild(
new Text(
`${keyHint("selectConfirm", "submit")} ${keyHint("selectCancel", "cancel")}`,
1,
0,
),
);
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
if (kb.matches(keyData, "selectConfirm") || keyData === "\n") {
this.onSubmitCallback(this.input.getValue());
} else if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
} else {
this.input.handleInput(keyData);
}
}
dispose(): void {
this.countdown?.dispose();
}
}

View file

@ -0,0 +1,119 @@
/**
* Generic selector component for extensions.
* Displays a list of string options with keyboard navigation.
*/
import {
Container,
getEditorKeybindings,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
import { CountdownTimer } from "./countdown-timer.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint, rawKeyHint } from "./keybinding-hints.js";
export interface ExtensionSelectorOptions {
tui?: TUI;
timeout?: number;
}
export class ExtensionSelectorComponent extends Container {
private options: string[];
private selectedIndex = 0;
private listContainer: Container;
private onSelectCallback: (option: string) => void;
private onCancelCallback: () => void;
private titleText: Text;
private baseTitle: string;
private countdown: CountdownTimer | undefined;
constructor(
title: string,
options: string[],
onSelect: (option: string) => void,
onCancel: () => void,
opts?: ExtensionSelectorOptions,
) {
super();
this.options = options;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
this.baseTitle = title;
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
this.titleText = new Text(theme.fg("accent", title), 1, 0);
this.addChild(this.titleText);
this.addChild(new Spacer(1));
if (opts?.timeout && opts.timeout > 0 && opts.tui) {
this.countdown = new CountdownTimer(
opts.timeout,
opts.tui,
(s) =>
this.titleText.setText(
theme.fg("accent", `${this.baseTitle} (${s}s)`),
),
() => this.onCancelCallback(),
);
}
this.listContainer = new Container();
this.addChild(this.listContainer);
this.addChild(new Spacer(1));
this.addChild(
new Text(
rawKeyHint("↑↓", "navigate") +
" " +
keyHint("selectConfirm", "select") +
" " +
keyHint("selectCancel", "cancel"),
1,
0,
),
);
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
this.updateList();
}
private updateList(): void {
this.listContainer.clear();
for (let i = 0; i < this.options.length; i++) {
const isSelected = i === this.selectedIndex;
const text = isSelected
? theme.fg("accent", "→ ") + theme.fg("accent", this.options[i])
: ` ${theme.fg("text", this.options[i])}`;
this.listContainer.addChild(new Text(text, 1, 0));
}
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
if (kb.matches(keyData, "selectUp") || keyData === "k") {
this.selectedIndex = Math.max(0, this.selectedIndex - 1);
this.updateList();
} else if (kb.matches(keyData, "selectDown") || keyData === "j") {
this.selectedIndex = Math.min(
this.options.length - 1,
this.selectedIndex + 1,
);
this.updateList();
} else if (kb.matches(keyData, "selectConfirm") || keyData === "\n") {
const selected = this.options[this.selectedIndex];
if (selected) this.onSelectCallback(selected);
} else if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
}
}
dispose(): void {
this.countdown?.dispose();
}
}

View file

@ -0,0 +1,236 @@
import {
type Component,
truncateToWidth,
visibleWidth,
} from "@mariozechner/pi-tui";
import type { AgentSession } from "../../../core/agent-session.js";
import type { ReadonlyFooterDataProvider } from "../../../core/footer-data-provider.js";
import { theme } from "../theme/theme.js";
/**
* Sanitize text for display in a single-line status.
* Removes newlines, tabs, carriage returns, and other control characters.
*/
function sanitizeStatusText(text: string): string {
// Replace newlines, tabs, carriage returns with space, then collapse multiple spaces
return text
.replace(/[\r\n\t]/g, " ")
.replace(/ +/g, " ")
.trim();
}
/**
* Format token counts (similar to web-ui)
*/
function formatTokens(count: number): string {
if (count < 1000) return count.toString();
if (count < 10000) return `${(count / 1000).toFixed(1)}k`;
if (count < 1000000) return `${Math.round(count / 1000)}k`;
if (count < 10000000) return `${(count / 1000000).toFixed(1)}M`;
return `${Math.round(count / 1000000)}M`;
}
/**
* Footer component that shows pwd, token stats, and context usage.
* Computes token/context stats from session, gets git branch and extension statuses from provider.
*/
export class FooterComponent implements Component {
private autoCompactEnabled = true;
constructor(
private session: AgentSession,
private footerData: ReadonlyFooterDataProvider,
) {}
setAutoCompactEnabled(enabled: boolean): void {
this.autoCompactEnabled = enabled;
}
/**
* No-op: git branch caching now handled by provider.
* Kept for compatibility with existing call sites in interactive-mode.
*/
invalidate(): void {
// No-op: git branch is cached/invalidated by provider
}
/**
* Clean up resources.
* Git watcher cleanup now handled by provider.
*/
dispose(): void {
// Git watcher cleanup handled by provider
}
render(width: number): string[] {
const state = this.session.state;
// Calculate cumulative usage from ALL session entries (not just post-compaction messages)
let totalInput = 0;
let totalOutput = 0;
let totalCacheRead = 0;
let totalCacheWrite = 0;
let totalCost = 0;
for (const entry of this.session.sessionManager.getEntries()) {
if (entry.type === "message" && entry.message.role === "assistant") {
totalInput += entry.message.usage.input;
totalOutput += entry.message.usage.output;
totalCacheRead += entry.message.usage.cacheRead;
totalCacheWrite += entry.message.usage.cacheWrite;
totalCost += entry.message.usage.cost.total;
}
}
// Calculate context usage from session (handles compaction correctly).
// After compaction, tokens are unknown until the next LLM response.
const contextUsage = this.session.getContextUsage();
const contextWindow =
contextUsage?.contextWindow ?? state.model?.contextWindow ?? 0;
const contextPercentValue = contextUsage?.percent ?? 0;
const contextPercent =
contextUsage?.percent !== null ? contextPercentValue.toFixed(1) : "?";
// Replace home directory with ~
let pwd = process.cwd();
const home = process.env.HOME || process.env.USERPROFILE;
if (home && pwd.startsWith(home)) {
pwd = `~${pwd.slice(home.length)}`;
}
// Add git branch if available
const branch = this.footerData.getGitBranch();
if (branch) {
pwd = `${pwd} (${branch})`;
}
// Add session name if set
const sessionName = this.session.sessionManager.getSessionName();
if (sessionName) {
pwd = `${pwd}${sessionName}`;
}
// Build stats line
const statsParts = [];
if (totalInput) statsParts.push(`${formatTokens(totalInput)}`);
if (totalOutput) statsParts.push(`${formatTokens(totalOutput)}`);
if (totalCacheRead) statsParts.push(`R${formatTokens(totalCacheRead)}`);
if (totalCacheWrite) statsParts.push(`W${formatTokens(totalCacheWrite)}`);
// Show cost with "(sub)" indicator if using OAuth subscription
const usingSubscription = state.model
? this.session.modelRegistry.isUsingOAuth(state.model)
: false;
if (totalCost || usingSubscription) {
const costStr = `$${totalCost.toFixed(3)}${usingSubscription ? " (sub)" : ""}`;
statsParts.push(costStr);
}
// Colorize context percentage based on usage
let contextPercentStr: string;
const autoIndicator = this.autoCompactEnabled ? " (auto)" : "";
const contextPercentDisplay =
contextPercent === "?"
? `?/${formatTokens(contextWindow)}${autoIndicator}`
: `${contextPercent}%/${formatTokens(contextWindow)}${autoIndicator}`;
if (contextPercentValue > 90) {
contextPercentStr = theme.fg("error", contextPercentDisplay);
} else if (contextPercentValue > 70) {
contextPercentStr = theme.fg("warning", contextPercentDisplay);
} else {
contextPercentStr = contextPercentDisplay;
}
statsParts.push(contextPercentStr);
let statsLeft = statsParts.join(" ");
// Add model name on the right side, plus thinking level if model supports it
const modelName = state.model?.id || "no-model";
let statsLeftWidth = visibleWidth(statsLeft);
// If statsLeft is too wide, truncate it
if (statsLeftWidth > width) {
statsLeft = truncateToWidth(statsLeft, width, "...");
statsLeftWidth = visibleWidth(statsLeft);
}
// Calculate available space for padding (minimum 2 spaces between stats and model)
const minPadding = 2;
// Add thinking level indicator if model supports reasoning
let rightSideWithoutProvider = modelName;
if (state.model?.reasoning) {
const thinkingLevel = state.thinkingLevel || "off";
rightSideWithoutProvider =
thinkingLevel === "off"
? `${modelName} • thinking off`
: `${modelName}${thinkingLevel}`;
}
// Prepend the provider in parentheses if there are multiple providers and there's enough room
let rightSide = rightSideWithoutProvider;
if (this.footerData.getAvailableProviderCount() > 1 && state.model) {
rightSide = `(${state.model!.provider}) ${rightSideWithoutProvider}`;
if (statsLeftWidth + minPadding + visibleWidth(rightSide) > width) {
// Too wide, fall back
rightSide = rightSideWithoutProvider;
}
}
const rightSideWidth = visibleWidth(rightSide);
const totalNeeded = statsLeftWidth + minPadding + rightSideWidth;
let statsLine: string;
if (totalNeeded <= width) {
// Both fit - add padding to right-align model
const padding = " ".repeat(width - statsLeftWidth - rightSideWidth);
statsLine = statsLeft + padding + rightSide;
} else {
// Need to truncate right side
const availableForRight = width - statsLeftWidth - minPadding;
if (availableForRight > 0) {
const truncatedRight = truncateToWidth(
rightSide,
availableForRight,
"",
);
const truncatedRightWidth = visibleWidth(truncatedRight);
const padding = " ".repeat(
Math.max(0, width - statsLeftWidth - truncatedRightWidth),
);
statsLine = statsLeft + padding + truncatedRight;
} else {
// Not enough space for right side at all
statsLine = statsLeft;
}
}
// Apply dim to each part separately. statsLeft may contain color codes (for context %)
// that end with a reset, which would clear an outer dim wrapper. So we dim the parts
// before and after the colored section independently.
const dimStatsLeft = theme.fg("dim", statsLeft);
const remainder = statsLine.slice(statsLeft.length); // padding + rightSide
const dimRemainder = theme.fg("dim", remainder);
const pwdLine = truncateToWidth(
theme.fg("dim", pwd),
width,
theme.fg("dim", "..."),
);
const lines = [pwdLine, dimStatsLeft + dimRemainder];
// Add extension statuses on a single line, sorted by key alphabetically
const extensionStatuses = this.footerData.getExtensionStatuses();
if (extensionStatuses.size > 0) {
const sortedStatuses = Array.from(extensionStatuses.entries())
.sort(([a], [b]) => a.localeCompare(b))
.map(([, text]) => sanitizeStatusText(text));
const statusLine = sortedStatuses.join(" ");
// Truncate to terminal width with dim ellipsis for consistency with footer style
lines.push(truncateToWidth(statusLine, width, theme.fg("dim", "...")));
}
return lines;
}
}

View file

@ -0,0 +1,52 @@
// UI Components for extensions
export { ArminComponent } from "./armin.js";
export { AssistantMessageComponent } from "./assistant-message.js";
export { BashExecutionComponent } from "./bash-execution.js";
export { BorderedLoader } from "./bordered-loader.js";
export { BranchSummaryMessageComponent } from "./branch-summary-message.js";
export { CompactionSummaryMessageComponent } from "./compaction-summary-message.js";
export { CustomEditor } from "./custom-editor.js";
export { CustomMessageComponent } from "./custom-message.js";
export { DaxnutsComponent } from "./daxnuts.js";
export { type RenderDiffOptions, renderDiff } from "./diff.js";
export { DynamicBorder } from "./dynamic-border.js";
export { ExtensionEditorComponent } from "./extension-editor.js";
export { ExtensionInputComponent } from "./extension-input.js";
export { ExtensionSelectorComponent } from "./extension-selector.js";
export { FooterComponent } from "./footer.js";
export {
appKey,
appKeyHint,
editorKey,
keyHint,
rawKeyHint,
} from "./keybinding-hints.js";
export { LoginDialogComponent } from "./login-dialog.js";
export { ModelSelectorComponent } from "./model-selector.js";
export { OAuthSelectorComponent } from "./oauth-selector.js";
export {
type ModelsCallbacks,
type ModelsConfig,
ScopedModelsSelectorComponent,
} from "./scoped-models-selector.js";
export { SessionSelectorComponent } from "./session-selector.js";
export {
type SettingsCallbacks,
type SettingsConfig,
SettingsSelectorComponent,
} from "./settings-selector.js";
export { ShowImagesSelectorComponent } from "./show-images-selector.js";
export { SkillInvocationMessageComponent } from "./skill-invocation-message.js";
export { ThemeSelectorComponent } from "./theme-selector.js";
export { ThinkingSelectorComponent } from "./thinking-selector.js";
export {
ToolExecutionComponent,
type ToolExecutionOptions,
} from "./tool-execution.js";
export { TreeSelectorComponent } from "./tree-selector.js";
export { UserMessageComponent } from "./user-message.js";
export { UserMessageSelectorComponent } from "./user-message-selector.js";
export {
truncateToVisualLines,
type VisualTruncateResult,
} from "./visual-truncate.js";

View file

@ -0,0 +1,85 @@
/**
* Utilities for formatting keybinding hints in the UI.
*/
import {
type EditorAction,
getEditorKeybindings,
type KeyId,
} from "@mariozechner/pi-tui";
import type {
AppAction,
KeybindingsManager,
} from "../../../core/keybindings.js";
import { theme } from "../theme/theme.js";
/**
* Format keys array as display string (e.g., ["ctrl+c", "escape"] -> "ctrl+c/escape").
*/
function formatKeys(keys: KeyId[]): string {
if (keys.length === 0) return "";
if (keys.length === 1) return keys[0]!;
return keys.join("/");
}
/**
* Get display string for an editor action.
*/
export function editorKey(action: EditorAction): string {
return formatKeys(getEditorKeybindings().getKeys(action));
}
/**
* Get display string for an app action.
*/
export function appKey(
keybindings: KeybindingsManager,
action: AppAction,
): string {
return formatKeys(keybindings.getKeys(action));
}
/**
* Format a keybinding hint with consistent styling: dim key, muted description.
* Looks up the key from editor keybindings automatically.
*
* @param action - Editor action name (e.g., "selectConfirm", "expandTools")
* @param description - Description text (e.g., "to expand", "cancel")
* @returns Formatted string with dim key and muted description
*/
export function keyHint(action: EditorAction, description: string): string {
return (
theme.fg("dim", editorKey(action)) + theme.fg("muted", ` ${description}`)
);
}
/**
* Format a keybinding hint for app-level actions.
* Requires the KeybindingsManager instance.
*
* @param keybindings - KeybindingsManager instance
* @param action - App action name (e.g., "interrupt", "externalEditor")
* @param description - Description text
* @returns Formatted string with dim key and muted description
*/
export function appKeyHint(
keybindings: KeybindingsManager,
action: AppAction,
description: string,
): string {
return (
theme.fg("dim", appKey(keybindings, action)) +
theme.fg("muted", ` ${description}`)
);
}
/**
* Format a raw key string with description (for non-configurable keys like ).
*
* @param key - Raw key string
* @param description - Description text
* @returns Formatted string with dim key and muted description
*/
export function rawKeyHint(key: string, description: string): string {
return theme.fg("dim", key) + theme.fg("muted", ` ${description}`);
}

View file

@ -0,0 +1,204 @@
import { getOAuthProviders } from "@mariozechner/pi-ai/oauth";
import {
Container,
type Focusable,
getEditorKeybindings,
Input,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import { exec } from "child_process";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint } from "./keybinding-hints.js";
/**
* Login dialog component - replaces editor during OAuth login flow
*/
export class LoginDialogComponent extends Container implements Focusable {
private contentContainer: Container;
private input: Input;
private tui: TUI;
private abortController = new AbortController();
private inputResolver?: (value: string) => void;
private inputRejecter?: (error: Error) => void;
// Focusable implementation - propagate to input for IME cursor positioning
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.input.focused = value;
}
constructor(
tui: TUI,
providerId: string,
private onComplete: (success: boolean, message?: string) => void,
) {
super();
this.tui = tui;
const providerInfo = getOAuthProviders().find((p) => p.id === providerId);
const providerName = providerInfo?.name || providerId;
// Top border
this.addChild(new DynamicBorder());
// Title
this.addChild(
new Text(theme.fg("warning", `Login to ${providerName}`), 1, 0),
);
// Dynamic content area
this.contentContainer = new Container();
this.addChild(this.contentContainer);
// Input (always present, used when needed)
this.input = new Input();
this.input.onSubmit = () => {
if (this.inputResolver) {
this.inputResolver(this.input.getValue());
this.inputResolver = undefined;
this.inputRejecter = undefined;
}
};
this.input.onEscape = () => {
this.cancel();
};
// Bottom border
this.addChild(new DynamicBorder());
}
get signal(): AbortSignal {
return this.abortController.signal;
}
private cancel(): void {
this.abortController.abort();
if (this.inputRejecter) {
this.inputRejecter(new Error("Login cancelled"));
this.inputResolver = undefined;
this.inputRejecter = undefined;
}
this.onComplete(false, "Login cancelled");
}
/**
* Called by onAuth callback - show URL and optional instructions
*/
showAuth(url: string, instructions?: string): void {
this.contentContainer.clear();
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(new Text(theme.fg("accent", url), 1, 0));
const clickHint =
process.platform === "darwin"
? "Cmd+click to open"
: "Ctrl+click to open";
const hyperlink = `\x1b]8;;${url}\x07${clickHint}\x1b]8;;\x07`;
this.contentContainer.addChild(new Text(theme.fg("dim", hyperlink), 1, 0));
if (instructions) {
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(
new Text(theme.fg("warning", instructions), 1, 0),
);
}
// Try to open browser
const openCmd =
process.platform === "darwin"
? "open"
: process.platform === "win32"
? "start"
: "xdg-open";
exec(`${openCmd} "${url}"`);
this.tui.requestRender();
}
/**
* Show input for manual code/URL entry (for callback server providers)
*/
showManualInput(prompt: string): Promise<string> {
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(new Text(theme.fg("dim", prompt), 1, 0));
this.contentContainer.addChild(this.input);
this.contentContainer.addChild(
new Text(`(${keyHint("selectCancel", "to cancel")})`, 1, 0),
);
this.tui.requestRender();
return new Promise((resolve, reject) => {
this.inputResolver = resolve;
this.inputRejecter = reject;
});
}
/**
* Called by onPrompt callback - show prompt and wait for input
* Note: Does NOT clear content, appends to existing (preserves URL from showAuth)
*/
showPrompt(message: string, placeholder?: string): Promise<string> {
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(new Text(theme.fg("text", message), 1, 0));
if (placeholder) {
this.contentContainer.addChild(
new Text(theme.fg("dim", `e.g., ${placeholder}`), 1, 0),
);
}
this.contentContainer.addChild(this.input);
this.contentContainer.addChild(
new Text(
`(${keyHint("selectCancel", "to cancel,")} ${keyHint("selectConfirm", "to submit")})`,
1,
0,
),
);
this.input.setValue("");
this.tui.requestRender();
return new Promise((resolve, reject) => {
this.inputResolver = resolve;
this.inputRejecter = reject;
});
}
/**
* Show waiting message (for polling flows like GitHub Copilot)
*/
showWaiting(message: string): void {
this.contentContainer.addChild(new Spacer(1));
this.contentContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
this.contentContainer.addChild(
new Text(`(${keyHint("selectCancel", "to cancel")})`, 1, 0),
);
this.tui.requestRender();
}
/**
* Called by onProgress callback
*/
showProgress(message: string): void {
this.contentContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
this.tui.requestRender();
}
handleInput(data: string): void {
const kb = getEditorKeybindings();
if (kb.matches(data, "selectCancel")) {
this.cancel();
return;
}
// Pass to input
this.input.handleInput(data);
}
}

View file

@ -0,0 +1,372 @@
import { type Model, modelsAreEqual } from "@mariozechner/pi-ai";
import {
Container,
type Focusable,
fuzzyFilter,
getEditorKeybindings,
Input,
Spacer,
Text,
type TUI,
} from "@mariozechner/pi-tui";
import type { ModelRegistry } from "../../../core/model-registry.js";
import type { SettingsManager } from "../../../core/settings-manager.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint } from "./keybinding-hints.js";
interface ModelItem {
provider: string;
id: string;
model: Model<any>;
}
interface ScopedModelItem {
model: Model<any>;
thinkingLevel?: string;
}
type ModelScope = "all" | "scoped";
/**
* Component that renders a model selector with search
*/
export class ModelSelectorComponent extends Container implements Focusable {
private searchInput: Input;
// Focusable implementation - propagate to searchInput for IME cursor positioning
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.searchInput.focused = value;
}
private listContainer: Container;
private allModels: ModelItem[] = [];
private scopedModelItems: ModelItem[] = [];
private activeModels: ModelItem[] = [];
private filteredModels: ModelItem[] = [];
private selectedIndex: number = 0;
private currentModel?: Model<any>;
private settingsManager: SettingsManager;
private modelRegistry: ModelRegistry;
private onSelectCallback: (model: Model<any>) => void;
private onCancelCallback: () => void;
private errorMessage?: string;
private tui: TUI;
private scopedModels: ReadonlyArray<ScopedModelItem>;
private scope: ModelScope = "all";
private scopeText?: Text;
private scopeHintText?: Text;
constructor(
tui: TUI,
currentModel: Model<any> | undefined,
settingsManager: SettingsManager,
modelRegistry: ModelRegistry,
scopedModels: ReadonlyArray<ScopedModelItem>,
onSelect: (model: Model<any>) => void,
onCancel: () => void,
initialSearchInput?: string,
) {
super();
this.tui = tui;
this.currentModel = currentModel;
this.settingsManager = settingsManager;
this.modelRegistry = modelRegistry;
this.scopedModels = scopedModels;
this.scope = scopedModels.length > 0 ? "scoped" : "all";
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
// Add top border
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
// Add hint about model filtering
if (scopedModels.length > 0) {
this.scopeText = new Text(this.getScopeText(), 0, 0);
this.addChild(this.scopeText);
this.scopeHintText = new Text(this.getScopeHintText(), 0, 0);
this.addChild(this.scopeHintText);
} else {
const hintText =
"Only showing models with configured API keys (see README for details)";
this.addChild(new Text(theme.fg("warning", hintText), 0, 0));
}
this.addChild(new Spacer(1));
// Create search input
this.searchInput = new Input();
if (initialSearchInput) {
this.searchInput.setValue(initialSearchInput);
}
this.searchInput.onSubmit = () => {
// Enter on search input selects the first filtered item
if (this.filteredModels[this.selectedIndex]) {
this.handleSelect(this.filteredModels[this.selectedIndex].model);
}
};
this.addChild(this.searchInput);
this.addChild(new Spacer(1));
// Create list container
this.listContainer = new Container();
this.addChild(this.listContainer);
this.addChild(new Spacer(1));
// Add bottom border
this.addChild(new DynamicBorder());
// Load models and do initial render
this.loadModels().then(() => {
if (initialSearchInput) {
this.filterModels(initialSearchInput);
} else {
this.updateList();
}
// Request re-render after models are loaded
this.tui.requestRender();
});
}
private async loadModels(): Promise<void> {
let models: ModelItem[];
// Refresh to pick up any changes to models.json
this.modelRegistry.refresh();
// Check for models.json errors
const loadError = this.modelRegistry.getError();
if (loadError) {
this.errorMessage = loadError;
}
// Load available models (built-in models still work even if models.json failed)
try {
const availableModels = await this.modelRegistry.getAvailable();
models = availableModels.map((model: Model<any>) => ({
provider: model.provider,
id: model.id,
model,
}));
} catch (error) {
this.allModels = [];
this.scopedModelItems = [];
this.activeModels = [];
this.filteredModels = [];
this.errorMessage =
error instanceof Error ? error.message : String(error);
return;
}
this.allModels = this.sortModels(models);
this.scopedModelItems = this.sortModels(
this.scopedModels.map((scoped) => ({
provider: scoped.model.provider,
id: scoped.model.id,
model: scoped.model,
})),
);
this.activeModels =
this.scope === "scoped" ? this.scopedModelItems : this.allModels;
this.filteredModels = this.activeModels;
this.selectedIndex = Math.min(
this.selectedIndex,
Math.max(0, this.filteredModels.length - 1),
);
}
private sortModels(models: ModelItem[]): ModelItem[] {
const sorted = [...models];
// Sort: current model first, then by provider
sorted.sort((a, b) => {
const aIsCurrent = modelsAreEqual(this.currentModel, a.model);
const bIsCurrent = modelsAreEqual(this.currentModel, b.model);
if (aIsCurrent && !bIsCurrent) return -1;
if (!aIsCurrent && bIsCurrent) return 1;
return a.provider.localeCompare(b.provider);
});
return sorted;
}
private getScopeText(): string {
const allText =
this.scope === "all"
? theme.fg("accent", "all")
: theme.fg("muted", "all");
const scopedText =
this.scope === "scoped"
? theme.fg("accent", "scoped")
: theme.fg("muted", "scoped");
return `${theme.fg("muted", "Scope: ")}${allText}${theme.fg("muted", " | ")}${scopedText}`;
}
private getScopeHintText(): string {
return keyHint("tab", "scope") + theme.fg("muted", " (all/scoped)");
}
private setScope(scope: ModelScope): void {
if (this.scope === scope) return;
this.scope = scope;
this.activeModels =
this.scope === "scoped" ? this.scopedModelItems : this.allModels;
this.selectedIndex = 0;
this.filterModels(this.searchInput.getValue());
if (this.scopeText) {
this.scopeText.setText(this.getScopeText());
}
}
private filterModels(query: string): void {
this.filteredModels = query
? fuzzyFilter(
this.activeModels,
query,
({ id, provider }) => `${id} ${provider}`,
)
: this.activeModels;
this.selectedIndex = Math.min(
this.selectedIndex,
Math.max(0, this.filteredModels.length - 1),
);
this.updateList();
}
private updateList(): void {
this.listContainer.clear();
const maxVisible = 10;
const startIndex = Math.max(
0,
Math.min(
this.selectedIndex - Math.floor(maxVisible / 2),
this.filteredModels.length - maxVisible,
),
);
const endIndex = Math.min(
startIndex + maxVisible,
this.filteredModels.length,
);
// Show visible slice of filtered models
for (let i = startIndex; i < endIndex; i++) {
const item = this.filteredModels[i];
if (!item) continue;
const isSelected = i === this.selectedIndex;
const isCurrent = modelsAreEqual(this.currentModel, item.model);
let line = "";
if (isSelected) {
const prefix = theme.fg("accent", "→ ");
const modelText = `${item.id}`;
const providerBadge = theme.fg("muted", `[${item.provider}]`);
const checkmark = isCurrent ? theme.fg("success", " ✓") : "";
line = `${prefix + theme.fg("accent", modelText)} ${providerBadge}${checkmark}`;
} else {
const modelText = ` ${item.id}`;
const providerBadge = theme.fg("muted", `[${item.provider}]`);
const checkmark = isCurrent ? theme.fg("success", " ✓") : "";
line = `${modelText} ${providerBadge}${checkmark}`;
}
this.listContainer.addChild(new Text(line, 0, 0));
}
// Add scroll indicator if needed
if (startIndex > 0 || endIndex < this.filteredModels.length) {
const scrollInfo = theme.fg(
"muted",
` (${this.selectedIndex + 1}/${this.filteredModels.length})`,
);
this.listContainer.addChild(new Text(scrollInfo, 0, 0));
}
// Show error message or "no results" if empty
if (this.errorMessage) {
// Show error in red
const errorLines = this.errorMessage.split("\n");
for (const line of errorLines) {
this.listContainer.addChild(new Text(theme.fg("error", line), 0, 0));
}
} else if (this.filteredModels.length === 0) {
this.listContainer.addChild(
new Text(theme.fg("muted", " No matching models"), 0, 0),
);
} else {
const selected = this.filteredModels[this.selectedIndex];
this.listContainer.addChild(new Spacer(1));
this.listContainer.addChild(
new Text(
theme.fg("muted", ` Model Name: ${selected.model.name}`),
0,
0,
),
);
}
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
if (kb.matches(keyData, "tab")) {
if (this.scopedModelItems.length > 0) {
const nextScope: ModelScope = this.scope === "all" ? "scoped" : "all";
this.setScope(nextScope);
if (this.scopeHintText) {
this.scopeHintText.setText(this.getScopeHintText());
}
}
return;
}
// Up arrow - wrap to bottom when at top
if (kb.matches(keyData, "selectUp")) {
if (this.filteredModels.length === 0) return;
this.selectedIndex =
this.selectedIndex === 0
? this.filteredModels.length - 1
: this.selectedIndex - 1;
this.updateList();
}
// Down arrow - wrap to top when at bottom
else if (kb.matches(keyData, "selectDown")) {
if (this.filteredModels.length === 0) return;
this.selectedIndex =
this.selectedIndex === this.filteredModels.length - 1
? 0
: this.selectedIndex + 1;
this.updateList();
}
// Enter
else if (kb.matches(keyData, "selectConfirm")) {
const selectedModel = this.filteredModels[this.selectedIndex];
if (selectedModel) {
this.handleSelect(selectedModel.model);
}
}
// Escape or Ctrl+C
else if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
}
// Pass everything else to search input
else {
this.searchInput.handleInput(keyData);
this.filterModels(this.searchInput.getValue());
}
}
private handleSelect(model: Model<any>): void {
// Save as new default
this.settingsManager.setDefaultModelAndProvider(model.provider, model.id);
this.onSelectCallback(model);
}
getSearchInput(): Input {
return this.searchInput;
}
}

View file

@ -0,0 +1,138 @@
import type { OAuthProviderInterface } from "@mariozechner/pi-ai";
import { getOAuthProviders } from "@mariozechner/pi-ai/oauth";
import {
Container,
getEditorKeybindings,
Spacer,
TruncatedText,
} from "@mariozechner/pi-tui";
import type { AuthStorage } from "../../../core/auth-storage.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
/**
* Component that renders an OAuth provider selector
*/
export class OAuthSelectorComponent extends Container {
private listContainer: Container;
private allProviders: OAuthProviderInterface[] = [];
private selectedIndex: number = 0;
private mode: "login" | "logout";
private authStorage: AuthStorage;
private onSelectCallback: (providerId: string) => void;
private onCancelCallback: () => void;
constructor(
mode: "login" | "logout",
authStorage: AuthStorage,
onSelect: (providerId: string) => void,
onCancel: () => void,
) {
super();
this.mode = mode;
this.authStorage = authStorage;
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
// Load all OAuth providers
this.loadProviders();
// Add top border
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
// Add title
const title =
mode === "login"
? "Select provider to login:"
: "Select provider to logout:";
this.addChild(new TruncatedText(theme.bold(title)));
this.addChild(new Spacer(1));
// Create list container
this.listContainer = new Container();
this.addChild(this.listContainer);
this.addChild(new Spacer(1));
// Add bottom border
this.addChild(new DynamicBorder());
// Initial render
this.updateList();
}
private loadProviders(): void {
this.allProviders = getOAuthProviders();
}
private updateList(): void {
this.listContainer.clear();
for (let i = 0; i < this.allProviders.length; i++) {
const provider = this.allProviders[i];
if (!provider) continue;
const isSelected = i === this.selectedIndex;
// Check if user is logged in for this provider
const credentials = this.authStorage.get(provider.id);
const isLoggedIn = credentials?.type === "oauth";
const statusIndicator = isLoggedIn
? theme.fg("success", " ✓ logged in")
: "";
let line = "";
if (isSelected) {
const prefix = theme.fg("accent", "→ ");
const text = theme.fg("accent", provider.name);
line = prefix + text + statusIndicator;
} else {
const text = ` ${provider.name}`;
line = text + statusIndicator;
}
this.listContainer.addChild(new TruncatedText(line, 0, 0));
}
// Show "no providers" if empty
if (this.allProviders.length === 0) {
const message =
this.mode === "login"
? "No OAuth providers available"
: "No OAuth providers logged in. Use /login first.";
this.listContainer.addChild(
new TruncatedText(theme.fg("muted", ` ${message}`), 0, 0),
);
}
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
// Up arrow
if (kb.matches(keyData, "selectUp")) {
this.selectedIndex = Math.max(0, this.selectedIndex - 1);
this.updateList();
}
// Down arrow
else if (kb.matches(keyData, "selectDown")) {
this.selectedIndex = Math.min(
this.allProviders.length - 1,
this.selectedIndex + 1,
);
this.updateList();
}
// Enter
else if (kb.matches(keyData, "selectConfirm")) {
const selectedProvider = this.allProviders[this.selectedIndex];
if (selectedProvider) {
this.onSelectCallback(selectedProvider.id);
}
}
// Escape or Ctrl+C
else if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
}
}
}

View file

@ -0,0 +1,444 @@
import type { Model } from "@mariozechner/pi-ai";
import {
Container,
type Focusable,
fuzzyFilter,
getEditorKeybindings,
Input,
Key,
matchesKey,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
// EnabledIds: null = all enabled (no filter), string[] = explicit ordered list
type EnabledIds = string[] | null;
function isEnabled(enabledIds: EnabledIds, id: string): boolean {
return enabledIds === null || enabledIds.includes(id);
}
function toggle(enabledIds: EnabledIds, id: string): EnabledIds {
if (enabledIds === null) return [id]; // First toggle: start with only this one
const index = enabledIds.indexOf(id);
if (index >= 0)
return [...enabledIds.slice(0, index), ...enabledIds.slice(index + 1)];
return [...enabledIds, id];
}
function enableAll(
enabledIds: EnabledIds,
allIds: string[],
targetIds?: string[],
): EnabledIds {
if (enabledIds === null) return null; // Already all enabled
const targets = targetIds ?? allIds;
const result = [...enabledIds];
for (const id of targets) {
if (!result.includes(id)) result.push(id);
}
return result.length === allIds.length ? null : result;
}
function clearAll(
enabledIds: EnabledIds,
allIds: string[],
targetIds?: string[],
): EnabledIds {
if (enabledIds === null) {
return targetIds ? allIds.filter((id) => !targetIds.includes(id)) : [];
}
const targets = new Set(targetIds ?? enabledIds);
return enabledIds.filter((id) => !targets.has(id));
}
function move(
enabledIds: EnabledIds,
allIds: string[],
id: string,
delta: number,
): EnabledIds {
const list = enabledIds ?? [...allIds];
const index = list.indexOf(id);
if (index < 0) return list;
const newIndex = index + delta;
if (newIndex < 0 || newIndex >= list.length) return list;
const result = [...list];
[result[index], result[newIndex]] = [result[newIndex], result[index]];
return result;
}
function getSortedIds(enabledIds: EnabledIds, allIds: string[]): string[] {
if (enabledIds === null) return allIds;
const enabledSet = new Set(enabledIds);
return [...enabledIds, ...allIds.filter((id) => !enabledSet.has(id))];
}
interface ModelItem {
fullId: string;
model: Model<any>;
enabled: boolean;
}
export interface ModelsConfig {
allModels: Model<any>[];
enabledModelIds: Set<string>;
/** true if enabledModels setting is defined (empty = all enabled) */
hasEnabledModelsFilter: boolean;
}
export interface ModelsCallbacks {
/** Called when a model is toggled (session-only, no persist) */
onModelToggle: (modelId: string, enabled: boolean) => void;
/** Called when user wants to persist current selection to settings */
onPersist: (enabledModelIds: string[]) => void;
/** Called when user enables all models. Returns list of all model IDs. */
onEnableAll: (allModelIds: string[]) => void;
/** Called when user clears all models */
onClearAll: () => void;
/** Called when user toggles all models for a provider. Returns affected model IDs. */
onToggleProvider: (
provider: string,
modelIds: string[],
enabled: boolean,
) => void;
onCancel: () => void;
}
/**
* Component for enabling/disabling models for Ctrl+P cycling.
* Changes are session-only until explicitly persisted with Ctrl+S.
*/
export class ScopedModelsSelectorComponent
extends Container
implements Focusable
{
private modelsById: Map<string, Model<any>> = new Map();
private allIds: string[] = [];
private enabledIds: EnabledIds = null;
private filteredItems: ModelItem[] = [];
private selectedIndex = 0;
private searchInput: Input;
// Focusable implementation - propagate to searchInput for IME cursor positioning
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.searchInput.focused = value;
}
private listContainer: Container;
private footerText: Text;
private callbacks: ModelsCallbacks;
private maxVisible = 15;
private isDirty = false;
constructor(config: ModelsConfig, callbacks: ModelsCallbacks) {
super();
this.callbacks = callbacks;
for (const model of config.allModels) {
const fullId = `${model.provider}/${model.id}`;
this.modelsById.set(fullId, model);
this.allIds.push(fullId);
}
this.enabledIds = config.hasEnabledModelsFilter
? [...config.enabledModelIds]
: null;
this.filteredItems = this.buildItems();
// Header
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
this.addChild(
new Text(theme.fg("accent", theme.bold("Model Configuration")), 0, 0),
);
this.addChild(
new Text(
theme.fg("muted", "Session-only. Ctrl+S to save to settings."),
0,
0,
),
);
this.addChild(new Spacer(1));
// Search input
this.searchInput = new Input();
this.addChild(this.searchInput);
this.addChild(new Spacer(1));
// List container
this.listContainer = new Container();
this.addChild(this.listContainer);
// Footer hint
this.addChild(new Spacer(1));
this.footerText = new Text(this.getFooterText(), 0, 0);
this.addChild(this.footerText);
this.addChild(new DynamicBorder());
this.updateList();
}
private buildItems(): ModelItem[] {
// Filter out IDs that no longer have a corresponding model (e.g., after logout)
return getSortedIds(this.enabledIds, this.allIds)
.filter((id) => this.modelsById.has(id))
.map((id) => ({
fullId: id,
model: this.modelsById.get(id)!,
enabled: isEnabled(this.enabledIds, id),
}));
}
private getFooterText(): string {
const enabledCount = this.enabledIds?.length ?? this.allIds.length;
const allEnabled = this.enabledIds === null;
const countText = allEnabled
? "all enabled"
: `${enabledCount}/${this.allIds.length} enabled`;
const parts = [
"Enter toggle",
"^A all",
"^X clear",
"^P provider",
"Alt+↑↓ reorder",
"^S save",
countText,
];
return this.isDirty
? theme.fg("dim", ` ${parts.join(" · ")} `) +
theme.fg("warning", "(unsaved)")
: theme.fg("dim", ` ${parts.join(" · ")}`);
}
private refresh(): void {
const query = this.searchInput.getValue();
const items = this.buildItems();
this.filteredItems = query
? fuzzyFilter(items, query, (i) => `${i.model.id} ${i.model.provider}`)
: items;
this.selectedIndex = Math.min(
this.selectedIndex,
Math.max(0, this.filteredItems.length - 1),
);
this.updateList();
this.footerText.setText(this.getFooterText());
}
private updateList(): void {
this.listContainer.clear();
if (this.filteredItems.length === 0) {
this.listContainer.addChild(
new Text(theme.fg("muted", " No matching models"), 0, 0),
);
return;
}
const startIndex = Math.max(
0,
Math.min(
this.selectedIndex - Math.floor(this.maxVisible / 2),
this.filteredItems.length - this.maxVisible,
),
);
const endIndex = Math.min(
startIndex + this.maxVisible,
this.filteredItems.length,
);
const allEnabled = this.enabledIds === null;
for (let i = startIndex; i < endIndex; i++) {
const item = this.filteredItems[i]!;
const isSelected = i === this.selectedIndex;
const prefix = isSelected ? theme.fg("accent", "→ ") : " ";
const modelText = isSelected
? theme.fg("accent", item.model.id)
: item.model.id;
const providerBadge = theme.fg("muted", ` [${item.model.provider}]`);
const status = allEnabled
? ""
: item.enabled
? theme.fg("success", " ✓")
: theme.fg("dim", " ✗");
this.listContainer.addChild(
new Text(`${prefix}${modelText}${providerBadge}${status}`, 0, 0),
);
}
// Add scroll indicator if needed
if (startIndex > 0 || endIndex < this.filteredItems.length) {
this.listContainer.addChild(
new Text(
theme.fg(
"muted",
` (${this.selectedIndex + 1}/${this.filteredItems.length})`,
),
0,
0,
),
);
}
if (this.filteredItems.length > 0) {
const selected = this.filteredItems[this.selectedIndex];
this.listContainer.addChild(new Spacer(1));
this.listContainer.addChild(
new Text(
theme.fg("muted", ` Model Name: ${selected.model.name}`),
0,
0,
),
);
}
}
handleInput(data: string): void {
const kb = getEditorKeybindings();
// Navigation
if (kb.matches(data, "selectUp")) {
if (this.filteredItems.length === 0) return;
this.selectedIndex =
this.selectedIndex === 0
? this.filteredItems.length - 1
: this.selectedIndex - 1;
this.updateList();
return;
}
if (kb.matches(data, "selectDown")) {
if (this.filteredItems.length === 0) return;
this.selectedIndex =
this.selectedIndex === this.filteredItems.length - 1
? 0
: this.selectedIndex + 1;
this.updateList();
return;
}
// Alt+Up/Down - Reorder enabled models
if (matchesKey(data, Key.alt("up")) || matchesKey(data, Key.alt("down"))) {
const item = this.filteredItems[this.selectedIndex];
if (item && isEnabled(this.enabledIds, item.fullId)) {
const delta = matchesKey(data, Key.alt("up")) ? -1 : 1;
const enabledList = this.enabledIds ?? this.allIds;
const currentIndex = enabledList.indexOf(item.fullId);
const newIndex = currentIndex + delta;
// Only move if within bounds
if (newIndex >= 0 && newIndex < enabledList.length) {
this.enabledIds = move(
this.enabledIds,
this.allIds,
item.fullId,
delta,
);
this.isDirty = true;
this.selectedIndex += delta;
this.refresh();
}
}
return;
}
// Toggle on Enter
if (matchesKey(data, Key.enter)) {
const item = this.filteredItems[this.selectedIndex];
if (item) {
const wasAllEnabled = this.enabledIds === null;
this.enabledIds = toggle(this.enabledIds, item.fullId);
this.isDirty = true;
if (wasAllEnabled) this.callbacks.onClearAll();
this.callbacks.onModelToggle(
item.fullId,
isEnabled(this.enabledIds, item.fullId),
);
this.refresh();
}
return;
}
// Ctrl+A - Enable all (filtered if search active, otherwise all)
if (matchesKey(data, Key.ctrl("a"))) {
const targetIds = this.searchInput.getValue()
? this.filteredItems.map((i) => i.fullId)
: undefined;
this.enabledIds = enableAll(this.enabledIds, this.allIds, targetIds);
this.isDirty = true;
this.callbacks.onEnableAll(targetIds ?? this.allIds);
this.refresh();
return;
}
// Ctrl+X - Clear all (filtered if search active, otherwise all)
if (matchesKey(data, Key.ctrl("x"))) {
const targetIds = this.searchInput.getValue()
? this.filteredItems.map((i) => i.fullId)
: undefined;
this.enabledIds = clearAll(this.enabledIds, this.allIds, targetIds);
this.isDirty = true;
this.callbacks.onClearAll();
this.refresh();
return;
}
// Ctrl+P - Toggle provider of current item
if (matchesKey(data, Key.ctrl("p"))) {
const item = this.filteredItems[this.selectedIndex];
if (item) {
const provider = item.model.provider;
const providerIds = this.allIds.filter(
(id) => this.modelsById.get(id)!.provider === provider,
);
const allEnabled = providerIds.every((id) =>
isEnabled(this.enabledIds, id),
);
this.enabledIds = allEnabled
? clearAll(this.enabledIds, this.allIds, providerIds)
: enableAll(this.enabledIds, this.allIds, providerIds);
this.isDirty = true;
this.callbacks.onToggleProvider(provider, providerIds, !allEnabled);
this.refresh();
}
return;
}
// Ctrl+S - Save/persist to settings
if (matchesKey(data, Key.ctrl("s"))) {
this.callbacks.onPersist(this.enabledIds ?? [...this.allIds]);
this.isDirty = false;
this.footerText.setText(this.getFooterText());
return;
}
// Ctrl+C - clear search or cancel if empty
if (matchesKey(data, Key.ctrl("c"))) {
if (this.searchInput.getValue()) {
this.searchInput.setValue("");
this.refresh();
} else {
this.callbacks.onCancel();
}
return;
}
// Escape - cancel
if (matchesKey(data, Key.escape)) {
this.callbacks.onCancel();
return;
}
// Pass everything else to search input
this.searchInput.handleInput(data);
this.refresh();
}
getSearchInput(): Input {
return this.searchInput;
}
}

View file

@ -0,0 +1,199 @@
import { fuzzyMatch } from "@mariozechner/pi-tui";
import type { SessionInfo } from "../../../core/session-manager.js";
export type SortMode = "threaded" | "recent" | "relevance";
export type NameFilter = "all" | "named";
export interface ParsedSearchQuery {
mode: "tokens" | "regex";
tokens: { kind: "fuzzy" | "phrase"; value: string }[];
regex: RegExp | null;
/** If set, parsing failed and we should treat query as non-matching. */
error?: string;
}
export interface MatchResult {
matches: boolean;
/** Lower is better; only meaningful when matches === true */
score: number;
}
function normalizeWhitespaceLower(text: string): string {
return text.toLowerCase().replace(/\s+/g, " ").trim();
}
function getSessionSearchText(session: SessionInfo): string {
return `${session.id} ${session.name ?? ""} ${session.allMessagesText} ${session.cwd}`;
}
export function hasSessionName(session: SessionInfo): boolean {
return Boolean(session.name?.trim());
}
function matchesNameFilter(session: SessionInfo, filter: NameFilter): boolean {
if (filter === "all") return true;
return hasSessionName(session);
}
export function parseSearchQuery(query: string): ParsedSearchQuery {
const trimmed = query.trim();
if (!trimmed) {
return { mode: "tokens", tokens: [], regex: null };
}
// Regex mode: re:<pattern>
if (trimmed.startsWith("re:")) {
const pattern = trimmed.slice(3).trim();
if (!pattern) {
return { mode: "regex", tokens: [], regex: null, error: "Empty regex" };
}
try {
return { mode: "regex", tokens: [], regex: new RegExp(pattern, "i") };
} catch (err) {
const msg = err instanceof Error ? err.message : String(err);
return { mode: "regex", tokens: [], regex: null, error: msg };
}
}
// Token mode with quote support.
// Example: foo "node cve" bar
const tokens: { kind: "fuzzy" | "phrase"; value: string }[] = [];
let buf = "";
let inQuote = false;
let hadUnclosedQuote = false;
const flush = (kind: "fuzzy" | "phrase"): void => {
const v = buf.trim();
buf = "";
if (!v) return;
tokens.push({ kind, value: v });
};
for (let i = 0; i < trimmed.length; i++) {
const ch = trimmed[i]!;
if (ch === '"') {
if (inQuote) {
flush("phrase");
inQuote = false;
} else {
flush("fuzzy");
inQuote = true;
}
continue;
}
if (!inQuote && /\s/.test(ch)) {
flush("fuzzy");
continue;
}
buf += ch;
}
if (inQuote) {
hadUnclosedQuote = true;
}
// If quotes were unbalanced, fall back to plain whitespace tokenization.
if (hadUnclosedQuote) {
return {
mode: "tokens",
tokens: trimmed
.split(/\s+/)
.map((t) => t.trim())
.filter((t) => t.length > 0)
.map((t) => ({ kind: "fuzzy" as const, value: t })),
regex: null,
};
}
flush(inQuote ? "phrase" : "fuzzy");
return { mode: "tokens", tokens, regex: null };
}
export function matchSession(
session: SessionInfo,
parsed: ParsedSearchQuery,
): MatchResult {
const text = getSessionSearchText(session);
if (parsed.mode === "regex") {
if (!parsed.regex) {
return { matches: false, score: 0 };
}
const idx = text.search(parsed.regex);
if (idx < 0) return { matches: false, score: 0 };
return { matches: true, score: idx * 0.1 };
}
if (parsed.tokens.length === 0) {
return { matches: true, score: 0 };
}
let totalScore = 0;
let normalizedText: string | null = null;
for (const token of parsed.tokens) {
if (token.kind === "phrase") {
if (normalizedText === null) {
normalizedText = normalizeWhitespaceLower(text);
}
const phrase = normalizeWhitespaceLower(token.value);
if (!phrase) continue;
const idx = normalizedText.indexOf(phrase);
if (idx < 0) return { matches: false, score: 0 };
totalScore += idx * 0.1;
continue;
}
const m = fuzzyMatch(token.value, text);
if (!m.matches) return { matches: false, score: 0 };
totalScore += m.score;
}
return { matches: true, score: totalScore };
}
export function filterAndSortSessions(
sessions: SessionInfo[],
query: string,
sortMode: SortMode,
nameFilter: NameFilter = "all",
): SessionInfo[] {
const nameFiltered =
nameFilter === "all"
? sessions
: sessions.filter((session) => matchesNameFilter(session, nameFilter));
const trimmed = query.trim();
if (!trimmed) return nameFiltered;
const parsed = parseSearchQuery(query);
if (parsed.error) return [];
// Recent mode: filter only, keep incoming order.
if (sortMode === "recent") {
const filtered: SessionInfo[] = [];
for (const s of nameFiltered) {
const res = matchSession(s, parsed);
if (res.matches) filtered.push(s);
}
return filtered;
}
// Relevance mode: sort by score, tie-break by modified desc.
const scored: { session: SessionInfo; score: number }[] = [];
for (const s of nameFiltered) {
const res = matchSession(s, parsed);
if (!res.matches) continue;
scored.push({ session: s, score: res.score });
}
scored.sort((a, b) => {
if (a.score !== b.score) return a.score - b.score;
return b.session.modified.getTime() - a.session.modified.getTime();
});
return scored.map((r) => r.session);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,453 @@
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Transport } from "@mariozechner/pi-ai";
import {
Container,
getCapabilities,
type SelectItem,
SelectList,
type SettingItem,
SettingsList,
Spacer,
Text,
} from "@mariozechner/pi-tui";
import {
getSelectListTheme,
getSettingsListTheme,
theme,
} from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
const THINKING_DESCRIPTIONS: Record<ThinkingLevel, string> = {
off: "No reasoning",
minimal: "Very brief reasoning (~1k tokens)",
low: "Light reasoning (~2k tokens)",
medium: "Moderate reasoning (~8k tokens)",
high: "Deep reasoning (~16k tokens)",
xhigh: "Maximum reasoning (~32k tokens)",
};
export interface SettingsConfig {
autoCompact: boolean;
showImages: boolean;
autoResizeImages: boolean;
blockImages: boolean;
enableSkillCommands: boolean;
steeringMode: "all" | "one-at-a-time";
followUpMode: "all" | "one-at-a-time";
transport: Transport;
thinkingLevel: ThinkingLevel;
availableThinkingLevels: ThinkingLevel[];
currentTheme: string;
availableThemes: string[];
hideThinkingBlock: boolean;
collapseChangelog: boolean;
doubleEscapeAction: "fork" | "tree" | "none";
treeFilterMode: "default" | "no-tools" | "user-only" | "labeled-only" | "all";
showHardwareCursor: boolean;
editorPaddingX: number;
autocompleteMaxVisible: number;
quietStartup: boolean;
clearOnShrink: boolean;
}
export interface SettingsCallbacks {
onAutoCompactChange: (enabled: boolean) => void;
onShowImagesChange: (enabled: boolean) => void;
onAutoResizeImagesChange: (enabled: boolean) => void;
onBlockImagesChange: (blocked: boolean) => void;
onEnableSkillCommandsChange: (enabled: boolean) => void;
onSteeringModeChange: (mode: "all" | "one-at-a-time") => void;
onFollowUpModeChange: (mode: "all" | "one-at-a-time") => void;
onTransportChange: (transport: Transport) => void;
onThinkingLevelChange: (level: ThinkingLevel) => void;
onThemeChange: (theme: string) => void;
onThemePreview?: (theme: string) => void;
onHideThinkingBlockChange: (hidden: boolean) => void;
onCollapseChangelogChange: (collapsed: boolean) => void;
onDoubleEscapeActionChange: (action: "fork" | "tree" | "none") => void;
onTreeFilterModeChange: (
mode: "default" | "no-tools" | "user-only" | "labeled-only" | "all",
) => void;
onShowHardwareCursorChange: (enabled: boolean) => void;
onEditorPaddingXChange: (padding: number) => void;
onAutocompleteMaxVisibleChange: (maxVisible: number) => void;
onQuietStartupChange: (enabled: boolean) => void;
onClearOnShrinkChange: (enabled: boolean) => void;
onCancel: () => void;
}
/**
* A submenu component for selecting from a list of options.
*/
class SelectSubmenu extends Container {
private selectList: SelectList;
constructor(
title: string,
description: string,
options: SelectItem[],
currentValue: string,
onSelect: (value: string) => void,
onCancel: () => void,
onSelectionChange?: (value: string) => void,
) {
super();
// Title
this.addChild(new Text(theme.bold(theme.fg("accent", title)), 0, 0));
// Description
if (description) {
this.addChild(new Spacer(1));
this.addChild(new Text(theme.fg("muted", description), 0, 0));
}
// Spacer
this.addChild(new Spacer(1));
// Select list
this.selectList = new SelectList(
options,
Math.min(options.length, 10),
getSelectListTheme(),
);
// Pre-select current value
const currentIndex = options.findIndex((o) => o.value === currentValue);
if (currentIndex !== -1) {
this.selectList.setSelectedIndex(currentIndex);
}
this.selectList.onSelect = (item) => {
onSelect(item.value);
};
this.selectList.onCancel = onCancel;
if (onSelectionChange) {
this.selectList.onSelectionChange = (item) => {
onSelectionChange(item.value);
};
}
this.addChild(this.selectList);
// Hint
this.addChild(new Spacer(1));
this.addChild(
new Text(theme.fg("dim", " Enter to select · Esc to go back"), 0, 0),
);
}
handleInput(data: string): void {
this.selectList.handleInput(data);
}
}
/**
* Main settings selector component.
*/
export class SettingsSelectorComponent extends Container {
private settingsList: SettingsList;
constructor(config: SettingsConfig, callbacks: SettingsCallbacks) {
super();
const supportsImages = getCapabilities().images;
const items: SettingItem[] = [
{
id: "autocompact",
label: "Auto-compact",
description: "Automatically compact context when it gets too large",
currentValue: config.autoCompact ? "true" : "false",
values: ["true", "false"],
},
{
id: "steering-mode",
label: "Steering mode",
description:
"Enter while streaming queues steering messages. 'one-at-a-time': deliver one, wait for response. 'all': deliver all at once.",
currentValue: config.steeringMode,
values: ["one-at-a-time", "all"],
},
{
id: "follow-up-mode",
label: "Follow-up mode",
description:
"Alt+Enter queues follow-up messages until agent stops. 'one-at-a-time': deliver one, wait for response. 'all': deliver all at once.",
currentValue: config.followUpMode,
values: ["one-at-a-time", "all"],
},
{
id: "transport",
label: "Transport",
description:
"Preferred transport for providers that support multiple transports",
currentValue: config.transport,
values: ["sse", "websocket", "auto"],
},
{
id: "hide-thinking",
label: "Hide thinking",
description: "Hide thinking blocks in assistant responses",
currentValue: config.hideThinkingBlock ? "true" : "false",
values: ["true", "false"],
},
{
id: "collapse-changelog",
label: "Collapse changelog",
description: "Show condensed changelog after updates",
currentValue: config.collapseChangelog ? "true" : "false",
values: ["true", "false"],
},
{
id: "quiet-startup",
label: "Quiet startup",
description: "Disable verbose printing at startup",
currentValue: config.quietStartup ? "true" : "false",
values: ["true", "false"],
},
{
id: "double-escape-action",
label: "Double-escape action",
description: "Action when pressing Escape twice with empty editor",
currentValue: config.doubleEscapeAction,
values: ["tree", "fork", "none"],
},
{
id: "tree-filter-mode",
label: "Tree filter mode",
description: "Default filter when opening /tree",
currentValue: config.treeFilterMode,
values: ["default", "no-tools", "user-only", "labeled-only", "all"],
},
{
id: "thinking",
label: "Thinking level",
description: "Reasoning depth for thinking-capable models",
currentValue: config.thinkingLevel,
submenu: (currentValue, done) =>
new SelectSubmenu(
"Thinking Level",
"Select reasoning depth for thinking-capable models",
config.availableThinkingLevels.map((level) => ({
value: level,
label: level,
description: THINKING_DESCRIPTIONS[level],
})),
currentValue,
(value) => {
callbacks.onThinkingLevelChange(value as ThinkingLevel);
done(value);
},
() => done(),
),
},
{
id: "theme",
label: "Theme",
description: "Color theme for the interface",
currentValue: config.currentTheme,
submenu: (currentValue, done) =>
new SelectSubmenu(
"Theme",
"Select color theme",
config.availableThemes.map((t) => ({
value: t,
label: t,
})),
currentValue,
(value) => {
callbacks.onThemeChange(value);
done(value);
},
() => {
// Restore original theme on cancel
callbacks.onThemePreview?.(currentValue);
done();
},
(value) => {
// Preview theme on selection change
callbacks.onThemePreview?.(value);
},
),
},
];
// Only show image toggle if terminal supports it
if (supportsImages) {
// Insert after autocompact
items.splice(1, 0, {
id: "show-images",
label: "Show images",
description: "Render images inline in terminal",
currentValue: config.showImages ? "true" : "false",
values: ["true", "false"],
});
}
// Image auto-resize toggle (always available, affects both attached and read images)
items.splice(supportsImages ? 2 : 1, 0, {
id: "auto-resize-images",
label: "Auto-resize images",
description:
"Resize large images to 2000x2000 max for better model compatibility",
currentValue: config.autoResizeImages ? "true" : "false",
values: ["true", "false"],
});
// Block images toggle (always available, insert after auto-resize-images)
const autoResizeIndex = items.findIndex(
(item) => item.id === "auto-resize-images",
);
items.splice(autoResizeIndex + 1, 0, {
id: "block-images",
label: "Block images",
description: "Prevent images from being sent to LLM providers",
currentValue: config.blockImages ? "true" : "false",
values: ["true", "false"],
});
// Skill commands toggle (insert after block-images)
const blockImagesIndex = items.findIndex(
(item) => item.id === "block-images",
);
items.splice(blockImagesIndex + 1, 0, {
id: "skill-commands",
label: "Skill commands",
description: "Register skills as /skill:name commands",
currentValue: config.enableSkillCommands ? "true" : "false",
values: ["true", "false"],
});
// Hardware cursor toggle (insert after skill-commands)
const skillCommandsIndex = items.findIndex(
(item) => item.id === "skill-commands",
);
items.splice(skillCommandsIndex + 1, 0, {
id: "show-hardware-cursor",
label: "Show hardware cursor",
description:
"Show the terminal cursor while still positioning it for IME support",
currentValue: config.showHardwareCursor ? "true" : "false",
values: ["true", "false"],
});
// Editor padding toggle (insert after show-hardware-cursor)
const hardwareCursorIndex = items.findIndex(
(item) => item.id === "show-hardware-cursor",
);
items.splice(hardwareCursorIndex + 1, 0, {
id: "editor-padding",
label: "Editor padding",
description: "Horizontal padding for input editor (0-3)",
currentValue: String(config.editorPaddingX),
values: ["0", "1", "2", "3"],
});
// Autocomplete max visible toggle (insert after editor-padding)
const editorPaddingIndex = items.findIndex(
(item) => item.id === "editor-padding",
);
items.splice(editorPaddingIndex + 1, 0, {
id: "autocomplete-max-visible",
label: "Autocomplete max items",
description: "Max visible items in autocomplete dropdown (3-20)",
currentValue: String(config.autocompleteMaxVisible),
values: ["3", "5", "7", "10", "15", "20"],
});
// Clear on shrink toggle (insert after autocomplete-max-visible)
const autocompleteIndex = items.findIndex(
(item) => item.id === "autocomplete-max-visible",
);
items.splice(autocompleteIndex + 1, 0, {
id: "clear-on-shrink",
label: "Clear on shrink",
description: "Clear empty rows when content shrinks (may cause flicker)",
currentValue: config.clearOnShrink ? "true" : "false",
values: ["true", "false"],
});
// Add borders
this.addChild(new DynamicBorder());
this.settingsList = new SettingsList(
items,
10,
getSettingsListTheme(),
(id, newValue) => {
switch (id) {
case "autocompact":
callbacks.onAutoCompactChange(newValue === "true");
break;
case "show-images":
callbacks.onShowImagesChange(newValue === "true");
break;
case "auto-resize-images":
callbacks.onAutoResizeImagesChange(newValue === "true");
break;
case "block-images":
callbacks.onBlockImagesChange(newValue === "true");
break;
case "skill-commands":
callbacks.onEnableSkillCommandsChange(newValue === "true");
break;
case "steering-mode":
callbacks.onSteeringModeChange(newValue as "all" | "one-at-a-time");
break;
case "follow-up-mode":
callbacks.onFollowUpModeChange(newValue as "all" | "one-at-a-time");
break;
case "transport":
callbacks.onTransportChange(newValue as Transport);
break;
case "hide-thinking":
callbacks.onHideThinkingBlockChange(newValue === "true");
break;
case "collapse-changelog":
callbacks.onCollapseChangelogChange(newValue === "true");
break;
case "quiet-startup":
callbacks.onQuietStartupChange(newValue === "true");
break;
case "double-escape-action":
callbacks.onDoubleEscapeActionChange(newValue as "fork" | "tree");
break;
case "tree-filter-mode":
callbacks.onTreeFilterModeChange(
newValue as
| "default"
| "no-tools"
| "user-only"
| "labeled-only"
| "all",
);
break;
case "show-hardware-cursor":
callbacks.onShowHardwareCursorChange(newValue === "true");
break;
case "editor-padding":
callbacks.onEditorPaddingXChange(parseInt(newValue, 10));
break;
case "autocomplete-max-visible":
callbacks.onAutocompleteMaxVisibleChange(parseInt(newValue, 10));
break;
case "clear-on-shrink":
callbacks.onClearOnShrinkChange(newValue === "true");
break;
}
},
callbacks.onCancel,
{ enableSearch: true },
);
this.addChild(this.settingsList);
this.addChild(new DynamicBorder());
}
getSettingsList(): SettingsList {
return this.settingsList;
}
}

View file

@ -0,0 +1,57 @@
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
import { getSelectListTheme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
/**
* Component that renders a show images selector with borders
*/
export class ShowImagesSelectorComponent extends Container {
private selectList: SelectList;
constructor(
currentValue: boolean,
onSelect: (show: boolean) => void,
onCancel: () => void,
) {
super();
const items: SelectItem[] = [
{
value: "yes",
label: "Yes",
description: "Show images inline in terminal",
},
{
value: "no",
label: "No",
description: "Show text placeholder instead",
},
];
// Add top border
this.addChild(new DynamicBorder());
// Create selector
this.selectList = new SelectList(items, 5, getSelectListTheme());
// Preselect current value
this.selectList.setSelectedIndex(currentValue ? 0 : 1);
this.selectList.onSelect = (item) => {
onSelect(item.value === "yes");
};
this.selectList.onCancel = () => {
onCancel();
};
this.addChild(this.selectList);
// Add bottom border
this.addChild(new DynamicBorder());
}
getSelectList(): SelectList {
return this.selectList;
}
}

View file

@ -0,0 +1,64 @@
import { Box, Markdown, type MarkdownTheme, Text } from "@mariozechner/pi-tui";
import type { ParsedSkillBlock } from "../../../core/agent-session.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
import { editorKey } from "./keybinding-hints.js";
/**
* Component that renders a skill invocation message with collapsed/expanded state.
* Uses same background color as custom messages for visual consistency.
* Only renders the skill block itself - user message is rendered separately.
*/
export class SkillInvocationMessageComponent extends Box {
private expanded = false;
private skillBlock: ParsedSkillBlock;
private markdownTheme: MarkdownTheme;
constructor(
skillBlock: ParsedSkillBlock,
markdownTheme: MarkdownTheme = getMarkdownTheme(),
) {
super(1, 1, (t) => theme.bg("customMessageBg", t));
this.skillBlock = skillBlock;
this.markdownTheme = markdownTheme;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
override invalidate(): void {
super.invalidate();
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
if (this.expanded) {
// Expanded: label + skill name header + full content
const label = theme.fg("customMessageLabel", `\x1b[1m[skill]\x1b[22m`);
this.addChild(new Text(label, 0, 0));
const header = `**${this.skillBlock.name}**\n\n`;
this.addChild(
new Markdown(
header + this.skillBlock.content,
0,
0,
this.markdownTheme,
{
color: (text: string) => theme.fg("customMessageText", text),
},
),
);
} else {
// Collapsed: single line - [skill] name (hint to expand)
const line =
theme.fg("customMessageLabel", `\x1b[1m[skill]\x1b[22m `) +
theme.fg("customMessageText", this.skillBlock.name) +
theme.fg("dim", ` (${editorKey("expandTools")} to expand)`);
this.addChild(new Text(line, 0, 0));
}
}
}

View file

@ -0,0 +1,62 @@
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
import { getAvailableThemes, getSelectListTheme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
/**
* Component that renders a theme selector
*/
export class ThemeSelectorComponent extends Container {
private selectList: SelectList;
private onPreview: (themeName: string) => void;
constructor(
currentTheme: string,
onSelect: (themeName: string) => void,
onCancel: () => void,
onPreview: (themeName: string) => void,
) {
super();
this.onPreview = onPreview;
// Get available themes and create select items
const themes = getAvailableThemes();
const themeItems: SelectItem[] = themes.map((name) => ({
value: name,
label: name,
description: name === currentTheme ? "(current)" : undefined,
}));
// Add top border
this.addChild(new DynamicBorder());
// Create selector
this.selectList = new SelectList(themeItems, 10, getSelectListTheme());
// Preselect current theme
const currentIndex = themes.indexOf(currentTheme);
if (currentIndex !== -1) {
this.selectList.setSelectedIndex(currentIndex);
}
this.selectList.onSelect = (item) => {
onSelect(item.value);
};
this.selectList.onCancel = () => {
onCancel();
};
this.selectList.onSelectionChange = (item) => {
this.onPreview(item.value);
};
this.addChild(this.selectList);
// Add bottom border
this.addChild(new DynamicBorder());
}
getSelectList(): SelectList {
return this.selectList;
}
}

View file

@ -0,0 +1,70 @@
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
import { getSelectListTheme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
const LEVEL_DESCRIPTIONS: Record<ThinkingLevel, string> = {
off: "No reasoning",
minimal: "Very brief reasoning (~1k tokens)",
low: "Light reasoning (~2k tokens)",
medium: "Moderate reasoning (~8k tokens)",
high: "Deep reasoning (~16k tokens)",
xhigh: "Maximum reasoning (~32k tokens)",
};
/**
* Component that renders a thinking level selector with borders
*/
export class ThinkingSelectorComponent extends Container {
private selectList: SelectList;
constructor(
currentLevel: ThinkingLevel,
availableLevels: ThinkingLevel[],
onSelect: (level: ThinkingLevel) => void,
onCancel: () => void,
) {
super();
const thinkingLevels: SelectItem[] = availableLevels.map((level) => ({
value: level,
label: level,
description: LEVEL_DESCRIPTIONS[level],
}));
// Add top border
this.addChild(new DynamicBorder());
// Create selector
this.selectList = new SelectList(
thinkingLevels,
thinkingLevels.length,
getSelectListTheme(),
);
// Preselect current level
const currentIndex = thinkingLevels.findIndex(
(item) => item.value === currentLevel,
);
if (currentIndex !== -1) {
this.selectList.setSelectedIndex(currentIndex);
}
this.selectList.onSelect = (item) => {
onSelect(item.value as ThinkingLevel);
};
this.selectList.onCancel = () => {
onCancel();
};
this.addChild(this.selectList);
// Add bottom border
this.addChild(new DynamicBorder());
}
getSelectList(): SelectList {
return this.selectList;
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,179 @@
import {
type Component,
Container,
getEditorKeybindings,
Spacer,
Text,
truncateToWidth,
} from "@mariozechner/pi-tui";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
interface UserMessageItem {
id: string; // Entry ID in the session
text: string; // The message text
timestamp?: string; // Optional timestamp if available
}
/**
* Custom user message list component with selection
*/
class UserMessageList implements Component {
private messages: UserMessageItem[] = [];
private selectedIndex: number = 0;
public onSelect?: (entryId: string) => void;
public onCancel?: () => void;
private maxVisible: number = 10; // Max messages visible
constructor(messages: UserMessageItem[]) {
// Store messages in chronological order (oldest to newest)
this.messages = messages;
// Start with the last (most recent) message selected
this.selectedIndex = Math.max(0, messages.length - 1);
}
invalidate(): void {
// No cached state to invalidate currently
}
render(width: number): string[] {
const lines: string[] = [];
if (this.messages.length === 0) {
lines.push(theme.fg("muted", " No user messages found"));
return lines;
}
// Calculate visible range with scrolling
const startIndex = Math.max(
0,
Math.min(
this.selectedIndex - Math.floor(this.maxVisible / 2),
this.messages.length - this.maxVisible,
),
);
const endIndex = Math.min(
startIndex + this.maxVisible,
this.messages.length,
);
// Render visible messages (2 lines per message + blank line)
for (let i = startIndex; i < endIndex; i++) {
const message = this.messages[i];
const isSelected = i === this.selectedIndex;
// Normalize message to single line
const normalizedMessage = message.text.replace(/\n/g, " ").trim();
// First line: cursor + message
const cursor = isSelected ? theme.fg("accent", " ") : " ";
const maxMsgWidth = width - 2; // Account for cursor (2 chars)
const truncatedMsg = truncateToWidth(normalizedMessage, maxMsgWidth);
const messageLine =
cursor + (isSelected ? theme.bold(truncatedMsg) : truncatedMsg);
lines.push(messageLine);
// Second line: metadata (position in history)
const position = i + 1;
const metadata = ` Message ${position} of ${this.messages.length}`;
const metadataLine = theme.fg("muted", metadata);
lines.push(metadataLine);
lines.push(""); // Blank line between messages
}
// Add scroll indicator if needed
if (startIndex > 0 || endIndex < this.messages.length) {
const scrollInfo = theme.fg(
"muted",
` (${this.selectedIndex + 1}/${this.messages.length})`,
);
lines.push(scrollInfo);
}
return lines;
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
// Up arrow - go to previous (older) message, wrap to bottom when at top
if (kb.matches(keyData, "selectUp")) {
this.selectedIndex =
this.selectedIndex === 0
? this.messages.length - 1
: this.selectedIndex - 1;
}
// Down arrow - go to next (newer) message, wrap to top when at bottom
else if (kb.matches(keyData, "selectDown")) {
this.selectedIndex =
this.selectedIndex === this.messages.length - 1
? 0
: this.selectedIndex + 1;
}
// Enter - select message and branch
else if (kb.matches(keyData, "selectConfirm")) {
const selected = this.messages[this.selectedIndex];
if (selected && this.onSelect) {
this.onSelect(selected.id);
}
}
// Escape - cancel
else if (kb.matches(keyData, "selectCancel")) {
if (this.onCancel) {
this.onCancel();
}
}
}
}
/**
* Component that renders a user message selector for branching
*/
export class UserMessageSelectorComponent extends Container {
private messageList: UserMessageList;
constructor(
messages: UserMessageItem[],
onSelect: (entryId: string) => void,
onCancel: () => void,
) {
super();
// Add header
this.addChild(new Spacer(1));
this.addChild(new Text(theme.bold("Branch from Message"), 1, 0));
this.addChild(
new Text(
theme.fg(
"muted",
"Select a message to create a new branch from that point",
),
1,
0,
),
);
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
// Create message list
this.messageList = new UserMessageList(messages);
this.messageList.onSelect = onSelect;
this.messageList.onCancel = onCancel;
this.addChild(this.messageList);
// Add bottom border
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
// Auto-cancel if no messages
if (messages.length === 0) {
setTimeout(() => onCancel(), 100);
}
}
getMessageList(): UserMessageList {
return this.messageList;
}
}

View file

@ -0,0 +1,37 @@
import {
Container,
Markdown,
type MarkdownTheme,
Spacer,
} from "@mariozechner/pi-tui";
import { getMarkdownTheme, theme } from "../theme/theme.js";
const OSC133_ZONE_START = "\x1b]133;A\x07";
const OSC133_ZONE_END = "\x1b]133;B\x07";
/**
* Component that renders a user message
*/
export class UserMessageComponent extends Container {
constructor(text: string, markdownTheme: MarkdownTheme = getMarkdownTheme()) {
super();
this.addChild(new Spacer(1));
this.addChild(
new Markdown(text, 1, 1, markdownTheme, {
bgColor: (text: string) => theme.bg("userMessageBg", text),
color: (text: string) => theme.fg("userMessageText", text),
}),
);
}
override render(width: number): string[] {
const lines = super.render(width);
if (lines.length === 0) {
return lines;
}
lines[0] = OSC133_ZONE_START + lines[0];
lines[lines.length - 1] = lines[lines.length - 1] + OSC133_ZONE_END;
return lines;
}
}

Some files were not shown because too many files have changed in this diff Show more