mirror of
https://github.com/harivansh-afk/clanker-agent.git
synced 2026-04-15 22:03:44 +00:00
move pi-mono into companion-cloud as apps/companion-os
- Copy all pi-mono source into apps/companion-os/ - Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases - Update deploy-staging.yml to build pi from source (bun compile) before Docker build - Add apps/companion-os/** to path triggers - No more cross-repo dispatch needed Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
0250f72976
579 changed files with 206942 additions and 0 deletions
18
packages/coding-agent/src/cli.ts
Normal file
18
packages/coding-agent/src/cli.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* CLI entry point for the refactored coding agent.
|
||||
* Uses main.ts with AgentSession and new mode modules.
|
||||
*
|
||||
* Test with: npx tsx src/cli-new.ts [args...]
|
||||
*/
|
||||
process.title = "pi";
|
||||
|
||||
import { setBedrockProviderModule } from "@mariozechner/pi-ai";
|
||||
import { bedrockProviderModule } from "@mariozechner/pi-ai/bedrock-provider";
|
||||
import { EnvHttpProxyAgent, setGlobalDispatcher } from "undici";
|
||||
import { main } from "./main.js";
|
||||
|
||||
setGlobalDispatcher(new EnvHttpProxyAgent());
|
||||
setBedrockProviderModule(bedrockProviderModule);
|
||||
|
||||
main(process.argv.slice(2));
|
||||
334
packages/coding-agent/src/cli/args.ts
Normal file
334
packages/coding-agent/src/cli/args.ts
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
/**
|
||||
* CLI argument parsing and help display
|
||||
*/
|
||||
|
||||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import chalk from "chalk";
|
||||
import { APP_NAME, CONFIG_DIR_NAME, ENV_AGENT_DIR } from "../config.js";
|
||||
import { allTools, type ToolName } from "../core/tools/index.js";
|
||||
|
||||
export type Mode = "text" | "json" | "rpc";
|
||||
|
||||
export interface Args {
|
||||
provider?: string;
|
||||
model?: string;
|
||||
apiKey?: string;
|
||||
systemPrompt?: string;
|
||||
appendSystemPrompt?: string;
|
||||
thinking?: ThinkingLevel;
|
||||
continue?: boolean;
|
||||
resume?: boolean;
|
||||
help?: boolean;
|
||||
version?: boolean;
|
||||
mode?: Mode;
|
||||
noSession?: boolean;
|
||||
session?: string;
|
||||
sessionDir?: string;
|
||||
models?: string[];
|
||||
tools?: ToolName[];
|
||||
noTools?: boolean;
|
||||
extensions?: string[];
|
||||
noExtensions?: boolean;
|
||||
print?: boolean;
|
||||
export?: string;
|
||||
noSkills?: boolean;
|
||||
skills?: string[];
|
||||
promptTemplates?: string[];
|
||||
noPromptTemplates?: boolean;
|
||||
themes?: string[];
|
||||
noThemes?: boolean;
|
||||
listModels?: string | true;
|
||||
offline?: boolean;
|
||||
verbose?: boolean;
|
||||
messages: string[];
|
||||
fileArgs: string[];
|
||||
/** Unknown flags (potentially extension flags) - map of flag name to value */
|
||||
unknownFlags: Map<string, boolean | string>;
|
||||
}
|
||||
|
||||
const VALID_THINKING_LEVELS = [
|
||||
"off",
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh",
|
||||
] as const;
|
||||
|
||||
export function isValidThinkingLevel(level: string): level is ThinkingLevel {
|
||||
return VALID_THINKING_LEVELS.includes(level as ThinkingLevel);
|
||||
}
|
||||
|
||||
export function parseArgs(
|
||||
args: string[],
|
||||
extensionFlags?: Map<string, { type: "boolean" | "string" }>,
|
||||
): Args {
|
||||
const result: Args = {
|
||||
messages: [],
|
||||
fileArgs: [],
|
||||
unknownFlags: new Map(),
|
||||
};
|
||||
|
||||
for (let i = 0; i < args.length; i++) {
|
||||
const arg = args[i];
|
||||
|
||||
if (arg === "--help" || arg === "-h") {
|
||||
result.help = true;
|
||||
} else if (arg === "--version" || arg === "-v") {
|
||||
result.version = true;
|
||||
} else if (arg === "--mode" && i + 1 < args.length) {
|
||||
const mode = args[++i];
|
||||
if (mode === "text" || mode === "json" || mode === "rpc") {
|
||||
result.mode = mode;
|
||||
}
|
||||
} else if (arg === "--continue" || arg === "-c") {
|
||||
result.continue = true;
|
||||
} else if (arg === "--resume" || arg === "-r") {
|
||||
result.resume = true;
|
||||
} else if (arg === "--provider" && i + 1 < args.length) {
|
||||
result.provider = args[++i];
|
||||
} else if (arg === "--model" && i + 1 < args.length) {
|
||||
result.model = args[++i];
|
||||
} else if (arg === "--api-key" && i + 1 < args.length) {
|
||||
result.apiKey = args[++i];
|
||||
} else if (arg === "--system-prompt" && i + 1 < args.length) {
|
||||
result.systemPrompt = args[++i];
|
||||
} else if (arg === "--append-system-prompt" && i + 1 < args.length) {
|
||||
result.appendSystemPrompt = args[++i];
|
||||
} else if (arg === "--no-session") {
|
||||
result.noSession = true;
|
||||
} else if (arg === "--session" && i + 1 < args.length) {
|
||||
result.session = args[++i];
|
||||
} else if (arg === "--session-dir" && i + 1 < args.length) {
|
||||
result.sessionDir = args[++i];
|
||||
} else if (arg === "--models" && i + 1 < args.length) {
|
||||
result.models = args[++i].split(",").map((s) => s.trim());
|
||||
} else if (arg === "--no-tools") {
|
||||
result.noTools = true;
|
||||
} else if (arg === "--tools" && i + 1 < args.length) {
|
||||
const toolNames = args[++i].split(",").map((s) => s.trim());
|
||||
const validTools: ToolName[] = [];
|
||||
for (const name of toolNames) {
|
||||
if (name in allTools) {
|
||||
validTools.push(name as ToolName);
|
||||
} else {
|
||||
console.error(
|
||||
chalk.yellow(
|
||||
`Warning: Unknown tool "${name}". Valid tools: ${Object.keys(allTools).join(", ")}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
result.tools = validTools;
|
||||
} else if (arg === "--thinking" && i + 1 < args.length) {
|
||||
const level = args[++i];
|
||||
if (isValidThinkingLevel(level)) {
|
||||
result.thinking = level;
|
||||
} else {
|
||||
console.error(
|
||||
chalk.yellow(
|
||||
`Warning: Invalid thinking level "${level}". Valid values: ${VALID_THINKING_LEVELS.join(", ")}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else if (arg === "--print" || arg === "-p") {
|
||||
result.print = true;
|
||||
} else if (arg === "--export" && i + 1 < args.length) {
|
||||
result.export = args[++i];
|
||||
} else if ((arg === "--extension" || arg === "-e") && i + 1 < args.length) {
|
||||
result.extensions = result.extensions ?? [];
|
||||
result.extensions.push(args[++i]);
|
||||
} else if (arg === "--no-extensions" || arg === "-ne") {
|
||||
result.noExtensions = true;
|
||||
} else if (arg === "--skill" && i + 1 < args.length) {
|
||||
result.skills = result.skills ?? [];
|
||||
result.skills.push(args[++i]);
|
||||
} else if (arg === "--prompt-template" && i + 1 < args.length) {
|
||||
result.promptTemplates = result.promptTemplates ?? [];
|
||||
result.promptTemplates.push(args[++i]);
|
||||
} else if (arg === "--theme" && i + 1 < args.length) {
|
||||
result.themes = result.themes ?? [];
|
||||
result.themes.push(args[++i]);
|
||||
} else if (arg === "--no-skills" || arg === "-ns") {
|
||||
result.noSkills = true;
|
||||
} else if (arg === "--no-prompt-templates" || arg === "-np") {
|
||||
result.noPromptTemplates = true;
|
||||
} else if (arg === "--no-themes") {
|
||||
result.noThemes = true;
|
||||
} else if (arg === "--list-models") {
|
||||
// Check if next arg is a search pattern (not a flag or file arg)
|
||||
if (
|
||||
i + 1 < args.length &&
|
||||
!args[i + 1].startsWith("-") &&
|
||||
!args[i + 1].startsWith("@")
|
||||
) {
|
||||
result.listModels = args[++i];
|
||||
} else {
|
||||
result.listModels = true;
|
||||
}
|
||||
} else if (arg === "--verbose") {
|
||||
result.verbose = true;
|
||||
} else if (arg === "--offline") {
|
||||
result.offline = true;
|
||||
} else if (arg.startsWith("@")) {
|
||||
result.fileArgs.push(arg.slice(1)); // Remove @ prefix
|
||||
} else if (arg.startsWith("--") && extensionFlags) {
|
||||
// Check if it's an extension-registered flag
|
||||
const flagName = arg.slice(2);
|
||||
const extFlag = extensionFlags.get(flagName);
|
||||
if (extFlag) {
|
||||
if (extFlag.type === "boolean") {
|
||||
result.unknownFlags.set(flagName, true);
|
||||
} else if (extFlag.type === "string" && i + 1 < args.length) {
|
||||
result.unknownFlags.set(flagName, args[++i]);
|
||||
}
|
||||
}
|
||||
// Unknown flags without extensionFlags are silently ignored (first pass)
|
||||
} else if (!arg.startsWith("-")) {
|
||||
result.messages.push(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function printHelp(): void {
|
||||
console.log(`${chalk.bold(APP_NAME)} - AI coding assistant with read, bash, edit, write tools
|
||||
|
||||
${chalk.bold("Usage:")}
|
||||
${APP_NAME} [options] [@files...] [messages...]
|
||||
|
||||
${chalk.bold("Commands:")}
|
||||
${APP_NAME} install <source> [-l] Install extension source and add to settings
|
||||
${APP_NAME} remove <source> [-l] Remove extension source from settings
|
||||
${APP_NAME} update [source] Update installed extensions (skips pinned sources)
|
||||
${APP_NAME} list List installed extensions from settings
|
||||
${APP_NAME} gateway Run the always-on gateway process
|
||||
${APP_NAME} daemon Alias for gateway
|
||||
${APP_NAME} config Open TUI to enable/disable package resources
|
||||
${APP_NAME} <command> --help Show help for install/remove/update/list
|
||||
|
||||
${chalk.bold("Options:")}
|
||||
--provider <name> Provider name (default: google)
|
||||
--model <pattern> Model pattern or ID (supports "provider/id" and optional ":<thinking>")
|
||||
--api-key <key> API key (defaults to env vars)
|
||||
--system-prompt <text> System prompt (default: coding assistant prompt)
|
||||
--append-system-prompt <text> Append text or file contents to the system prompt
|
||||
--mode <mode> Output mode: text (default), json, or rpc
|
||||
--print, -p Non-interactive mode: process prompt and exit
|
||||
--continue, -c Continue previous session
|
||||
--resume, -r Select a session to resume
|
||||
--session <path> Use specific session file
|
||||
--session-dir <dir> Directory for session storage and lookup
|
||||
--no-session Don't save session (ephemeral)
|
||||
--models <patterns> Comma-separated model patterns for Ctrl+P cycling
|
||||
Supports globs (anthropic/*, *sonnet*) and fuzzy matching
|
||||
--no-tools Disable all built-in tools
|
||||
--tools <tools> Comma-separated list of tools to enable (default: read,bash,edit,write)
|
||||
Available: read, bash, edit, write, grep, find, ls
|
||||
--thinking <level> Set thinking level: off, minimal, low, medium, high, xhigh
|
||||
--extension, -e <path> Load an extension file (can be used multiple times)
|
||||
--no-extensions, -ne Disable extension discovery (explicit -e paths still work)
|
||||
--skill <path> Load a skill file or directory (can be used multiple times)
|
||||
--no-skills, -ns Disable skills discovery and loading
|
||||
--prompt-template <path> Load a prompt template file or directory (can be used multiple times)
|
||||
--no-prompt-templates, -np Disable prompt template discovery and loading
|
||||
--theme <path> Load a theme file or directory (can be used multiple times)
|
||||
--no-themes Disable theme discovery and loading
|
||||
--export <file> Export session file to HTML and exit
|
||||
--list-models [search] List available models (with optional fuzzy search)
|
||||
--verbose Force verbose startup (overrides quietStartup setting)
|
||||
--offline Disable startup network operations (same as PI_OFFLINE=1)
|
||||
--help, -h Show this help
|
||||
--version, -v Show version number
|
||||
|
||||
Extensions can register additional flags (e.g., --plan from plan-mode extension).
|
||||
|
||||
${chalk.bold("Examples:")}
|
||||
# Interactive mode
|
||||
${APP_NAME}
|
||||
|
||||
# Interactive mode with initial prompt
|
||||
${APP_NAME} "List all .ts files in src/"
|
||||
|
||||
# Include files in initial message
|
||||
${APP_NAME} @prompt.md @image.png "What color is the sky?"
|
||||
|
||||
# Non-interactive mode (process and exit)
|
||||
${APP_NAME} -p "List all .ts files in src/"
|
||||
|
||||
# Multiple messages (interactive)
|
||||
${APP_NAME} "Read package.json" "What dependencies do we have?"
|
||||
|
||||
# Continue previous session
|
||||
${APP_NAME} --continue "What did we discuss?"
|
||||
|
||||
# Use different model
|
||||
${APP_NAME} --provider openai --model gpt-4o-mini "Help me refactor this code"
|
||||
|
||||
# Use model with provider prefix (no --provider needed)
|
||||
${APP_NAME} --model openai/gpt-4o "Help me refactor this code"
|
||||
|
||||
# Use model with thinking level shorthand
|
||||
${APP_NAME} --model sonnet:high "Solve this complex problem"
|
||||
|
||||
# Limit model cycling to specific models
|
||||
${APP_NAME} --models claude-sonnet,claude-haiku,gpt-4o
|
||||
|
||||
# Limit to a specific provider with glob pattern
|
||||
${APP_NAME} --models "github-copilot/*"
|
||||
|
||||
# Cycle models with fixed thinking levels
|
||||
${APP_NAME} --models sonnet:high,haiku:low
|
||||
|
||||
# Start with a specific thinking level
|
||||
${APP_NAME} --thinking high "Solve this complex problem"
|
||||
|
||||
# Read-only mode (no file modifications possible)
|
||||
${APP_NAME} --tools read,grep,find,ls -p "Review the code in src/"
|
||||
|
||||
# Export a session file to HTML
|
||||
${APP_NAME} --export ~/${CONFIG_DIR_NAME}/agent/sessions/--path--/session.jsonl
|
||||
${APP_NAME} --export session.jsonl output.html
|
||||
|
||||
${chalk.bold("Environment Variables:")}
|
||||
ANTHROPIC_API_KEY - Anthropic Claude API key
|
||||
ANTHROPIC_OAUTH_TOKEN - Anthropic OAuth token (alternative to API key)
|
||||
OPENAI_API_KEY - OpenAI GPT API key
|
||||
AZURE_OPENAI_API_KEY - Azure OpenAI API key
|
||||
AZURE_OPENAI_BASE_URL - Azure OpenAI base URL (https://{resource}.openai.azure.com/openai/v1)
|
||||
AZURE_OPENAI_RESOURCE_NAME - Azure OpenAI resource name (alternative to base URL)
|
||||
AZURE_OPENAI_API_VERSION - Azure OpenAI API version (default: v1)
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME_MAP - Azure OpenAI model=deployment map (comma-separated)
|
||||
GEMINI_API_KEY - Google Gemini API key
|
||||
GROQ_API_KEY - Groq API key
|
||||
CEREBRAS_API_KEY - Cerebras API key
|
||||
XAI_API_KEY - xAI Grok API key
|
||||
OPENROUTER_API_KEY - OpenRouter API key
|
||||
AI_GATEWAY_API_KEY - Vercel AI Gateway API key
|
||||
ZAI_API_KEY - ZAI API key
|
||||
MISTRAL_API_KEY - Mistral API key
|
||||
MINIMAX_API_KEY - MiniMax API key
|
||||
OPENCODE_API_KEY - OpenCode Zen/OpenCode Go API key
|
||||
KIMI_API_KEY - Kimi For Coding API key
|
||||
AWS_PROFILE - AWS profile for Amazon Bedrock
|
||||
AWS_ACCESS_KEY_ID - AWS access key for Amazon Bedrock
|
||||
AWS_SECRET_ACCESS_KEY - AWS secret key for Amazon Bedrock
|
||||
AWS_BEARER_TOKEN_BEDROCK - Bedrock API key (bearer token)
|
||||
AWS_REGION - AWS region for Amazon Bedrock (e.g., us-east-1)
|
||||
${ENV_AGENT_DIR.padEnd(32)} - Session storage directory (default: ~/${CONFIG_DIR_NAME}/agent)
|
||||
PI_PACKAGE_DIR - Override package directory (for Nix/Guix store paths)
|
||||
PI_OFFLINE - Disable startup network operations when set to 1/true/yes
|
||||
PI_SHARE_VIEWER_URL - Base URL for /share command (default: https://pi.dev/session/)
|
||||
PI_AI_ANTIGRAVITY_VERSION - Override Antigravity User-Agent version (e.g., 1.23.0)
|
||||
|
||||
${chalk.bold("Available Tools (default: read, bash, edit, write):")}
|
||||
read - Read file contents
|
||||
bash - Execute bash commands
|
||||
edit - Edit files with find/replace
|
||||
write - Write files (creates/overwrites)
|
||||
grep - Search file contents (read-only, off by default)
|
||||
find - Find files by glob pattern (read-only, off by default)
|
||||
ls - List directory contents (read-only, off by default)
|
||||
`);
|
||||
}
|
||||
57
packages/coding-agent/src/cli/config-selector.ts
Normal file
57
packages/coding-agent/src/cli/config-selector.ts
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* TUI config selector for `pi config` command
|
||||
*/
|
||||
|
||||
import { ProcessTerminal, TUI } from "@mariozechner/pi-tui";
|
||||
import type { ResolvedPaths } from "../core/package-manager.js";
|
||||
import type { SettingsManager } from "../core/settings-manager.js";
|
||||
import { ConfigSelectorComponent } from "../modes/interactive/components/config-selector.js";
|
||||
import {
|
||||
initTheme,
|
||||
stopThemeWatcher,
|
||||
} from "../modes/interactive/theme/theme.js";
|
||||
|
||||
export interface ConfigSelectorOptions {
|
||||
resolvedPaths: ResolvedPaths;
|
||||
settingsManager: SettingsManager;
|
||||
cwd: string;
|
||||
agentDir: string;
|
||||
}
|
||||
|
||||
/** Show TUI config selector and return when closed */
|
||||
export async function selectConfig(
|
||||
options: ConfigSelectorOptions,
|
||||
): Promise<void> {
|
||||
// Initialize theme before showing TUI
|
||||
initTheme(options.settingsManager.getTheme(), true);
|
||||
|
||||
return new Promise((resolve) => {
|
||||
const ui = new TUI(new ProcessTerminal());
|
||||
let resolved = false;
|
||||
|
||||
const selector = new ConfigSelectorComponent(
|
||||
options.resolvedPaths,
|
||||
options.settingsManager,
|
||||
options.cwd,
|
||||
options.agentDir,
|
||||
() => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
stopThemeWatcher();
|
||||
resolve();
|
||||
}
|
||||
},
|
||||
() => {
|
||||
ui.stop();
|
||||
stopThemeWatcher();
|
||||
process.exit(0);
|
||||
},
|
||||
() => ui.requestRender(),
|
||||
);
|
||||
|
||||
ui.addChild(selector);
|
||||
ui.setFocus(selector.getResourceList());
|
||||
ui.start();
|
||||
});
|
||||
}
|
||||
105
packages/coding-agent/src/cli/file-processor.ts
Normal file
105
packages/coding-agent/src/cli/file-processor.ts
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Process @file CLI arguments into text content and image attachments
|
||||
*/
|
||||
|
||||
import { access, readFile, stat } from "node:fs/promises";
|
||||
import type { ImageContent } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { resolve } from "path";
|
||||
import { resolveReadPath } from "../core/tools/path-utils.js";
|
||||
import { formatDimensionNote, resizeImage } from "../utils/image-resize.js";
|
||||
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
|
||||
|
||||
export interface ProcessedFiles {
|
||||
text: string;
|
||||
images: ImageContent[];
|
||||
}
|
||||
|
||||
export interface ProcessFileOptions {
|
||||
/** Whether to auto-resize images to 2000x2000 max. Default: true */
|
||||
autoResizeImages?: boolean;
|
||||
}
|
||||
|
||||
/** Process @file arguments into text content and image attachments */
|
||||
export async function processFileArguments(
|
||||
fileArgs: string[],
|
||||
options?: ProcessFileOptions,
|
||||
): Promise<ProcessedFiles> {
|
||||
const autoResizeImages = options?.autoResizeImages ?? true;
|
||||
let text = "";
|
||||
const images: ImageContent[] = [];
|
||||
|
||||
for (const fileArg of fileArgs) {
|
||||
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
|
||||
const absolutePath = resolve(resolveReadPath(fileArg, process.cwd()));
|
||||
|
||||
// Check if file exists
|
||||
try {
|
||||
await access(absolutePath);
|
||||
} catch {
|
||||
console.error(chalk.red(`Error: File not found: ${absolutePath}`));
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Check if file is empty
|
||||
const stats = await stat(absolutePath);
|
||||
if (stats.size === 0) {
|
||||
// Skip empty files
|
||||
continue;
|
||||
}
|
||||
|
||||
const mimeType = await detectSupportedImageMimeTypeFromFile(absolutePath);
|
||||
|
||||
if (mimeType) {
|
||||
// Handle image file
|
||||
const content = await readFile(absolutePath);
|
||||
const base64Content = content.toString("base64");
|
||||
|
||||
let attachment: ImageContent;
|
||||
let dimensionNote: string | undefined;
|
||||
|
||||
if (autoResizeImages) {
|
||||
const resized = await resizeImage({
|
||||
type: "image",
|
||||
data: base64Content,
|
||||
mimeType,
|
||||
});
|
||||
dimensionNote = formatDimensionNote(resized);
|
||||
attachment = {
|
||||
type: "image",
|
||||
mimeType: resized.mimeType,
|
||||
data: resized.data,
|
||||
};
|
||||
} else {
|
||||
attachment = {
|
||||
type: "image",
|
||||
mimeType,
|
||||
data: base64Content,
|
||||
};
|
||||
}
|
||||
|
||||
images.push(attachment);
|
||||
|
||||
// Add text reference to image with optional dimension note
|
||||
if (dimensionNote) {
|
||||
text += `<file name="${absolutePath}">${dimensionNote}</file>\n`;
|
||||
} else {
|
||||
text += `<file name="${absolutePath}"></file>\n`;
|
||||
}
|
||||
} else {
|
||||
// Handle text file
|
||||
try {
|
||||
const content = await readFile(absolutePath, "utf-8");
|
||||
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.error(
|
||||
chalk.red(`Error: Could not read file ${absolutePath}: ${message}`),
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { text, images };
|
||||
}
|
||||
126
packages/coding-agent/src/cli/list-models.ts
Normal file
126
packages/coding-agent/src/cli/list-models.ts
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* List available models with optional fuzzy search
|
||||
*/
|
||||
|
||||
import type { Api, Model } from "@mariozechner/pi-ai";
|
||||
import { fuzzyFilter } from "@mariozechner/pi-tui";
|
||||
import type { ModelRegistry } from "../core/model-registry.js";
|
||||
|
||||
/**
|
||||
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
|
||||
*/
|
||||
function formatTokenCount(count: number): string {
|
||||
if (count >= 1_000_000) {
|
||||
const millions = count / 1_000_000;
|
||||
return millions % 1 === 0 ? `${millions}M` : `${millions.toFixed(1)}M`;
|
||||
}
|
||||
if (count >= 1_000) {
|
||||
const thousands = count / 1_000;
|
||||
return thousands % 1 === 0 ? `${thousands}K` : `${thousands.toFixed(1)}K`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models, optionally filtered by search pattern
|
||||
*/
|
||||
export async function listModels(
|
||||
modelRegistry: ModelRegistry,
|
||||
searchPattern?: string,
|
||||
): Promise<void> {
|
||||
const models = modelRegistry.getAvailable();
|
||||
|
||||
if (models.length === 0) {
|
||||
console.log("No models available. Set API keys in environment variables.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply fuzzy filter if search pattern provided
|
||||
let filteredModels: Model<Api>[] = models;
|
||||
if (searchPattern) {
|
||||
filteredModels = fuzzyFilter(
|
||||
models,
|
||||
searchPattern,
|
||||
(m) => `${m.provider} ${m.id}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (filteredModels.length === 0) {
|
||||
console.log(`No models matching "${searchPattern}"`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort by provider, then by model id
|
||||
filteredModels.sort((a, b) => {
|
||||
const providerCmp = a.provider.localeCompare(b.provider);
|
||||
if (providerCmp !== 0) return providerCmp;
|
||||
return a.id.localeCompare(b.id);
|
||||
});
|
||||
|
||||
// Calculate column widths
|
||||
const rows = filteredModels.map((m) => ({
|
||||
provider: m.provider,
|
||||
model: m.id,
|
||||
context: formatTokenCount(m.contextWindow),
|
||||
maxOut: formatTokenCount(m.maxTokens),
|
||||
thinking: m.reasoning ? "yes" : "no",
|
||||
images: m.input.includes("image") ? "yes" : "no",
|
||||
}));
|
||||
|
||||
const headers = {
|
||||
provider: "provider",
|
||||
model: "model",
|
||||
context: "context",
|
||||
maxOut: "max-out",
|
||||
thinking: "thinking",
|
||||
images: "images",
|
||||
};
|
||||
|
||||
const widths = {
|
||||
provider: Math.max(
|
||||
headers.provider.length,
|
||||
...rows.map((r) => r.provider.length),
|
||||
),
|
||||
model: Math.max(headers.model.length, ...rows.map((r) => r.model.length)),
|
||||
context: Math.max(
|
||||
headers.context.length,
|
||||
...rows.map((r) => r.context.length),
|
||||
),
|
||||
maxOut: Math.max(
|
||||
headers.maxOut.length,
|
||||
...rows.map((r) => r.maxOut.length),
|
||||
),
|
||||
thinking: Math.max(
|
||||
headers.thinking.length,
|
||||
...rows.map((r) => r.thinking.length),
|
||||
),
|
||||
images: Math.max(
|
||||
headers.images.length,
|
||||
...rows.map((r) => r.images.length),
|
||||
),
|
||||
};
|
||||
|
||||
// Print header
|
||||
const headerLine = [
|
||||
headers.provider.padEnd(widths.provider),
|
||||
headers.model.padEnd(widths.model),
|
||||
headers.context.padEnd(widths.context),
|
||||
headers.maxOut.padEnd(widths.maxOut),
|
||||
headers.thinking.padEnd(widths.thinking),
|
||||
headers.images.padEnd(widths.images),
|
||||
].join(" ");
|
||||
console.log(headerLine);
|
||||
|
||||
// Print rows
|
||||
for (const row of rows) {
|
||||
const line = [
|
||||
row.provider.padEnd(widths.provider),
|
||||
row.model.padEnd(widths.model),
|
||||
row.context.padEnd(widths.context),
|
||||
row.maxOut.padEnd(widths.maxOut),
|
||||
row.thinking.padEnd(widths.thinking),
|
||||
row.images.padEnd(widths.images),
|
||||
].join(" ");
|
||||
console.log(line);
|
||||
}
|
||||
}
|
||||
56
packages/coding-agent/src/cli/session-picker.ts
Normal file
56
packages/coding-agent/src/cli/session-picker.ts
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* TUI session selector for --resume flag
|
||||
*/
|
||||
|
||||
import { ProcessTerminal, TUI } from "@mariozechner/pi-tui";
|
||||
import { KeybindingsManager } from "../core/keybindings.js";
|
||||
import type {
|
||||
SessionInfo,
|
||||
SessionListProgress,
|
||||
} from "../core/session-manager.js";
|
||||
import { SessionSelectorComponent } from "../modes/interactive/components/session-selector.js";
|
||||
|
||||
type SessionsLoader = (
|
||||
onProgress?: SessionListProgress,
|
||||
) => Promise<SessionInfo[]>;
|
||||
|
||||
/** Show TUI session selector and return selected session path or null if cancelled */
|
||||
export async function selectSession(
|
||||
currentSessionsLoader: SessionsLoader,
|
||||
allSessionsLoader: SessionsLoader,
|
||||
): Promise<string | null> {
|
||||
return new Promise((resolve) => {
|
||||
const ui = new TUI(new ProcessTerminal());
|
||||
const keybindings = KeybindingsManager.create();
|
||||
let resolved = false;
|
||||
|
||||
const selector = new SessionSelectorComponent(
|
||||
currentSessionsLoader,
|
||||
allSessionsLoader,
|
||||
(path: string) => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
resolve(path);
|
||||
}
|
||||
},
|
||||
() => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
resolve(null);
|
||||
}
|
||||
},
|
||||
() => {
|
||||
ui.stop();
|
||||
process.exit(0);
|
||||
},
|
||||
() => ui.requestRender(),
|
||||
{ showRenameHint: false, keybindings },
|
||||
);
|
||||
|
||||
ui.addChild(selector);
|
||||
ui.setFocus(selector.getSessionList());
|
||||
ui.start();
|
||||
});
|
||||
}
|
||||
256
packages/coding-agent/src/config.ts
Normal file
256
packages/coding-agent/src/config.ts
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import { existsSync, readFileSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { dirname, join, resolve } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
// =============================================================================
|
||||
// Package Detection
|
||||
// =============================================================================
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
/**
|
||||
* Detect if we're running as a Bun compiled binary.
|
||||
* Bun binaries have import.meta.url containing "$bunfs", "~BUN", or "%7EBUN" (Bun's virtual filesystem path)
|
||||
*/
|
||||
export const isBunBinary =
|
||||
import.meta.url.includes("$bunfs") ||
|
||||
import.meta.url.includes("~BUN") ||
|
||||
import.meta.url.includes("%7EBUN");
|
||||
|
||||
/** Detect if Bun is the runtime (compiled binary or bun run) */
|
||||
export const isBunRuntime = !!process.versions.bun;
|
||||
|
||||
// =============================================================================
|
||||
// Install Method Detection
|
||||
// =============================================================================
|
||||
|
||||
export type InstallMethod =
|
||||
| "bun-binary"
|
||||
| "npm"
|
||||
| "pnpm"
|
||||
| "yarn"
|
||||
| "bun"
|
||||
| "unknown";
|
||||
|
||||
export function detectInstallMethod(): InstallMethod {
|
||||
if (isBunBinary) {
|
||||
return "bun-binary";
|
||||
}
|
||||
|
||||
const resolvedPath = `${__dirname}\0${process.execPath || ""}`.toLowerCase();
|
||||
|
||||
if (
|
||||
resolvedPath.includes("/pnpm/") ||
|
||||
resolvedPath.includes("/.pnpm/") ||
|
||||
resolvedPath.includes("\\pnpm\\")
|
||||
) {
|
||||
return "pnpm";
|
||||
}
|
||||
if (
|
||||
resolvedPath.includes("/yarn/") ||
|
||||
resolvedPath.includes("/.yarn/") ||
|
||||
resolvedPath.includes("\\yarn\\")
|
||||
) {
|
||||
return "yarn";
|
||||
}
|
||||
if (isBunRuntime) {
|
||||
return "bun";
|
||||
}
|
||||
if (
|
||||
resolvedPath.includes("/npm/") ||
|
||||
resolvedPath.includes("/node_modules/") ||
|
||||
resolvedPath.includes("\\npm\\")
|
||||
) {
|
||||
return "npm";
|
||||
}
|
||||
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
export function getUpdateInstruction(packageName: string): string {
|
||||
const method = detectInstallMethod();
|
||||
switch (method) {
|
||||
case "bun-binary":
|
||||
return `Download from: https://github.com/badlogic/pi-mono/releases/latest`;
|
||||
case "pnpm":
|
||||
return `Run: pnpm install -g ${packageName}`;
|
||||
case "yarn":
|
||||
return `Run: yarn global add ${packageName}`;
|
||||
case "bun":
|
||||
return `Run: bun install -g ${packageName}`;
|
||||
case "npm":
|
||||
return `Run: npm install -g ${packageName}`;
|
||||
default:
|
||||
return `Run: npm install -g ${packageName}`;
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Package Asset Paths (shipped with executable)
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Get the base directory for resolving package assets (themes, package.json, README.md, CHANGELOG.md).
|
||||
* - For Bun binary: returns the directory containing the executable
|
||||
* - For Node.js (dist/): returns __dirname (the dist/ directory)
|
||||
* - For tsx (src/): returns parent directory (the package root)
|
||||
*/
|
||||
export function getPackageDir(): string {
|
||||
// Allow override via environment variable (useful for Nix/Guix where store paths tokenize poorly)
|
||||
const envDir = process.env.PI_PACKAGE_DIR;
|
||||
if (envDir) {
|
||||
if (envDir === "~") return homedir();
|
||||
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
|
||||
return envDir;
|
||||
}
|
||||
|
||||
if (isBunBinary) {
|
||||
// Bun binary: process.execPath points to the compiled executable
|
||||
return dirname(process.execPath);
|
||||
}
|
||||
// Node.js: walk up from __dirname until we find package.json
|
||||
let dir = __dirname;
|
||||
while (dir !== dirname(dir)) {
|
||||
if (existsSync(join(dir, "package.json"))) {
|
||||
return dir;
|
||||
}
|
||||
dir = dirname(dir);
|
||||
}
|
||||
// Fallback (shouldn't happen)
|
||||
return __dirname;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get path to built-in themes directory (shipped with package)
|
||||
* - For Bun binary: theme/ next to executable
|
||||
* - For Node.js (dist/): dist/modes/interactive/theme/
|
||||
* - For tsx (src/): src/modes/interactive/theme/
|
||||
*/
|
||||
export function getThemesDir(): string {
|
||||
if (isBunBinary) {
|
||||
return join(dirname(process.execPath), "theme");
|
||||
}
|
||||
// Theme is in modes/interactive/theme/ relative to src/ or dist/
|
||||
const packageDir = getPackageDir();
|
||||
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
|
||||
return join(packageDir, srcOrDist, "modes", "interactive", "theme");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get path to HTML export template directory (shipped with package)
|
||||
* - For Bun binary: export-html/ next to executable
|
||||
* - For Node.js (dist/): dist/core/export-html/
|
||||
* - For tsx (src/): src/core/export-html/
|
||||
*/
|
||||
export function getExportTemplateDir(): string {
|
||||
if (isBunBinary) {
|
||||
return join(dirname(process.execPath), "export-html");
|
||||
}
|
||||
const packageDir = getPackageDir();
|
||||
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
|
||||
return join(packageDir, srcOrDist, "core", "export-html");
|
||||
}
|
||||
|
||||
/** Get path to package.json */
|
||||
export function getPackageJsonPath(): string {
|
||||
return join(getPackageDir(), "package.json");
|
||||
}
|
||||
|
||||
/** Get path to README.md */
|
||||
export function getReadmePath(): string {
|
||||
return resolve(join(getPackageDir(), "README.md"));
|
||||
}
|
||||
|
||||
/** Get path to docs directory */
|
||||
export function getDocsPath(): string {
|
||||
return resolve(join(getPackageDir(), "docs"));
|
||||
}
|
||||
|
||||
/** Get path to CHANGELOG.md */
|
||||
export function getChangelogPath(): string {
|
||||
return resolve(join(getPackageDir(), "CHANGELOG.md"));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// App Config (from package.json piConfig)
|
||||
// =============================================================================
|
||||
|
||||
const pkg = JSON.parse(readFileSync(getPackageJsonPath(), "utf-8"));
|
||||
|
||||
export const APP_NAME: string = pkg.piConfig?.name || "pi";
|
||||
export const CONFIG_DIR_NAME: string = pkg.piConfig?.configDir || ".pi";
|
||||
export const VERSION: string = pkg.version;
|
||||
|
||||
// e.g., PI_CODING_AGENT_DIR or TAU_CODING_AGENT_DIR
|
||||
export const ENV_AGENT_DIR = `${APP_NAME.toUpperCase()}_CODING_AGENT_DIR`;
|
||||
|
||||
const DEFAULT_SHARE_VIEWER_URL = "https://pi.dev/session/";
|
||||
|
||||
/** Get the share viewer URL for a gist ID */
|
||||
export function getShareViewerUrl(gistId: string): string {
|
||||
const baseUrl = process.env.PI_SHARE_VIEWER_URL || DEFAULT_SHARE_VIEWER_URL;
|
||||
return `${baseUrl}#${gistId}`;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// User Config Paths (~/.pi/agent/*)
|
||||
// =============================================================================
|
||||
|
||||
/** Get the agent config directory (e.g., ~/.pi/agent/) */
|
||||
export function getAgentDir(): string {
|
||||
const envDir = process.env[ENV_AGENT_DIR];
|
||||
if (envDir) {
|
||||
// Expand tilde to home directory
|
||||
if (envDir === "~") return homedir();
|
||||
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
|
||||
return envDir;
|
||||
}
|
||||
return join(homedir(), CONFIG_DIR_NAME, "agent");
|
||||
}
|
||||
|
||||
/** Get path to user's custom themes directory */
|
||||
export function getCustomThemesDir(): string {
|
||||
return join(getAgentDir(), "themes");
|
||||
}
|
||||
|
||||
/** Get path to models.json */
|
||||
export function getModelsPath(): string {
|
||||
return join(getAgentDir(), "models.json");
|
||||
}
|
||||
|
||||
/** Get path to auth.json */
|
||||
export function getAuthPath(): string {
|
||||
return join(getAgentDir(), "auth.json");
|
||||
}
|
||||
|
||||
/** Get path to settings.json */
|
||||
export function getSettingsPath(): string {
|
||||
return join(getAgentDir(), "settings.json");
|
||||
}
|
||||
|
||||
/** Get path to tools directory */
|
||||
export function getToolsDir(): string {
|
||||
return join(getAgentDir(), "tools");
|
||||
}
|
||||
|
||||
/** Get path to managed binaries directory (fd, rg) */
|
||||
export function getBinDir(): string {
|
||||
return join(getAgentDir(), "bin");
|
||||
}
|
||||
|
||||
/** Get path to prompt templates directory */
|
||||
export function getPromptsDir(): string {
|
||||
return join(getAgentDir(), "prompts");
|
||||
}
|
||||
|
||||
/** Get path to sessions directory */
|
||||
export function getSessionsDir(): string {
|
||||
return join(getAgentDir(), "sessions");
|
||||
}
|
||||
|
||||
/** Get path to debug log file */
|
||||
export function getDebugLogPath(): string {
|
||||
return join(getAgentDir(), `${APP_NAME}-debug.log`);
|
||||
}
|
||||
3337
packages/coding-agent/src/core/agent-session.ts
Normal file
3337
packages/coding-agent/src/core/agent-session.ts
Normal file
File diff suppressed because it is too large
Load diff
503
packages/coding-agent/src/core/auth-storage.ts
Normal file
503
packages/coding-agent/src/core/auth-storage.ts
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
/**
|
||||
* Credential storage for API keys and OAuth tokens.
|
||||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*
|
||||
* Uses file locking to prevent race conditions when multiple pi instances
|
||||
* try to refresh tokens simultaneously.
|
||||
*/
|
||||
|
||||
import {
|
||||
getEnvApiKey,
|
||||
type OAuthCredentials,
|
||||
type OAuthLoginCallbacks,
|
||||
type OAuthProviderId,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
getOAuthApiKey,
|
||||
getOAuthProvider,
|
||||
getOAuthProviders,
|
||||
} from "@mariozechner/pi-ai/oauth";
|
||||
import {
|
||||
chmodSync,
|
||||
existsSync,
|
||||
mkdirSync,
|
||||
readFileSync,
|
||||
writeFileSync,
|
||||
} from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import lockfile from "proper-lockfile";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { resolveConfigValue } from "./resolve-config-value.js";
|
||||
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
key: string;
|
||||
};
|
||||
|
||||
export type OAuthCredential = {
|
||||
type: "oauth";
|
||||
} & OAuthCredentials;
|
||||
|
||||
export type AuthCredential = ApiKeyCredential | OAuthCredential;
|
||||
|
||||
export type AuthStorageData = Record<string, AuthCredential>;
|
||||
|
||||
type LockResult<T> = {
|
||||
result: T;
|
||||
next?: string;
|
||||
};
|
||||
|
||||
export interface AuthStorageBackend {
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
|
||||
withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T>;
|
||||
}
|
||||
|
||||
export class FileAuthStorageBackend implements AuthStorageBackend {
|
||||
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
|
||||
|
||||
private ensureParentDir(): void {
|
||||
const dir = dirname(this.authPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
}
|
||||
|
||||
private ensureFileExists(): void {
|
||||
if (!existsSync(this.authPath)) {
|
||||
writeFileSync(this.authPath, "{}", "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
}
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => void) | undefined;
|
||||
try {
|
||||
release = lockfile.lockSync(this.authPath, { realpath: false });
|
||||
const current = existsSync(this.authPath)
|
||||
? readFileSync(this.authPath, "utf-8")
|
||||
: undefined;
|
||||
const { result, next } = fn(current);
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T> {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => Promise<void>) | undefined;
|
||||
let lockCompromised = false;
|
||||
let lockCompromisedError: Error | undefined;
|
||||
const throwIfCompromised = () => {
|
||||
if (lockCompromised) {
|
||||
throw (
|
||||
lockCompromisedError ?? new Error("Auth storage lock was compromised")
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
release = await lockfile.lock(this.authPath, {
|
||||
retries: {
|
||||
retries: 10,
|
||||
factor: 2,
|
||||
minTimeout: 100,
|
||||
maxTimeout: 10000,
|
||||
randomize: true,
|
||||
},
|
||||
stale: 30000,
|
||||
onCompromised: (err) => {
|
||||
lockCompromised = true;
|
||||
lockCompromisedError = err;
|
||||
},
|
||||
});
|
||||
|
||||
throwIfCompromised();
|
||||
const current = existsSync(this.authPath)
|
||||
? readFileSync(this.authPath, "utf-8")
|
||||
: undefined;
|
||||
const { result, next } = await fn(current);
|
||||
throwIfCompromised();
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
throwIfCompromised();
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
try {
|
||||
await release();
|
||||
} catch {
|
||||
// Ignore unlock errors when lock is compromised.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
|
||||
private value: string | undefined;
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
const { result, next } = fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async withLockAsync<T>(
|
||||
fn: (current: string | undefined) => Promise<LockResult<T>>,
|
||||
): Promise<T> {
|
||||
const { result, next } = await fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Credential storage backed by a JSON file.
|
||||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
private runtimeOverrides: Map<string, string> = new Map();
|
||||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
private loadError: Error | null = null;
|
||||
private errors: Error[] = [];
|
||||
|
||||
private constructor(private storage: AuthStorageBackend) {
|
||||
this.reload();
|
||||
}
|
||||
|
||||
static create(authPath?: string): AuthStorage {
|
||||
return new AuthStorage(
|
||||
new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")),
|
||||
);
|
||||
}
|
||||
|
||||
static fromStorage(storage: AuthStorageBackend): AuthStorage {
|
||||
return new AuthStorage(storage);
|
||||
}
|
||||
|
||||
static inMemory(data: AuthStorageData = {}): AuthStorage {
|
||||
const storage = new InMemoryAuthStorageBackend();
|
||||
storage.withLock(() => ({
|
||||
result: undefined,
|
||||
next: JSON.stringify(data, null, 2),
|
||||
}));
|
||||
return AuthStorage.fromStorage(storage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a runtime API key override (not persisted to disk).
|
||||
* Used for CLI --api-key flag.
|
||||
*/
|
||||
setRuntimeApiKey(provider: string, apiKey: string): void {
|
||||
this.runtimeOverrides.set(provider, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a runtime API key override.
|
||||
*/
|
||||
removeRuntimeApiKey(provider: string): void {
|
||||
this.runtimeOverrides.delete(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a fallback resolver for API keys not found in auth.json or env vars.
|
||||
* Used for custom provider keys from models.json.
|
||||
*/
|
||||
setFallbackResolver(
|
||||
resolver: (provider: string) => string | undefined,
|
||||
): void {
|
||||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
private recordError(error: unknown): void {
|
||||
const normalizedError =
|
||||
error instanceof Error ? error : new Error(String(error));
|
||||
this.errors.push(normalizedError);
|
||||
}
|
||||
|
||||
private parseStorageData(content: string | undefined): AuthStorageData {
|
||||
if (!content) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(content) as AuthStorageData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from storage.
|
||||
*/
|
||||
reload(): void {
|
||||
let content: string | undefined;
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
content = current;
|
||||
return { result: undefined };
|
||||
});
|
||||
this.data = this.parseStorageData(content);
|
||||
this.loadError = null;
|
||||
} catch (error) {
|
||||
this.loadError = error as Error;
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
private persistProviderChange(
|
||||
provider: string,
|
||||
credential: AuthCredential | undefined,
|
||||
): void {
|
||||
if (this.loadError) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
const merged: AuthStorageData = { ...currentData };
|
||||
if (credential) {
|
||||
merged[provider] = credential;
|
||||
} else {
|
||||
delete merged[provider];
|
||||
}
|
||||
return { result: undefined, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get credential for a provider.
|
||||
*/
|
||||
get(provider: string): AuthCredential | undefined {
|
||||
return this.data[provider] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set credential for a provider.
|
||||
*/
|
||||
set(provider: string, credential: AuthCredential): void {
|
||||
this.data[provider] = credential;
|
||||
this.persistProviderChange(provider, credential);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove credential for a provider.
|
||||
*/
|
||||
remove(provider: string): void {
|
||||
delete this.data[provider];
|
||||
this.persistProviderChange(provider, undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers with credentials.
|
||||
*/
|
||||
list(): string[] {
|
||||
return Object.keys(this.data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if credentials exist for a provider in auth.json.
|
||||
*/
|
||||
has(provider: string): boolean {
|
||||
return provider in this.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if any form of auth is configured for a provider.
|
||||
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
|
||||
*/
|
||||
hasAuth(provider: string): boolean {
|
||||
if (this.runtimeOverrides.has(provider)) return true;
|
||||
if (this.data[provider]) return true;
|
||||
if (getEnvApiKey(provider)) return true;
|
||||
if (this.fallbackResolver?.(provider)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all credentials (for passing to getOAuthApiKey).
|
||||
*/
|
||||
getAll(): AuthStorageData {
|
||||
return { ...this.data };
|
||||
}
|
||||
|
||||
drainErrors(): Error[] {
|
||||
const drained = [...this.errors];
|
||||
this.errors = [];
|
||||
return drained;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(
|
||||
providerId: OAuthProviderId,
|
||||
callbacks: OAuthLoginCallbacks,
|
||||
): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
const credentials = await provider.login(callbacks);
|
||||
this.set(providerId, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout from a provider.
|
||||
*/
|
||||
logout(provider: string): void {
|
||||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OAuth token with backend locking to prevent race conditions.
|
||||
* Multiple pi instances may try to refresh simultaneously when tokens expire.
|
||||
*/
|
||||
private async refreshOAuthTokenWithLock(
|
||||
providerId: OAuthProviderId,
|
||||
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const result = await this.storage.withLockAsync(async (current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
this.data = currentData;
|
||||
this.loadError = null;
|
||||
|
||||
const cred = currentData[providerId];
|
||||
if (cred?.type !== "oauth") {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
result: { apiKey: provider.getApiKey(cred), newCredentials: cred },
|
||||
};
|
||||
}
|
||||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(currentData)) {
|
||||
if (value.type === "oauth") {
|
||||
oauthCreds[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
|
||||
if (!refreshed) {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
const merged: AuthStorageData = {
|
||||
...currentData,
|
||||
[providerId]: { type: "oauth", ...refreshed.newCredentials },
|
||||
};
|
||||
this.data = merged;
|
||||
this.loadError = null;
|
||||
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed with locking)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[providerId];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
// Unknown OAuth provider, can't get API key
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
|
||||
if (needsRefresh) {
|
||||
// Use locked refresh to prevent race conditions
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) {
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
// Refresh failed - re-read file to check if another instance succeeded
|
||||
this.reload();
|
||||
const updatedCred = this.data[providerId];
|
||||
|
||||
if (
|
||||
updatedCred?.type === "oauth" &&
|
||||
Date.now() < updatedCred.expires
|
||||
) {
|
||||
// Another instance refreshed successfully, use those credentials
|
||||
return provider.getApiKey(updatedCred);
|
||||
}
|
||||
|
||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||
// User can /login to re-authenticate (credentials preserved for retry)
|
||||
return undefined;
|
||||
}
|
||||
} else {
|
||||
// Token not expired, use current access token
|
||||
return provider.getApiKey(cred);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
const envKey = getEnvApiKey(providerId);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(providerId) ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
getOAuthProviders() {
|
||||
return getOAuthProviders();
|
||||
}
|
||||
}
|
||||
296
packages/coding-agent/src/core/bash-executor.ts
Normal file
296
packages/coding-agent/src/core/bash-executor.ts
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
/**
|
||||
* Bash command execution with streaming support and cancellation.
|
||||
*
|
||||
* This module provides a unified bash execution implementation used by:
|
||||
* - AgentSession.executeBash() for interactive and RPC modes
|
||||
* - Direct calls from modes that need bash execution
|
||||
*/
|
||||
|
||||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream, type WriteStream } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { type ChildProcess, spawn } from "child_process";
|
||||
import stripAnsi from "strip-ansi";
|
||||
import {
|
||||
getShellConfig,
|
||||
getShellEnv,
|
||||
killProcessTree,
|
||||
sanitizeBinaryOutput,
|
||||
} from "../utils/shell.js";
|
||||
import type { BashOperations } from "./tools/bash.js";
|
||||
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BashExecutorOptions {
|
||||
/** Callback for streaming output chunks (already sanitized) */
|
||||
onChunk?: (chunk: string) => void;
|
||||
/** AbortSignal for cancellation */
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BashResult {
|
||||
/** Combined stdout + stderr output (sanitized, possibly truncated) */
|
||||
output: string;
|
||||
/** Process exit code (undefined if killed/cancelled) */
|
||||
exitCode: number | undefined;
|
||||
/** Whether the command was cancelled via signal */
|
||||
cancelled: boolean;
|
||||
/** Whether the output was truncated */
|
||||
truncated: boolean;
|
||||
/** Path to temp file containing full output (if output exceeded truncation threshold) */
|
||||
fullOutputPath?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Implementation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Execute a bash command with optional streaming and cancellation support.
|
||||
*
|
||||
* Features:
|
||||
* - Streams sanitized output via onChunk callback
|
||||
* - Writes large output to temp file for later retrieval
|
||||
* - Supports cancellation via AbortSignal
|
||||
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
|
||||
* - Truncates output if it exceeds the default max bytes
|
||||
*
|
||||
* @param command - The bash command to execute
|
||||
* @param options - Optional streaming callback and abort signal
|
||||
* @returns Promise resolving to execution result
|
||||
*/
|
||||
export function executeBash(
|
||||
command: string,
|
||||
options?: BashExecutorOptions,
|
||||
): Promise<BashResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const { shell, args } = getShellConfig();
|
||||
const child: ChildProcess = spawn(shell, [...args, command], {
|
||||
detached: true,
|
||||
env: getShellEnv(),
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
// Track sanitized output for truncation
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
// Temp file for large output
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
// Handle abort signal
|
||||
const abortHandler = () => {
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
};
|
||||
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
// Already aborted, don't even start
|
||||
child.kill();
|
||||
resolve({
|
||||
output: "",
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: false,
|
||||
});
|
||||
return;
|
||||
}
|
||||
options.signal.addEventListener("abort", abortHandler, { once: true });
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const handleData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(
|
||||
stripAnsi(decoder.decode(data, { stream: true })),
|
||||
).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
// Write already-buffered chunks to temp file
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer of sanitized text
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback if provided
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
child.stdout?.on("data", handleData);
|
||||
child.stderr?.on("data", handleData);
|
||||
|
||||
child.on("close", (code) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine buffered chunks for truncation (already sanitized)
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
|
||||
// code === null means killed (cancelled)
|
||||
const cancelled = code === null;
|
||||
|
||||
resolve({
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: cancelled ? undefined : code,
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
});
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a bash command using custom BashOperations.
|
||||
* Used for remote execution (SSH, containers, etc.).
|
||||
*/
|
||||
export async function executeBashWithOperations(
|
||||
command: string,
|
||||
cwd: string,
|
||||
operations: BashOperations,
|
||||
options?: BashExecutorOptions,
|
||||
): Promise<BashResult> {
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const onData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(
|
||||
stripAnsi(decoder.decode(data, { stream: true })),
|
||||
).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const result = await operations.exec(command, cwd, {
|
||||
onData,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
const cancelled = options?.signal?.aborted ?? false;
|
||||
|
||||
return {
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
} catch (err) {
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Check if it was an abort
|
||||
if (options?.signal?.aborted) {
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
return {
|
||||
output: truncationResult.truncated
|
||||
? truncationResult.content
|
||||
: fullOutput,
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
/**
|
||||
* Branch summarization for tree navigation.
|
||||
*
|
||||
* When navigating to a different point in the session tree, this generates
|
||||
* a summary of the branch being left so context isn't lost.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { completeSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type {
|
||||
ReadonlySessionManager,
|
||||
SessionEntry,
|
||||
} from "../session-manager.js";
|
||||
import { estimateTokens } from "./compaction.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BranchSummaryResult {
|
||||
summary?: string;
|
||||
readFiles?: string[];
|
||||
modifiedFiles?: string[];
|
||||
aborted?: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Details stored in BranchSummaryEntry.details for file tracking */
|
||||
export interface BranchSummaryDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
export type { FileOperations } from "./utils.js";
|
||||
|
||||
export interface BranchPreparation {
|
||||
/** Messages extracted for summarization, in chronological order */
|
||||
messages: AgentMessage[];
|
||||
/** File operations extracted from tool calls */
|
||||
fileOps: FileOperations;
|
||||
/** Total estimated tokens in messages */
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export interface CollectEntriesResult {
|
||||
/** Entries to summarize, in chronological order */
|
||||
entries: SessionEntry[];
|
||||
/** Common ancestor between old and new position, if any */
|
||||
commonAncestorId: string | null;
|
||||
}
|
||||
|
||||
export interface GenerateBranchSummaryOptions {
|
||||
/** Model to use for summarization */
|
||||
model: Model<any>;
|
||||
/** API key for the model */
|
||||
apiKey: string;
|
||||
/** Abort signal for cancellation */
|
||||
signal: AbortSignal;
|
||||
/** Optional custom instructions for summarization */
|
||||
customInstructions?: string;
|
||||
/** If true, customInstructions replaces the default prompt instead of being appended */
|
||||
replaceInstructions?: boolean;
|
||||
/** Tokens reserved for prompt + LLM response (default 16384) */
|
||||
reserveTokens?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry Collection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Collect entries that should be summarized when navigating from one position to another.
|
||||
*
|
||||
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
|
||||
* along the way. Does NOT stop at compaction boundaries - those are included and their
|
||||
* summaries become context.
|
||||
*
|
||||
* @param session - Session manager (read-only access)
|
||||
* @param oldLeafId - Current position (where we're navigating from)
|
||||
* @param targetId - Target position (where we're navigating to)
|
||||
* @returns Entries to summarize and the common ancestor
|
||||
*/
|
||||
export function collectEntriesForBranchSummary(
|
||||
session: ReadonlySessionManager,
|
||||
oldLeafId: string | null,
|
||||
targetId: string,
|
||||
): CollectEntriesResult {
|
||||
// If no old position, nothing to summarize
|
||||
if (!oldLeafId) {
|
||||
return { entries: [], commonAncestorId: null };
|
||||
}
|
||||
|
||||
// Find common ancestor (deepest node that's on both paths)
|
||||
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
|
||||
const targetPath = session.getBranch(targetId);
|
||||
|
||||
// targetPath is root-first, so iterate backwards to find deepest common ancestor
|
||||
let commonAncestorId: string | null = null;
|
||||
for (let i = targetPath.length - 1; i >= 0; i--) {
|
||||
if (oldPath.has(targetPath[i].id)) {
|
||||
commonAncestorId = targetPath[i].id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect entries from old leaf back to common ancestor
|
||||
const entries: SessionEntry[] = [];
|
||||
let current: string | null = oldLeafId;
|
||||
|
||||
while (current && current !== commonAncestorId) {
|
||||
const entry = session.getEntry(current);
|
||||
if (!entry) break;
|
||||
entries.push(entry);
|
||||
current = entry.parentId;
|
||||
}
|
||||
|
||||
// Reverse to get chronological order
|
||||
entries.reverse();
|
||||
|
||||
return { entries, commonAncestorId };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry to Message Conversion
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from a session entry.
|
||||
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
switch (entry.type) {
|
||||
case "message":
|
||||
// Skip tool results - context is in assistant's tool call
|
||||
if (entry.message.role === "toolResult") return undefined;
|
||||
return entry.message;
|
||||
|
||||
case "custom_message":
|
||||
return createCustomMessage(
|
||||
entry.customType,
|
||||
entry.content,
|
||||
entry.display,
|
||||
entry.details,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
case "branch_summary":
|
||||
return createBranchSummaryMessage(
|
||||
entry.summary,
|
||||
entry.fromId,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
case "compaction":
|
||||
return createCompactionSummaryMessage(
|
||||
entry.summary,
|
||||
entry.tokensBefore,
|
||||
entry.timestamp,
|
||||
);
|
||||
|
||||
// These don't contribute to conversation content
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "custom":
|
||||
case "label":
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare entries for summarization with token budget.
|
||||
*
|
||||
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
|
||||
* This ensures we keep the most recent context when the branch is too long.
|
||||
*
|
||||
* Also collects file operations from:
|
||||
* - Tool calls in assistant messages
|
||||
* - Existing branch_summary entries' details (for cumulative tracking)
|
||||
*
|
||||
* @param entries - Entries in chronological order
|
||||
* @param tokenBudget - Maximum tokens to include (0 = no limit)
|
||||
*/
|
||||
export function prepareBranchEntries(
|
||||
entries: SessionEntry[],
|
||||
tokenBudget: number = 0,
|
||||
): BranchPreparation {
|
||||
const messages: AgentMessage[] = [];
|
||||
const fileOps = createFileOps();
|
||||
let totalTokens = 0;
|
||||
|
||||
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
|
||||
// This ensures we capture cumulative file tracking from nested branch summaries
|
||||
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
|
||||
for (const entry of entries) {
|
||||
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
|
||||
const details = entry.details as BranchSummaryDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
// Modified files go into both edited and written for proper deduplication
|
||||
for (const f of details.modifiedFiles) {
|
||||
fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: walk from newest to oldest, adding messages until token budget
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
const message = getMessageFromEntry(entry);
|
||||
if (!message) continue;
|
||||
|
||||
// Extract file ops from assistant messages (tool calls)
|
||||
extractFileOpsFromMessage(message, fileOps);
|
||||
|
||||
const tokens = estimateTokens(message);
|
||||
|
||||
// Check budget before adding
|
||||
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
|
||||
// If this is a summary entry, try to fit it anyway as it's important context
|
||||
if (entry.type === "compaction" || entry.type === "branch_summary") {
|
||||
if (totalTokens < tokenBudget * 0.9) {
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
}
|
||||
// Stop - we've hit the budget
|
||||
break;
|
||||
}
|
||||
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
|
||||
return { messages, fileOps, totalTokens };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summary Generation
|
||||
// ============================================================================
|
||||
|
||||
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
|
||||
Summary of that exploration:
|
||||
|
||||
`;
|
||||
|
||||
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What was the user trying to accomplish in this branch?]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Work that was started but not finished]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [What should happen next to continue this work]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of abandoned branch entries.
|
||||
*
|
||||
* @param entries - Session entries to summarize (chronological order)
|
||||
* @param options - Generation options
|
||||
*/
|
||||
export async function generateBranchSummary(
|
||||
entries: SessionEntry[],
|
||||
options: GenerateBranchSummaryOptions,
|
||||
): Promise<BranchSummaryResult> {
|
||||
const {
|
||||
model,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
replaceInstructions,
|
||||
reserveTokens = 16384,
|
||||
} = options;
|
||||
|
||||
// Token budget = context window minus reserved space for prompt + response
|
||||
const contextWindow = model.contextWindow || 128000;
|
||||
const tokenBudget = contextWindow - reserveTokens;
|
||||
|
||||
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
|
||||
|
||||
if (messages.length === 0) {
|
||||
return { summary: "No content to summarize" };
|
||||
}
|
||||
|
||||
// Transform to LLM-compatible messages, then serialize to text
|
||||
// Serialization prevents the model from treating it as a conversation to continue
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build prompt
|
||||
let instructions: string;
|
||||
if (replaceInstructions && customInstructions) {
|
||||
instructions = customInstructions;
|
||||
} else if (customInstructions) {
|
||||
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
|
||||
} else {
|
||||
instructions = BRANCH_SUMMARY_PROMPT;
|
||||
}
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
// Call LLM for summarization
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
{ apiKey, signal, maxTokens: 2048 },
|
||||
);
|
||||
|
||||
// Check if aborted or errored
|
||||
if (response.stopReason === "aborted") {
|
||||
return { aborted: true };
|
||||
}
|
||||
if (response.stopReason === "error") {
|
||||
return { error: response.errorMessage || "Summarization failed" };
|
||||
}
|
||||
|
||||
let summary = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
// Prepend preamble to provide context about the branch summary
|
||||
summary = BRANCH_SUMMARY_PREAMBLE + summary;
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
return {
|
||||
summary: summary || "No summary generated",
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
};
|
||||
}
|
||||
899
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
899
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
|
|
@ -0,0 +1,899 @@
|
|||
/**
|
||||
* Context compaction for long sessions.
|
||||
*
|
||||
* Pure functions for compaction logic. The session manager handles I/O,
|
||||
* and after compaction the session is reloaded.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
||||
import { completeSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
/** Details stored in CompactionEntry.details for file tracking */
|
||||
export interface CompactionDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from messages and previous compaction entries.
|
||||
*/
|
||||
function extractFileOperations(
|
||||
messages: AgentMessage[],
|
||||
entries: SessionEntry[],
|
||||
prevCompactionIndex: number,
|
||||
): FileOperations {
|
||||
const fileOps = createFileOps();
|
||||
|
||||
// Collect from previous compaction's details (if pi-generated)
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
if (!prevCompaction.fromHook && prevCompaction.details) {
|
||||
// fromHook field kept for session file compatibility
|
||||
const details = prevCompaction.details as CompactionDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
for (const f of details.modifiedFiles) fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from tool calls in messages
|
||||
for (const msg of messages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
|
||||
return fileOps;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Extraction
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from an entry if it produces one.
|
||||
* Returns undefined for entries that don't contribute to LLM context.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
if (entry.type === "message") {
|
||||
return entry.message;
|
||||
}
|
||||
if (entry.type === "custom_message") {
|
||||
return createCustomMessage(
|
||||
entry.customType,
|
||||
entry.content,
|
||||
entry.display,
|
||||
entry.details,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
if (entry.type === "branch_summary") {
|
||||
return createBranchSummaryMessage(
|
||||
entry.summary,
|
||||
entry.fromId,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
if (entry.type === "compaction") {
|
||||
return createCompactionSummaryMessage(
|
||||
entry.summary,
|
||||
entry.tokensBefore,
|
||||
entry.timestamp,
|
||||
);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
||||
export interface CompactionResult<T = unknown> {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
|
||||
details?: T;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionSettings {
|
||||
enabled: boolean;
|
||||
reserveTokens: number;
|
||||
keepRecentTokens: number;
|
||||
}
|
||||
|
||||
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
|
||||
enabled: true,
|
||||
reserveTokens: 16384,
|
||||
keepRecentTokens: 20000,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Token calculation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Calculate total context tokens from usage.
|
||||
* Uses the native totalTokens field when available, falls back to computing from components.
|
||||
*/
|
||||
export function calculateContextTokens(usage: Usage): number {
|
||||
return (
|
||||
usage.totalTokens ||
|
||||
usage.input + usage.output + usage.cacheRead + usage.cacheWrite
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get usage from an assistant message if available.
|
||||
* Skips aborted and error messages as they don't have valid usage data.
|
||||
*/
|
||||
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
|
||||
if (msg.role === "assistant" && "usage" in msg) {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (
|
||||
assistantMsg.stopReason !== "aborted" &&
|
||||
assistantMsg.stopReason !== "error" &&
|
||||
assistantMsg.usage
|
||||
) {
|
||||
return assistantMsg.usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the last non-aborted assistant message usage from session entries.
|
||||
*/
|
||||
export function getLastAssistantUsage(
|
||||
entries: SessionEntry[],
|
||||
): Usage | undefined {
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const usage = getAssistantUsage(entry.message);
|
||||
if (usage) return usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export interface ContextUsageEstimate {
|
||||
tokens: number;
|
||||
usageTokens: number;
|
||||
trailingTokens: number;
|
||||
lastUsageIndex: number | null;
|
||||
}
|
||||
|
||||
function getLastAssistantUsageInfo(
|
||||
messages: AgentMessage[],
|
||||
): { usage: Usage; index: number } | undefined {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const usage = getAssistantUsage(messages[i]);
|
||||
if (usage) return { usage, index: i };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate context tokens from messages, using the last assistant usage when available.
|
||||
* If there are messages after the last usage, estimate their tokens with estimateTokens.
|
||||
*/
|
||||
export function estimateContextTokens(
|
||||
messages: AgentMessage[],
|
||||
): ContextUsageEstimate {
|
||||
const usageInfo = getLastAssistantUsageInfo(messages);
|
||||
|
||||
if (!usageInfo) {
|
||||
let estimated = 0;
|
||||
for (const message of messages) {
|
||||
estimated += estimateTokens(message);
|
||||
}
|
||||
return {
|
||||
tokens: estimated,
|
||||
usageTokens: 0,
|
||||
trailingTokens: estimated,
|
||||
lastUsageIndex: null,
|
||||
};
|
||||
}
|
||||
|
||||
const usageTokens = calculateContextTokens(usageInfo.usage);
|
||||
let trailingTokens = 0;
|
||||
for (let i = usageInfo.index + 1; i < messages.length; i++) {
|
||||
trailingTokens += estimateTokens(messages[i]);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: usageTokens + trailingTokens,
|
||||
usageTokens,
|
||||
trailingTokens,
|
||||
lastUsageIndex: usageInfo.index,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if compaction should trigger based on context usage.
|
||||
*/
|
||||
export function shouldCompact(
|
||||
contextTokens: number,
|
||||
contextWindow: number,
|
||||
settings: CompactionSettings,
|
||||
): boolean {
|
||||
if (!settings.enabled) return false;
|
||||
return contextTokens > contextWindow - settings.reserveTokens;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cut point detection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Estimate token count for a message using chars/4 heuristic.
|
||||
* This is conservative (overestimates tokens).
|
||||
*/
|
||||
export function estimateTokens(message: AgentMessage): number {
|
||||
let chars = 0;
|
||||
|
||||
switch (message.role) {
|
||||
case "user": {
|
||||
const content = (
|
||||
message as { content: string | Array<{ type: string; text?: string }> }
|
||||
).content;
|
||||
if (typeof content === "string") {
|
||||
chars = content.length;
|
||||
} else if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "assistant": {
|
||||
const assistant = message as AssistantMessage;
|
||||
for (const block of assistant.content) {
|
||||
if (block.type === "text") {
|
||||
chars += block.text.length;
|
||||
} else if (block.type === "thinking") {
|
||||
chars += block.thinking.length;
|
||||
} else if (block.type === "toolCall") {
|
||||
chars += block.name.length + JSON.stringify(block.arguments).length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "custom":
|
||||
case "toolResult": {
|
||||
if (typeof message.content === "string") {
|
||||
chars = message.content.length;
|
||||
} else {
|
||||
for (const block of message.content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
if (block.type === "image") {
|
||||
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "bashExecution": {
|
||||
chars = message.command.length + message.output.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "branchSummary":
|
||||
case "compactionSummary": {
|
||||
chars = message.summary.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
|
||||
* Never cut at tool results (they must follow their tool call).
|
||||
* When we cut at an assistant message with tool calls, its tool results follow it
|
||||
* and will be kept.
|
||||
* BashExecutionMessage is treated like a user message (user-initiated context).
|
||||
*/
|
||||
function findValidCutPoints(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
): number[] {
|
||||
const cutPoints: number[] = [];
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = entries[i];
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const role = entry.message.role;
|
||||
switch (role) {
|
||||
case "bashExecution":
|
||||
case "custom":
|
||||
case "branchSummary":
|
||||
case "compactionSummary":
|
||||
case "user":
|
||||
case "assistant":
|
||||
cutPoints.push(i);
|
||||
break;
|
||||
case "toolResult":
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "compaction":
|
||||
case "branch_summary":
|
||||
case "custom":
|
||||
case "custom_message":
|
||||
case "label":
|
||||
}
|
||||
// branch_summary and custom_message are user-role messages, valid cut points
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
cutPoints.push(i);
|
||||
}
|
||||
}
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
|
||||
* Returns -1 if no turn start found before the index.
|
||||
* BashExecutionMessage is treated like a user message for turn boundaries.
|
||||
*/
|
||||
export function findTurnStartIndex(
|
||||
entries: SessionEntry[],
|
||||
entryIndex: number,
|
||||
startIndex: number,
|
||||
): number {
|
||||
for (let i = entryIndex; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
// branch_summary and custom_message are user-role messages, can start a turn
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
return i;
|
||||
}
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
if (role === "user" || role === "bashExecution") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export interface CutPointResult {
|
||||
/** Index of first entry to keep */
|
||||
firstKeptEntryIndex: number;
|
||||
/** Index of user message that starts the turn being split, or -1 if not splitting */
|
||||
turnStartIndex: number;
|
||||
/** Whether this cut splits a turn (cut point is not a user message) */
|
||||
isSplitTurn: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
|
||||
*
|
||||
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
|
||||
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
|
||||
*
|
||||
* Can cut at user OR assistant messages (never tool results). When cutting at an
|
||||
* assistant message with tool calls, its tool results come after and will be kept.
|
||||
*
|
||||
* Returns CutPointResult with:
|
||||
* - firstKeptEntryIndex: the entry index to start keeping from
|
||||
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
|
||||
* - isSplitTurn: whether we're cutting in the middle of a turn
|
||||
*
|
||||
* Only considers entries between `startIndex` and `endIndex` (exclusive).
|
||||
*/
|
||||
export function findCutPoint(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
keepRecentTokens: number,
|
||||
): CutPointResult {
|
||||
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
|
||||
|
||||
if (cutPoints.length === 0) {
|
||||
return {
|
||||
firstKeptEntryIndex: startIndex,
|
||||
turnStartIndex: -1,
|
||||
isSplitTurn: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Walk backwards from newest, accumulating estimated message sizes
|
||||
let accumulatedTokens = 0;
|
||||
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
|
||||
|
||||
for (let i = endIndex - 1; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type !== "message") continue;
|
||||
|
||||
// Estimate this message's size
|
||||
const messageTokens = estimateTokens(entry.message);
|
||||
accumulatedTokens += messageTokens;
|
||||
|
||||
// Check if we've exceeded the budget
|
||||
if (accumulatedTokens >= keepRecentTokens) {
|
||||
// Find the closest valid cut point at or after this entry
|
||||
for (let c = 0; c < cutPoints.length; c++) {
|
||||
if (cutPoints[c] >= i) {
|
||||
cutIndex = cutPoints[c];
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
|
||||
while (cutIndex > startIndex) {
|
||||
const prevEntry = entries[cutIndex - 1];
|
||||
// Stop at session header or compaction boundaries
|
||||
if (prevEntry.type === "compaction") {
|
||||
break;
|
||||
}
|
||||
if (prevEntry.type === "message") {
|
||||
// Stop if we hit any message
|
||||
break;
|
||||
}
|
||||
// Include this non-message entry (bash, settings change, etc.)
|
||||
cutIndex--;
|
||||
}
|
||||
|
||||
// Determine if this is a split turn
|
||||
const cutEntry = entries[cutIndex];
|
||||
const isUserMessage =
|
||||
cutEntry.type === "message" && cutEntry.message.role === "user";
|
||||
const turnStartIndex = isUserMessage
|
||||
? -1
|
||||
: findTurnStartIndex(entries, cutIndex, startIndex);
|
||||
|
||||
return {
|
||||
firstKeptEntryIndex: cutIndex,
|
||||
turnStartIndex,
|
||||
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization
|
||||
// ============================================================================
|
||||
|
||||
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned by user]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [Ordered list of what should happen next]
|
||||
|
||||
## Critical Context
|
||||
- [Any data, examples, or references needed to continue]
|
||||
- [Or "(none)" if not applicable]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
|
||||
|
||||
Update the existing structured summary with new information. RULES:
|
||||
- PRESERVE all existing information from the previous summary
|
||||
- ADD new progress, decisions, and context from the new messages
|
||||
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
|
||||
- UPDATE "Next Steps" based on what was accomplished
|
||||
- PRESERVE exact file paths, function names, and error messages
|
||||
- If something is no longer relevant, you may remove it
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[Preserve existing goals, add new ones if the task expanded]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Preserve existing, add new ones discovered]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Include previously done items AND newly completed items]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work - update based on progress]
|
||||
|
||||
### Blocked
|
||||
- [Current blockers - remove if resolved]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
|
||||
|
||||
## Next Steps
|
||||
1. [Update based on current state]
|
||||
|
||||
## Critical Context
|
||||
- [Preserve important context, add new if needed]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of the conversation using the LLM.
|
||||
* If previousSummary is provided, uses the update prompt to merge.
|
||||
*/
|
||||
export async function generateSummary(
|
||||
currentMessages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
previousSummary?: string,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.8 * reserveTokens);
|
||||
|
||||
// Use update prompt if we have a previous summary, otherwise initial prompt
|
||||
let basePrompt = previousSummary
|
||||
? UPDATE_SUMMARIZATION_PROMPT
|
||||
: SUMMARIZATION_PROMPT;
|
||||
if (customInstructions) {
|
||||
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
|
||||
}
|
||||
|
||||
// Serialize conversation to text so model doesn't try to continue it
|
||||
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
|
||||
const llmMessages = convertToLlm(currentMessages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build the prompt with conversation wrapped in tags
|
||||
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
|
||||
if (previousSummary) {
|
||||
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
|
||||
}
|
||||
promptText += basePrompt;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const completionOptions = model.reasoning
|
||||
? { maxTokens, signal, apiKey, reasoning: "high" as const }
|
||||
: { maxTokens, signal, apiKey };
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
completionOptions,
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(
|
||||
`Summarization failed: ${response.errorMessage || "Unknown error"}`,
|
||||
);
|
||||
}
|
||||
|
||||
const textContent = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
return textContent;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compaction Preparation (for extensions)
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionPreparation {
|
||||
/** UUID of first entry to keep */
|
||||
firstKeptEntryId: string;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AgentMessage[];
|
||||
/** Messages that will be turned into turn prefix summary (if splitting) */
|
||||
turnPrefixMessages: AgentMessage[];
|
||||
/** Whether this is a split turn (cut point in middle of turn) */
|
||||
isSplitTurn: boolean;
|
||||
tokensBefore: number;
|
||||
/** Summary from previous compaction, for iterative update */
|
||||
previousSummary?: string;
|
||||
/** File operations extracted from messagesToSummarize */
|
||||
fileOps: FileOperations;
|
||||
/** Compaction settions from settings.jsonl */
|
||||
settings: CompactionSettings;
|
||||
}
|
||||
|
||||
export function prepareCompaction(
|
||||
pathEntries: SessionEntry[],
|
||||
settings: CompactionSettings,
|
||||
): CompactionPreparation | undefined {
|
||||
if (
|
||||
pathEntries.length > 0 &&
|
||||
pathEntries[pathEntries.length - 1].type === "compaction"
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = pathEntries.length - 1; i >= 0; i--) {
|
||||
if (pathEntries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = pathEntries.length;
|
||||
|
||||
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
|
||||
const usageMessages: AgentMessage[] = [];
|
||||
for (let i = usageStart; i < boundaryEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) usageMessages.push(msg);
|
||||
}
|
||||
const tokensBefore = estimateContextTokens(usageMessages).tokens;
|
||||
|
||||
const cutPoint = findCutPoint(
|
||||
pathEntries,
|
||||
boundaryStart,
|
||||
boundaryEnd,
|
||||
settings.keepRecentTokens,
|
||||
);
|
||||
|
||||
// Get UUID of first kept entry
|
||||
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
|
||||
if (!firstKeptEntry?.id) {
|
||||
return undefined; // Session needs migration
|
||||
}
|
||||
const firstKeptEntryId = firstKeptEntry.id;
|
||||
|
||||
const historyEnd = cutPoint.isSplitTurn
|
||||
? cutPoint.turnStartIndex
|
||||
: cutPoint.firstKeptEntryIndex;
|
||||
|
||||
// Messages to summarize (will be discarded after summary)
|
||||
const messagesToSummarize: AgentMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) messagesToSummarize.push(msg);
|
||||
}
|
||||
|
||||
// Messages for turn prefix summary (if splitting a turn)
|
||||
const turnPrefixMessages: AgentMessage[] = [];
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (
|
||||
let i = cutPoint.turnStartIndex;
|
||||
i < cutPoint.firstKeptEntryIndex;
|
||||
i++
|
||||
) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) turnPrefixMessages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Get previous summary for iterative update
|
||||
let previousSummary: string | undefined;
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
|
||||
previousSummary = prevCompaction.summary;
|
||||
}
|
||||
|
||||
// Extract file operations from messages and previous compaction
|
||||
const fileOps = extractFileOperations(
|
||||
messagesToSummarize,
|
||||
pathEntries,
|
||||
prevCompactionIndex,
|
||||
);
|
||||
|
||||
// Also extract file ops from turn prefix if splitting
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (const msg of turnPrefixMessages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn: cutPoint.isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main compaction function
|
||||
// ============================================================================
|
||||
|
||||
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
|
||||
|
||||
Summarize the prefix to provide context for the retained suffix:
|
||||
|
||||
## Original Request
|
||||
[What did the user ask for in this turn?]
|
||||
|
||||
## Early Progress
|
||||
- [Key decisions and work done in the prefix]
|
||||
|
||||
## Context for Suffix
|
||||
- [Information needed to understand the retained recent work]
|
||||
|
||||
Be concise. Focus on what's needed to understand the kept suffix.`;
|
||||
|
||||
/**
|
||||
* Generate summaries for compaction using prepared data.
|
||||
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
|
||||
*
|
||||
* @param preparation - Pre-calculated preparation from prepareCompaction()
|
||||
* @param customInstructions - Optional custom focus for the summary
|
||||
*/
|
||||
export async function compact(
|
||||
preparation: CompactionPreparation,
|
||||
model: Model<any>,
|
||||
apiKey: string,
|
||||
customInstructions?: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<CompactionResult> {
|
||||
const {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
} = preparation;
|
||||
|
||||
// Generate summaries (can be parallel if both needed) and merge into one
|
||||
let summary: string;
|
||||
|
||||
if (isSplitTurn && turnPrefixMessages.length > 0) {
|
||||
// Generate both summaries in parallel
|
||||
const [historyResult, turnPrefixResult] = await Promise.all([
|
||||
messagesToSummarize.length > 0
|
||||
? generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(
|
||||
turnPrefixMessages,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
),
|
||||
]);
|
||||
// Merge into single summary
|
||||
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
|
||||
} else {
|
||||
// Just generate history summary
|
||||
summary = await generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
);
|
||||
}
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
if (!firstKeptEntryId) {
|
||||
throw new Error(
|
||||
"First kept entry has no UUID - session may need migration",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
summary,
|
||||
firstKeptEntryId,
|
||||
tokensBefore,
|
||||
details: { readFiles, modifiedFiles } as CompactionDetails,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a summary for a turn prefix (when splitting a turn).
|
||||
*/
|
||||
async function generateTurnPrefixSummary(
|
||||
messages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{
|
||||
systemPrompt: SUMMARIZATION_SYSTEM_PROMPT,
|
||||
messages: summarizationMessages,
|
||||
},
|
||||
{ maxTokens, signal, apiKey },
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(
|
||||
`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`,
|
||||
);
|
||||
}
|
||||
|
||||
return response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Compaction and summarization utilities.
|
||||
*/
|
||||
|
||||
export * from "./branch-summarization.js";
|
||||
export * from "./compaction.js";
|
||||
export * from "./utils.js";
|
||||
167
packages/coding-agent/src/core/compaction/utils.ts
Normal file
167
packages/coding-agent/src/core/compaction/utils.ts
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
/**
|
||||
* Shared utilities for compaction and branch summarization.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
export interface FileOperations {
|
||||
read: Set<string>;
|
||||
written: Set<string>;
|
||||
edited: Set<string>;
|
||||
}
|
||||
|
||||
export function createFileOps(): FileOperations {
|
||||
return {
|
||||
read: new Set(),
|
||||
written: new Set(),
|
||||
edited: new Set(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from tool calls in an assistant message.
|
||||
*/
|
||||
export function extractFileOpsFromMessage(
|
||||
message: AgentMessage,
|
||||
fileOps: FileOperations,
|
||||
): void {
|
||||
if (message.role !== "assistant") return;
|
||||
if (!("content" in message) || !Array.isArray(message.content)) return;
|
||||
|
||||
for (const block of message.content) {
|
||||
if (typeof block !== "object" || block === null) continue;
|
||||
if (!("type" in block) || block.type !== "toolCall") continue;
|
||||
if (!("arguments" in block) || !("name" in block)) continue;
|
||||
|
||||
const args = block.arguments as Record<string, unknown> | undefined;
|
||||
if (!args) continue;
|
||||
|
||||
const path = typeof args.path === "string" ? args.path : undefined;
|
||||
if (!path) continue;
|
||||
|
||||
switch (block.name) {
|
||||
case "read":
|
||||
fileOps.read.add(path);
|
||||
break;
|
||||
case "write":
|
||||
fileOps.written.add(path);
|
||||
break;
|
||||
case "edit":
|
||||
fileOps.edited.add(path);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute final file lists from file operations.
|
||||
* Returns readFiles (files only read, not modified) and modifiedFiles.
|
||||
*/
|
||||
export function computeFileLists(fileOps: FileOperations): {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
} {
|
||||
const modified = new Set([...fileOps.edited, ...fileOps.written]);
|
||||
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
|
||||
const modifiedFiles = [...modified].sort();
|
||||
return { readFiles: readOnly, modifiedFiles };
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file operations as XML tags for summary.
|
||||
*/
|
||||
export function formatFileOperations(
|
||||
readFiles: string[],
|
||||
modifiedFiles: string[],
|
||||
): string {
|
||||
const sections: string[] = [];
|
||||
if (readFiles.length > 0) {
|
||||
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
|
||||
}
|
||||
if (modifiedFiles.length > 0) {
|
||||
sections.push(
|
||||
`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`,
|
||||
);
|
||||
}
|
||||
if (sections.length === 0) return "";
|
||||
return `\n\n${sections.join("\n\n")}`;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Serialization
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Serialize LLM messages to text for summarization.
|
||||
* This prevents the model from treating it as a conversation to continue.
|
||||
* Call convertToLlm() first to handle custom message types.
|
||||
*/
|
||||
export function serializeConversation(messages: Message[]): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
const content =
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: msg.content
|
||||
.filter(
|
||||
(c): c is { type: "text"; text: string } => c.type === "text",
|
||||
)
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) parts.push(`[User]: ${content}`);
|
||||
} else if (msg.role === "assistant") {
|
||||
const textParts: string[] = [];
|
||||
const thinkingParts: string[] = [];
|
||||
const toolCalls: string[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking") {
|
||||
thinkingParts.push(block.thinking);
|
||||
} else if (block.type === "toolCall") {
|
||||
const args = block.arguments as Record<string, unknown>;
|
||||
const argsStr = Object.entries(args)
|
||||
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
|
||||
.join(", ");
|
||||
toolCalls.push(`${block.name}(${argsStr})`);
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingParts.length > 0) {
|
||||
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
parts.push(`[Assistant]: ${textParts.join("\n")}`);
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const content = msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) {
|
||||
parts.push(`[Tool result]: ${content}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join("\n\n");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization System Prompt
|
||||
// ============================================================================
|
||||
|
||||
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
|
||||
|
||||
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;
|
||||
3
packages/coding-agent/src/core/defaults.ts
Normal file
3
packages/coding-agent/src/core/defaults.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
|
||||
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";
|
||||
15
packages/coding-agent/src/core/diagnostics.ts
Normal file
15
packages/coding-agent/src/core/diagnostics.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
export interface ResourceCollision {
|
||||
resourceType: "extension" | "skill" | "prompt" | "theme";
|
||||
name: string; // skill name, command/tool/flag name, prompt name, theme name
|
||||
winnerPath: string;
|
||||
loserPath: string;
|
||||
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
|
||||
loserSource?: string;
|
||||
}
|
||||
|
||||
export interface ResourceDiagnostic {
|
||||
type: "warning" | "error" | "collision";
|
||||
message: string;
|
||||
path?: string;
|
||||
collision?: ResourceCollision;
|
||||
}
|
||||
33
packages/coding-agent/src/core/event-bus.ts
Normal file
33
packages/coding-agent/src/core/event-bus.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import { EventEmitter } from "node:events";
|
||||
|
||||
export interface EventBus {
|
||||
emit(channel: string, data: unknown): void;
|
||||
on(channel: string, handler: (data: unknown) => void): () => void;
|
||||
}
|
||||
|
||||
export interface EventBusController extends EventBus {
|
||||
clear(): void;
|
||||
}
|
||||
|
||||
export function createEventBus(): EventBusController {
|
||||
const emitter = new EventEmitter();
|
||||
return {
|
||||
emit: (channel, data) => {
|
||||
emitter.emit(channel, data);
|
||||
},
|
||||
on: (channel, handler) => {
|
||||
const safeHandler = async (data: unknown) => {
|
||||
try {
|
||||
await handler(data);
|
||||
} catch (err) {
|
||||
console.error(`Event handler error (${channel}):`, err);
|
||||
}
|
||||
};
|
||||
emitter.on(channel, safeHandler);
|
||||
return () => emitter.off(channel, safeHandler);
|
||||
},
|
||||
clear: () => {
|
||||
emitter.removeAllListeners();
|
||||
},
|
||||
};
|
||||
}
|
||||
104
packages/coding-agent/src/core/exec.ts
Normal file
104
packages/coding-agent/src/core/exec.ts
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Shared command execution utilities for extensions and custom tools.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
|
||||
/**
|
||||
* Options for executing shell commands.
|
||||
*/
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the command */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
/** Working directory */
|
||||
cwd?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of executing a shell command.
|
||||
*/
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
killed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a shell command and return stdout/stderr/code.
|
||||
* Supports timeout and abort signal.
|
||||
*/
|
||||
export async function execCommand(
|
||||
command: string,
|
||||
args: string[],
|
||||
cwd: string,
|
||||
options?: ExecOptions,
|
||||
): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, {
|
||||
cwd,
|
||||
shell: false,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout?.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: code ?? 0, killed });
|
||||
});
|
||||
|
||||
proc.on("error", (_err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: 1, killed });
|
||||
});
|
||||
});
|
||||
}
|
||||
271
packages/coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
271
packages/coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
/**
|
||||
* ANSI escape code to HTML converter.
|
||||
*
|
||||
* Converts terminal ANSI color/style codes to HTML with inline styles.
|
||||
* Supports:
|
||||
* - Standard foreground colors (30-37) and bright variants (90-97)
|
||||
* - Standard background colors (40-47) and bright variants (100-107)
|
||||
* - 256-color palette (38;5;N and 48;5;N)
|
||||
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
|
||||
* - Text styles: bold (1), dim (2), italic (3), underline (4)
|
||||
* - Reset (0)
|
||||
*/
|
||||
|
||||
// Standard ANSI color palette (0-15)
|
||||
const ANSI_COLORS = [
|
||||
"#000000", // 0: black
|
||||
"#800000", // 1: red
|
||||
"#008000", // 2: green
|
||||
"#808000", // 3: yellow
|
||||
"#000080", // 4: blue
|
||||
"#800080", // 5: magenta
|
||||
"#008080", // 6: cyan
|
||||
"#c0c0c0", // 7: white
|
||||
"#808080", // 8: bright black
|
||||
"#ff0000", // 9: bright red
|
||||
"#00ff00", // 10: bright green
|
||||
"#ffff00", // 11: bright yellow
|
||||
"#0000ff", // 12: bright blue
|
||||
"#ff00ff", // 13: bright magenta
|
||||
"#00ffff", // 14: bright cyan
|
||||
"#ffffff", // 15: bright white
|
||||
];
|
||||
|
||||
/**
|
||||
* Convert 256-color index to hex.
|
||||
*/
|
||||
function color256ToHex(index: number): string {
|
||||
// Standard colors (0-15)
|
||||
if (index < 16) {
|
||||
return ANSI_COLORS[index];
|
||||
}
|
||||
|
||||
// Color cube (16-231): 6x6x6 = 216 colors
|
||||
if (index < 232) {
|
||||
const cubeIndex = index - 16;
|
||||
const r = Math.floor(cubeIndex / 36);
|
||||
const g = Math.floor((cubeIndex % 36) / 6);
|
||||
const b = cubeIndex % 6;
|
||||
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
|
||||
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
|
||||
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
|
||||
}
|
||||
|
||||
// Grayscale (232-255): 24 shades
|
||||
const gray = 8 + (index - 232) * 10;
|
||||
const grayHex = gray.toString(16).padStart(2, "0");
|
||||
return `#${grayHex}${grayHex}${grayHex}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Escape HTML special characters.
|
||||
*/
|
||||
function escapeHtml(text: string): string {
|
||||
return text
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
interface TextStyle {
|
||||
fg: string | null;
|
||||
bg: string | null;
|
||||
bold: boolean;
|
||||
dim: boolean;
|
||||
italic: boolean;
|
||||
underline: boolean;
|
||||
}
|
||||
|
||||
function createEmptyStyle(): TextStyle {
|
||||
return {
|
||||
fg: null,
|
||||
bg: null,
|
||||
bold: false,
|
||||
dim: false,
|
||||
italic: false,
|
||||
underline: false,
|
||||
};
|
||||
}
|
||||
|
||||
function styleToInlineCSS(style: TextStyle): string {
|
||||
const parts: string[] = [];
|
||||
if (style.fg) parts.push(`color:${style.fg}`);
|
||||
if (style.bg) parts.push(`background-color:${style.bg}`);
|
||||
if (style.bold) parts.push("font-weight:bold");
|
||||
if (style.dim) parts.push("opacity:0.6");
|
||||
if (style.italic) parts.push("font-style:italic");
|
||||
if (style.underline) parts.push("text-decoration:underline");
|
||||
return parts.join(";");
|
||||
}
|
||||
|
||||
function hasStyle(style: TextStyle): boolean {
|
||||
return (
|
||||
style.fg !== null ||
|
||||
style.bg !== null ||
|
||||
style.bold ||
|
||||
style.dim ||
|
||||
style.italic ||
|
||||
style.underline
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
|
||||
*/
|
||||
function applySgrCode(params: number[], style: TextStyle): void {
|
||||
let i = 0;
|
||||
while (i < params.length) {
|
||||
const code = params[i];
|
||||
|
||||
if (code === 0) {
|
||||
// Reset all
|
||||
style.fg = null;
|
||||
style.bg = null;
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
style.italic = false;
|
||||
style.underline = false;
|
||||
} else if (code === 1) {
|
||||
style.bold = true;
|
||||
} else if (code === 2) {
|
||||
style.dim = true;
|
||||
} else if (code === 3) {
|
||||
style.italic = true;
|
||||
} else if (code === 4) {
|
||||
style.underline = true;
|
||||
} else if (code === 22) {
|
||||
// Reset bold/dim
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
} else if (code === 23) {
|
||||
style.italic = false;
|
||||
} else if (code === 24) {
|
||||
style.underline = false;
|
||||
} else if (code >= 30 && code <= 37) {
|
||||
// Standard foreground colors
|
||||
style.fg = ANSI_COLORS[code - 30];
|
||||
} else if (code === 38) {
|
||||
// Extended foreground color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 38;5;N
|
||||
style.fg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 38;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.fg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 39) {
|
||||
// Default foreground
|
||||
style.fg = null;
|
||||
} else if (code >= 40 && code <= 47) {
|
||||
// Standard background colors
|
||||
style.bg = ANSI_COLORS[code - 40];
|
||||
} else if (code === 48) {
|
||||
// Extended background color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 48;5;N
|
||||
style.bg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 48;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.bg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 49) {
|
||||
// Default background
|
||||
style.bg = null;
|
||||
} else if (code >= 90 && code <= 97) {
|
||||
// Bright foreground colors
|
||||
style.fg = ANSI_COLORS[code - 90 + 8];
|
||||
} else if (code >= 100 && code <= 107) {
|
||||
// Bright background colors
|
||||
style.bg = ANSI_COLORS[code - 100 + 8];
|
||||
}
|
||||
// Ignore unrecognized codes
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
|
||||
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
|
||||
|
||||
/**
|
||||
* Convert ANSI-escaped text to HTML with inline styles.
|
||||
*/
|
||||
export function ansiToHtml(text: string): string {
|
||||
const style = createEmptyStyle();
|
||||
let result = "";
|
||||
let lastIndex = 0;
|
||||
let inSpan = false;
|
||||
|
||||
// Reset regex state
|
||||
ANSI_REGEX.lastIndex = 0;
|
||||
|
||||
let match = ANSI_REGEX.exec(text);
|
||||
while (match !== null) {
|
||||
// Add text before this escape sequence
|
||||
const beforeText = text.slice(lastIndex, match.index);
|
||||
if (beforeText) {
|
||||
result += escapeHtml(beforeText);
|
||||
}
|
||||
|
||||
// Parse SGR parameters
|
||||
const paramStr = match[1];
|
||||
const params = paramStr
|
||||
? paramStr.split(";").map((p) => parseInt(p, 10) || 0)
|
||||
: [0];
|
||||
|
||||
// Close existing span if we have one
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
inSpan = false;
|
||||
}
|
||||
|
||||
// Apply the codes
|
||||
applySgrCode(params, style);
|
||||
|
||||
// Open new span if we have any styling
|
||||
if (hasStyle(style)) {
|
||||
result += `<span style="${styleToInlineCSS(style)}">`;
|
||||
inSpan = true;
|
||||
}
|
||||
|
||||
lastIndex = match.index + match[0].length;
|
||||
match = ANSI_REGEX.exec(text);
|
||||
}
|
||||
|
||||
// Add remaining text
|
||||
const remainingText = text.slice(lastIndex);
|
||||
if (remainingText) {
|
||||
result += escapeHtml(remainingText);
|
||||
}
|
||||
|
||||
// Close any open span
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert array of ANSI-escaped lines to HTML.
|
||||
* Each line is wrapped in a div element.
|
||||
*/
|
||||
export function ansiLinesToHtml(lines: string[]): string {
|
||||
return lines
|
||||
.map(
|
||||
(line) => `<div class="ansi-line">${ansiToHtml(line) || " "}</div>`,
|
||||
)
|
||||
.join("\n");
|
||||
}
|
||||
353
packages/coding-agent/src/core/export-html/index.ts
Normal file
353
packages/coding-agent/src/core/export-html/index.ts
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
import type { AgentState } from "@mariozechner/pi-agent-core";
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { basename, join } from "path";
|
||||
import { APP_NAME, getExportTemplateDir } from "../../config.js";
|
||||
import {
|
||||
getResolvedThemeColors,
|
||||
getThemeExportColors,
|
||||
} from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolInfo } from "../extensions/types.js";
|
||||
import type { SessionEntry } from "../session-manager.js";
|
||||
import { SessionManager } from "../session-manager.js";
|
||||
|
||||
/**
|
||||
* Interface for rendering custom tools to HTML.
|
||||
* Used by agent-session to pre-render extension tool output.
|
||||
*/
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined;
|
||||
}
|
||||
|
||||
/** Pre-rendered HTML for a custom tool call and result */
|
||||
interface RenderedToolHtml {
|
||||
callHtml?: string;
|
||||
resultHtml?: string;
|
||||
}
|
||||
|
||||
export interface ExportOptions {
|
||||
outputPath?: string;
|
||||
themeName?: string;
|
||||
/** Optional tool renderer for custom tools */
|
||||
toolRenderer?: ToolHtmlRenderer;
|
||||
}
|
||||
|
||||
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
|
||||
function parseColor(
|
||||
color: string,
|
||||
): { r: number; g: number; b: number } | undefined {
|
||||
const hexMatch = color.match(
|
||||
/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/,
|
||||
);
|
||||
if (hexMatch) {
|
||||
return {
|
||||
r: Number.parseInt(hexMatch[1], 16),
|
||||
g: Number.parseInt(hexMatch[2], 16),
|
||||
b: Number.parseInt(hexMatch[3], 16),
|
||||
};
|
||||
}
|
||||
const rgbMatch = color.match(
|
||||
/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/,
|
||||
);
|
||||
if (rgbMatch) {
|
||||
return {
|
||||
r: Number.parseInt(rgbMatch[1], 10),
|
||||
g: Number.parseInt(rgbMatch[2], 10),
|
||||
b: Number.parseInt(rgbMatch[3], 10),
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Calculate relative luminance of a color (0-1, higher = lighter). */
|
||||
function getLuminance(r: number, g: number, b: number): number {
|
||||
const toLinear = (c: number) => {
|
||||
const s = c / 255;
|
||||
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
|
||||
};
|
||||
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
|
||||
}
|
||||
|
||||
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
|
||||
function adjustBrightness(color: string, factor: number): string {
|
||||
const parsed = parseColor(color);
|
||||
if (!parsed) return color;
|
||||
const adjust = (c: number) =>
|
||||
Math.min(255, Math.max(0, Math.round(c * factor)));
|
||||
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
|
||||
}
|
||||
|
||||
/** Derive export background colors from a base color (e.g., userMessageBg). */
|
||||
function deriveExportColors(baseColor: string): {
|
||||
pageBg: string;
|
||||
cardBg: string;
|
||||
infoBg: string;
|
||||
} {
|
||||
const parsed = parseColor(baseColor);
|
||||
if (!parsed) {
|
||||
return {
|
||||
pageBg: "rgb(24, 24, 30)",
|
||||
cardBg: "rgb(30, 30, 36)",
|
||||
infoBg: "rgb(60, 55, 40)",
|
||||
};
|
||||
}
|
||||
|
||||
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
|
||||
const isLight = luminance > 0.5;
|
||||
|
||||
if (isLight) {
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.96),
|
||||
cardBg: baseColor,
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.7),
|
||||
cardBg: adjustBrightness(baseColor, 0.85),
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate CSS custom property declarations from theme colors.
|
||||
*/
|
||||
function generateThemeVars(themeName?: string): string {
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const lines: string[] = [];
|
||||
for (const [key, value] of Object.entries(colors)) {
|
||||
lines.push(`--${key}: ${value};`);
|
||||
}
|
||||
|
||||
// Use explicit theme export colors if available, otherwise derive from userMessageBg
|
||||
const themeExport = getThemeExportColors(themeName);
|
||||
const userMessageBg = colors.userMessageBg || "#343541";
|
||||
const derivedColors = deriveExportColors(userMessageBg);
|
||||
|
||||
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
|
||||
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
|
||||
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
|
||||
|
||||
return lines.join("\n ");
|
||||
}
|
||||
|
||||
interface SessionData {
|
||||
header: ReturnType<SessionManager["getHeader"]>;
|
||||
entries: ReturnType<SessionManager["getEntries"]>;
|
||||
leafId: string | null;
|
||||
systemPrompt?: string;
|
||||
tools?: ToolInfo[];
|
||||
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
|
||||
renderedTools?: Record<string, RenderedToolHtml>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Core HTML generation logic shared by both export functions.
|
||||
*/
|
||||
function generateHtml(sessionData: SessionData, themeName?: string): string {
|
||||
const templateDir = getExportTemplateDir();
|
||||
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
|
||||
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
|
||||
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
|
||||
const markedJs = readFileSync(
|
||||
join(templateDir, "vendor", "marked.min.js"),
|
||||
"utf-8",
|
||||
);
|
||||
const hljsJs = readFileSync(
|
||||
join(templateDir, "vendor", "highlight.min.js"),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
const themeVars = generateThemeVars(themeName);
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
|
||||
const bodyBg = exportColors.pageBg;
|
||||
const containerBg = exportColors.cardBg;
|
||||
const infoBg = exportColors.infoBg;
|
||||
|
||||
// Base64 encode session data to avoid escaping issues
|
||||
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString(
|
||||
"base64",
|
||||
);
|
||||
|
||||
// Build the CSS with theme variables injected
|
||||
const css = templateCss
|
||||
.replace("{{THEME_VARS}}", themeVars)
|
||||
.replace("{{BODY_BG}}", bodyBg)
|
||||
.replace("{{CONTAINER_BG}}", containerBg)
|
||||
.replace("{{INFO_BG}}", infoBg);
|
||||
|
||||
return template
|
||||
.replace("{{CSS}}", css)
|
||||
.replace("{{JS}}", templateJs)
|
||||
.replace("{{SESSION_DATA}}", sessionDataBase64)
|
||||
.replace("{{MARKED_JS}}", markedJs)
|
||||
.replace("{{HIGHLIGHT_JS}}", hljsJs);
|
||||
}
|
||||
|
||||
/** Built-in tool names that have custom rendering in template.js */
|
||||
const BUILTIN_TOOLS = new Set([
|
||||
"bash",
|
||||
"read",
|
||||
"write",
|
||||
"edit",
|
||||
"ls",
|
||||
"find",
|
||||
"grep",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Pre-render custom tools to HTML using their TUI renderers.
|
||||
*/
|
||||
function preRenderCustomTools(
|
||||
entries: SessionEntry[],
|
||||
toolRenderer: ToolHtmlRenderer,
|
||||
): Record<string, RenderedToolHtml> {
|
||||
const renderedTools: Record<string, RenderedToolHtml> = {};
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.type !== "message") continue;
|
||||
const msg = entry.message;
|
||||
|
||||
// Find tool calls in assistant messages
|
||||
if (msg.role === "assistant" && Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
|
||||
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
|
||||
if (callHtml) {
|
||||
renderedTools[block.id] = { callHtml };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find tool results
|
||||
if (msg.role === "toolResult" && msg.toolCallId) {
|
||||
const toolName = msg.toolName || "";
|
||||
// Only render if we have a pre-rendered call OR it's not a built-in tool
|
||||
const existing = renderedTools[msg.toolCallId];
|
||||
if (existing || !BUILTIN_TOOLS.has(toolName)) {
|
||||
const resultHtml = toolRenderer.renderResult(
|
||||
toolName,
|
||||
msg.content,
|
||||
msg.details,
|
||||
msg.isError || false,
|
||||
);
|
||||
if (resultHtml) {
|
||||
renderedTools[msg.toolCallId] = {
|
||||
...existing,
|
||||
resultHtml,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return renderedTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session to HTML using SessionManager and AgentState.
|
||||
* Used by TUI's /export command.
|
||||
*/
|
||||
export async function exportSessionToHtml(
|
||||
sm: SessionManager,
|
||||
state?: AgentState,
|
||||
options?: ExportOptions | string,
|
||||
): Promise<string> {
|
||||
const opts: ExportOptions =
|
||||
typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
const sessionFile = sm.getSessionFile();
|
||||
if (!sessionFile) {
|
||||
throw new Error("Cannot export in-memory session to HTML");
|
||||
}
|
||||
if (!existsSync(sessionFile)) {
|
||||
throw new Error("Nothing to export yet - start a conversation first");
|
||||
}
|
||||
|
||||
const entries = sm.getEntries();
|
||||
|
||||
// Pre-render custom tools if a tool renderer is provided
|
||||
let renderedTools: Record<string, RenderedToolHtml> | undefined;
|
||||
if (opts.toolRenderer) {
|
||||
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
|
||||
// Only include if we actually rendered something
|
||||
if (Object.keys(renderedTools).length === 0) {
|
||||
renderedTools = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries,
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: state?.systemPrompt,
|
||||
tools: state?.tools?.map((t) => ({
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.parameters,
|
||||
})),
|
||||
renderedTools,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const sessionBasename = basename(sessionFile, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session file to HTML (standalone, without AgentState).
|
||||
* Used by CLI for exporting arbitrary session files.
|
||||
*/
|
||||
export async function exportFromFile(
|
||||
inputPath: string,
|
||||
options?: ExportOptions | string,
|
||||
): Promise<string> {
|
||||
const opts: ExportOptions =
|
||||
typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
if (!existsSync(inputPath)) {
|
||||
throw new Error(`File not found: ${inputPath}`);
|
||||
}
|
||||
|
||||
const sm = SessionManager.open(inputPath);
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries: sm.getEntries(),
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: undefined,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const inputBasename = basename(inputPath, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
971
packages/coding-agent/src/core/export-html/template.css
Normal file
971
packages/coding-agent/src/core/export-html/template.css
Normal file
|
|
@ -0,0 +1,971 @@
|
|||
:root {
|
||||
{{THEME_VARS}}
|
||||
--body-bg: {{BODY_BG}};
|
||||
--container-bg: {{CONTAINER_BG}};
|
||||
--info-bg: {{INFO_BG}};
|
||||
}
|
||||
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
:root {
|
||||
--line-height: 18px; /* 12px font * 1.5 */
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
|
||||
font-size: 12px;
|
||||
line-height: var(--line-height);
|
||||
color: var(--text);
|
||||
background: var(--body-bg);
|
||||
}
|
||||
|
||||
#app {
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
#sidebar {
|
||||
width: 400px;
|
||||
background: var(--container-bg);
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
height: 100vh;
|
||||
border-right: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.sidebar-header {
|
||||
padding: 8px 12px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.sidebar-controls {
|
||||
padding: 8px 8px 4px 8px;
|
||||
}
|
||||
|
||||
.sidebar-search {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
padding: 4px 8px;
|
||||
font-size: 11px;
|
||||
font-family: inherit;
|
||||
background: var(--body-bg);
|
||||
color: var(--text);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.sidebar-filters {
|
||||
display: flex;
|
||||
padding: 4px 8px 8px 8px;
|
||||
gap: 4px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.sidebar-search:focus {
|
||||
outline: none;
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-search::placeholder {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.filter-btn {
|
||||
padding: 3px 8px;
|
||||
font-size: 10px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.filter-btn:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.filter-btn.active {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: none;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.sidebar-close:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.tree-container {
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.tree-node {
|
||||
padding: 0 8px;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
font-size: 11px;
|
||||
line-height: 13px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.tree-node:hover {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active .tree-content {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tree-node.in-path {
|
||||
background: color-mix(in srgb, var(--accent) 10%, transparent);
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path) {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path):hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.tree-prefix {
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
font-family: monospace;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.tree-marker {
|
||||
color: var(--accent);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.tree-content {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tree-role-user {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.tree-role-assistant {
|
||||
color: var(--success);
|
||||
}
|
||||
|
||||
.tree-role-tool {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-muted {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
.tree-compaction {
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
.tree-branch-summary {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.tree-custom-message {
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.tree-status {
|
||||
padding: 4px 12px;
|
||||
font-size: 10px;
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Main content */
|
||||
#content {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: var(--line-height) calc(var(--line-height) * 2);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
width: 100%;
|
||||
max-width: 800px;
|
||||
}
|
||||
|
||||
/* Help bar */
|
||||
.help-bar {
|
||||
font-size: 11px;
|
||||
color: var(--warning);
|
||||
margin-bottom: var(--line-height);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.download-json-btn {
|
||||
font-size: 10px;
|
||||
padding: 2px 8px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 3px;
|
||||
color: var(--text);
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.download-json-btn:hover {
|
||||
background: var(--hover);
|
||||
border-color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Header */
|
||||
.header {
|
||||
background: var(--container-bg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
color: var(--dim);
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
font-weight: 600;
|
||||
margin-right: 8px;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: var(--text);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* Messages */
|
||||
#messages {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--line-height);
|
||||
}
|
||||
|
||||
.message-timestamp {
|
||||
font-size: 10px;
|
||||
color: var(--dim);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.user-message {
|
||||
background: var(--userMessageBg);
|
||||
color: var(--userMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.assistant-message {
|
||||
padding: 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
/* Copy link button - appears on hover */
|
||||
.copy-link-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
padding: 6px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 4px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s, background 0.15s, color 0.15s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.user-message:hover .copy-link-btn,
|
||||
.assistant-message:hover .copy-link-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.copy-link-btn:hover {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.copy-link-btn.copied {
|
||||
background: var(--success, #22c55e);
|
||||
color: white;
|
||||
border-color: var(--success, #22c55e);
|
||||
}
|
||||
|
||||
/* Highlight effect for deep-linked messages */
|
||||
.user-message.highlight,
|
||||
.assistant-message.highlight {
|
||||
animation: highlight-pulse 2s ease-out;
|
||||
}
|
||||
|
||||
@keyframes highlight-pulse {
|
||||
0% {
|
||||
box-shadow: 0 0 0 3px var(--accent);
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 0 0 transparent;
|
||||
}
|
||||
}
|
||||
|
||||
.assistant-message > .message-timestamp {
|
||||
padding-left: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text {
|
||||
padding: var(--line-height);
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.message-timestamp + .assistant-text,
|
||||
.message-timestamp + .thinking-block {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-block + .assistant-text {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-text {
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.message-timestamp + .thinking-block .thinking-text,
|
||||
.message-timestamp + .thinking-block .thinking-collapsed {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-collapsed {
|
||||
display: none;
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Tool execution */
|
||||
.tool-execution {
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.tool-execution + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-execution.pending { background: var(--toolPendingBg); }
|
||||
.tool-execution.success { background: var(--toolSuccessBg); }
|
||||
.tool-execution.error { background: var(--toolErrorBg); }
|
||||
|
||||
.tool-header, .tool-name {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tool-path {
|
||||
color: var(--accent);
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.line-numbers {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.line-count {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-command {
|
||||
font-weight: bold;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tool-output {
|
||||
margin-top: var(--line-height);
|
||||
color: var(--toolOutput);
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
font-family: inherit;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.tool-output > div,
|
||||
.output-preview,
|
||||
.output-full {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
line-height: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-output pre {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: inherit;
|
||||
color: inherit;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
|
||||
.tool-output code {
|
||||
padding: 0;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-output.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-output.expandable:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.tool-output.expandable .output-full {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-preview {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.ansi-line {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.tool-images {
|
||||
}
|
||||
|
||||
.tool-image {
|
||||
max-width: 100%;
|
||||
max-height: 500px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.expand-hint {
|
||||
color: var(--toolOutput);
|
||||
}
|
||||
|
||||
/* Diff */
|
||||
.tool-diff {
|
||||
font-size: 11px;
|
||||
overflow-x: auto;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.diff-added { color: var(--toolDiffAdded); }
|
||||
.diff-removed { color: var(--toolDiffRemoved); }
|
||||
.diff-context { color: var(--toolDiffContext); }
|
||||
|
||||
/* Model change */
|
||||
.model-change {
|
||||
padding: 0 var(--line-height);
|
||||
color: var(--dim);
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.model-name {
|
||||
color: var(--borderAccent);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Compaction / Branch Summary - matches customMessage colors from TUI */
|
||||
.compaction {
|
||||
background: var(--customMessageBg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.compaction-label {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.compaction-collapsed {
|
||||
color: var(--customMessageText);
|
||||
}
|
||||
|
||||
.compaction-content {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-collapsed {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* System prompt */
|
||||
.system-prompt {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.system-prompt-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.system-prompt-preview {
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt-expand-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.system-prompt-full {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-preview,
|
||||
.system-prompt.expanded .system-prompt-expand-hint {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.system-prompt.provider-prompt {
|
||||
border-left: 3px solid var(--warning);
|
||||
}
|
||||
|
||||
.system-prompt-note {
|
||||
font-size: 10px;
|
||||
font-style: italic;
|
||||
color: var(--muted);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* Tools list */
|
||||
.tools-list {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tools-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-item {
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-item-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-item-desc {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-params-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-item:has(.tool-params-hint) {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-params-hint::after {
|
||||
content: '[click to show parameters]';
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-hint::after {
|
||||
content: '[hide parameters]';
|
||||
}
|
||||
|
||||
.tool-params-content {
|
||||
display: none;
|
||||
margin-top: 4px;
|
||||
margin-left: 12px;
|
||||
padding-left: 8px;
|
||||
border-left: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.tool-param {
|
||||
margin-bottom: 4px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-param-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-param-type {
|
||||
color: var(--dim);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-param-required {
|
||||
color: var(--warning, #e8a838);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-optional {
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-desc {
|
||||
color: var(--dim);
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Hook/custom messages */
|
||||
.hook-message {
|
||||
background: var(--customMessageBg);
|
||||
color: var(--customMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.hook-type {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Branch summary */
|
||||
.branch-summary {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.branch-summary-header {
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Error */
|
||||
.error-text {
|
||||
color: var(--error);
|
||||
padding: 0 var(--line-height);
|
||||
}
|
||||
.tool-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
/* Images */
|
||||
.message-images {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.message-image {
|
||||
max-width: 100%;
|
||||
max-height: 400px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
/* Markdown content */
|
||||
.markdown-content h1,
|
||||
.markdown-content h2,
|
||||
.markdown-content h3,
|
||||
.markdown-content h4,
|
||||
.markdown-content h5,
|
||||
.markdown-content h6 {
|
||||
color: var(--mdHeading);
|
||||
margin: var(--line-height) 0 0 0;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content h1 { font-size: 1em; }
|
||||
.markdown-content h2 { font-size: 1em; }
|
||||
.markdown-content h3 { font-size: 1em; }
|
||||
.markdown-content h4 { font-size: 1em; }
|
||||
.markdown-content h5 { font-size: 1em; }
|
||||
.markdown-content h6 { font-size: 1em; }
|
||||
.markdown-content p { margin: 0; }
|
||||
.markdown-content p + p { margin-top: var(--line-height); }
|
||||
|
||||
.markdown-content a {
|
||||
color: var(--mdLink);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background: rgba(128, 128, 128, 0.2);
|
||||
color: var(--mdCode);
|
||||
padding: 0 4px;
|
||||
border-radius: 3px;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background: transparent;
|
||||
margin: var(--line-height) 0;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
display: block;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 3px solid var(--mdQuoteBorder);
|
||||
padding-left: var(--line-height);
|
||||
margin: var(--line-height) 0;
|
||||
color: var(--mdQuote);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content ul,
|
||||
.markdown-content ol {
|
||||
margin: var(--line-height) 0;
|
||||
padding-left: calc(var(--line-height) * 2);
|
||||
}
|
||||
|
||||
.markdown-content li { margin: 0; }
|
||||
.markdown-content li::marker { color: var(--mdListBullet); }
|
||||
|
||||
.markdown-content hr {
|
||||
border: none;
|
||||
border-top: 1px solid var(--mdHr);
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
margin: 0.5em 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid var(--mdCodeBlockBorder);
|
||||
padding: 6px 10px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
background: rgba(128, 128, 128, 0.1);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content img {
|
||||
max-width: 100%;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Syntax highlighting */
|
||||
.hljs { background: transparent; color: var(--text); }
|
||||
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
|
||||
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
|
||||
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
|
||||
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
|
||||
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
|
||||
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
|
||||
/* Types: hljs v11 uses .hljs-title.class_ for class names */
|
||||
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
|
||||
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
|
||||
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
|
||||
.hljs-operator { color: var(--syntaxOperator); }
|
||||
.hljs-punctuation { color: var(--syntaxPunctuation); }
|
||||
.hljs-subst { color: var(--text); }
|
||||
|
||||
/* Footer */
|
||||
.footer {
|
||||
margin-top: 48px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
/* Mobile */
|
||||
#hamburger {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
z-index: 100;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#hamburger:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
|
||||
|
||||
#sidebar-overlay {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
z-index: 98;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
#sidebar {
|
||||
position: fixed;
|
||||
left: -400px;
|
||||
width: 400px;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
height: 100vh;
|
||||
z-index: 99;
|
||||
transition: left 0.3s;
|
||||
}
|
||||
|
||||
#sidebar.open {
|
||||
left: 0;
|
||||
}
|
||||
|
||||
#sidebar-overlay.open {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#hamburger {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#content {
|
||||
padding: var(--line-height) 16px;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 500px) {
|
||||
#sidebar {
|
||||
width: 100vw;
|
||||
left: -100vw;
|
||||
}
|
||||
}
|
||||
|
||||
@media print {
|
||||
#sidebar, #sidebar-toggle { display: none !important; }
|
||||
body { background: white; color: black; }
|
||||
#content { max-width: none; }
|
||||
}
|
||||
54
packages/coding-agent/src/core/export-html/template.html
Normal file
54
packages/coding-agent/src/core/export-html/template.html
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Session Export</title>
|
||||
<style>
|
||||
{{CSS}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
|
||||
<div id="sidebar-overlay"></div>
|
||||
<div id="app">
|
||||
<aside id="sidebar">
|
||||
<div class="sidebar-header">
|
||||
<div class="sidebar-controls">
|
||||
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
|
||||
</div>
|
||||
<div class="sidebar-filters">
|
||||
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
|
||||
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
|
||||
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
|
||||
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
|
||||
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
|
||||
<button class="sidebar-close" id="sidebar-close" title="Close">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="tree-container" id="tree-container"></div>
|
||||
<div class="tree-status" id="tree-status"></div>
|
||||
</aside>
|
||||
<main id="content">
|
||||
<div id="header-container"></div>
|
||||
<div id="messages"></div>
|
||||
</main>
|
||||
<div id="image-modal" class="image-modal">
|
||||
<img id="modal-image" src="" alt="">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
|
||||
|
||||
<!-- Vendored libraries -->
|
||||
<script>{{MARKED_JS}}</script>
|
||||
|
||||
<!-- highlight.js -->
|
||||
<script>{{HIGHLIGHT_JS}}</script>
|
||||
|
||||
<!-- Main application code -->
|
||||
<script>
|
||||
{{JS}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1831
packages/coding-agent/src/core/export-html/template.js
Normal file
1831
packages/coding-agent/src/core/export-html/template.js
Normal file
File diff suppressed because it is too large
Load diff
112
packages/coding-agent/src/core/export-html/tool-renderer.ts
Normal file
112
packages/coding-agent/src/core/export-html/tool-renderer.ts
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Tool HTML renderer for custom tools in HTML export.
|
||||
*
|
||||
* Renders custom tool calls and results to HTML by invoking their TUI renderers
|
||||
* and converting the ANSI output to HTML.
|
||||
*/
|
||||
|
||||
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import type { Theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolDefinition } from "../extensions/types.js";
|
||||
import { ansiLinesToHtml } from "./ansi-to-html.js";
|
||||
|
||||
export interface ToolHtmlRendererDeps {
|
||||
/** Function to look up tool definition by name */
|
||||
getToolDefinition: (name: string) => ToolDefinition | undefined;
|
||||
/** Theme for styling */
|
||||
theme: Theme;
|
||||
/** Terminal width for rendering (default: 100) */
|
||||
width?: number;
|
||||
}
|
||||
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool HTML renderer.
|
||||
*
|
||||
* The renderer looks up tool definitions and invokes their renderCall/renderResult
|
||||
* methods, converting the resulting TUI Component output (ANSI) to HTML.
|
||||
*/
|
||||
export function createToolHtmlRenderer(
|
||||
deps: ToolHtmlRendererDeps,
|
||||
): ToolHtmlRenderer {
|
||||
const { getToolDefinition, theme, width = 100 } = deps;
|
||||
|
||||
return {
|
||||
renderCall(toolName: string, args: unknown): string | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderCall) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const component = toolDef.renderCall(args, theme);
|
||||
if (!component) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = component.render(width);
|
||||
return ansiLinesToHtml(lines);
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): string | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderResult) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Build AgentToolResult from content array
|
||||
// Cast content since session storage uses generic object types
|
||||
const agentToolResult = {
|
||||
content: result as (TextContent | ImageContent)[],
|
||||
details,
|
||||
isError,
|
||||
};
|
||||
|
||||
// Always render expanded, client-side will apply truncation
|
||||
const component = toolDef.renderResult(
|
||||
agentToolResult,
|
||||
{ expanded: true, isPartial: false },
|
||||
theme,
|
||||
);
|
||||
if (!component) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = component.render(width);
|
||||
return ansiLinesToHtml(lines);
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
8426
packages/coding-agent/src/core/export-html/vendor/highlight.min.js
vendored
Normal file
8426
packages/coding-agent/src/core/export-html/vendor/highlight.min.js
vendored
Normal file
File diff suppressed because it is too large
Load diff
1998
packages/coding-agent/src/core/export-html/vendor/marked.min.js
vendored
Normal file
1998
packages/coding-agent/src/core/export-html/vendor/marked.min.js
vendored
Normal file
File diff suppressed because it is too large
Load diff
170
packages/coding-agent/src/core/extensions/index.ts
Normal file
170
packages/coding-agent/src/core/extensions/index.ts
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Extension system for lifecycle events and custom tools.
|
||||
*/
|
||||
|
||||
export type {
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
} from "../slash-commands.js";
|
||||
export {
|
||||
createExtensionRuntime,
|
||||
discoverAndLoadExtensions,
|
||||
loadExtensionFromFactory,
|
||||
loadExtensions,
|
||||
} from "./loader.js";
|
||||
export type {
|
||||
ExtensionErrorListener,
|
||||
ForkHandler,
|
||||
NavigateTreeHandler,
|
||||
NewSessionHandler,
|
||||
ShutdownHandler,
|
||||
SwitchSessionHandler,
|
||||
} from "./runner.js";
|
||||
export { ExtensionRunner } from "./runner.js";
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
// Re-exports
|
||||
AgentToolResult,
|
||||
AgentToolUpdateCallback,
|
||||
// App keybindings (for custom editors)
|
||||
AppAction,
|
||||
AppendEntryHandler,
|
||||
// Events - Tool (ToolCallEvent types)
|
||||
BashToolCallEvent,
|
||||
BashToolResultEvent,
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
// Context
|
||||
CompactOptions,
|
||||
// Events - Agent
|
||||
ContextEvent,
|
||||
// Event Results
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
CustomToolCallEvent,
|
||||
CustomToolResultEvent,
|
||||
EditToolCallEvent,
|
||||
EditToolResultEvent,
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
// API
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
// Errors
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFactory,
|
||||
ExtensionFlag,
|
||||
ExtensionHandler,
|
||||
// Runtime
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
ExtensionUIDialogOptions,
|
||||
ExtensionWidgetOptions,
|
||||
FindToolCallEvent,
|
||||
FindToolResultEvent,
|
||||
GetActiveToolsHandler,
|
||||
GetAllToolsHandler,
|
||||
GetCommandsHandler,
|
||||
GetThinkingLevelHandler,
|
||||
GrepToolCallEvent,
|
||||
GrepToolResultEvent,
|
||||
// Events - Input
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
KeybindingsManager,
|
||||
LoadExtensionsResult,
|
||||
LsToolCallEvent,
|
||||
LsToolResultEvent,
|
||||
// Events - Message
|
||||
MessageEndEvent,
|
||||
// Message Rendering
|
||||
MessageRenderer,
|
||||
MessageRenderOptions,
|
||||
MessageStartEvent,
|
||||
MessageUpdateEvent,
|
||||
ModelSelectEvent,
|
||||
ModelSelectSource,
|
||||
// Provider Registration
|
||||
ProviderConfig,
|
||||
ProviderModelConfig,
|
||||
ReadToolCallEvent,
|
||||
ReadToolResultEvent,
|
||||
// Commands
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
// Events - Resources
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SendMessageHandler,
|
||||
SendUserMessageHandler,
|
||||
SessionBeforeCompactEvent,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkEvent,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchEvent,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeEvent,
|
||||
SessionBeforeTreeResult,
|
||||
SessionCompactEvent,
|
||||
SessionEvent,
|
||||
SessionForkEvent,
|
||||
SessionShutdownEvent,
|
||||
// Events - Session
|
||||
SessionStartEvent,
|
||||
SessionSwitchEvent,
|
||||
SessionTreeEvent,
|
||||
SetActiveToolsHandler,
|
||||
SetLabelHandler,
|
||||
SetModelHandler,
|
||||
SetThinkingLevelHandler,
|
||||
TerminalInputHandler,
|
||||
// Events - Tool
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
// Tools
|
||||
ToolDefinition,
|
||||
// Events - Tool Execution
|
||||
ToolExecutionEndEvent,
|
||||
ToolExecutionStartEvent,
|
||||
ToolExecutionUpdateEvent,
|
||||
ToolInfo,
|
||||
ToolRenderResultOptions,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
TreePreparation,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
// Events - User Bash
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
WidgetPlacement,
|
||||
WriteToolCallEvent,
|
||||
WriteToolResultEvent,
|
||||
} from "./types.js";
|
||||
// Type guards
|
||||
export {
|
||||
isBashToolResult,
|
||||
isEditToolResult,
|
||||
isFindToolResult,
|
||||
isGrepToolResult,
|
||||
isLsToolResult,
|
||||
isReadToolResult,
|
||||
isToolCallEventType,
|
||||
isWriteToolResult,
|
||||
} from "./types.js";
|
||||
export {
|
||||
wrapRegisteredTool,
|
||||
wrapRegisteredTools,
|
||||
wrapToolsWithExtensions,
|
||||
wrapToolWithExtensions,
|
||||
} from "./wrapper.js";
|
||||
607
packages/coding-agent/src/core/extensions/loader.ts
Normal file
607
packages/coding-agent/src/core/extensions/loader.ts
Normal file
|
|
@ -0,0 +1,607 @@
|
|||
/**
|
||||
* Extension loader - loads TypeScript extension modules using jiti.
|
||||
*
|
||||
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
|
||||
*/
|
||||
|
||||
import * as fs from "node:fs";
|
||||
import { createRequire } from "node:module";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import { createJiti } from "@mariozechner/jiti";
|
||||
import * as _bundledPiAgentCore from "@mariozechner/pi-agent-core";
|
||||
import * as _bundledPiAi from "@mariozechner/pi-ai";
|
||||
import * as _bundledPiAiOauth from "@mariozechner/pi-ai/oauth";
|
||||
import type { KeyId } from "@mariozechner/pi-tui";
|
||||
import * as _bundledPiTui from "@mariozechner/pi-tui";
|
||||
// Static imports of packages that extensions may use.
|
||||
// These MUST be static so Bun bundles them into the compiled binary.
|
||||
// The virtualModules option then makes them available to extensions.
|
||||
import * as _bundledTypebox from "@sinclair/typebox";
|
||||
import { getAgentDir, isBunBinary } from "../../config.js";
|
||||
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
|
||||
// avoiding a circular dependency. Extensions can import from @mariozechner/pi-coding-agent.
|
||||
import * as _bundledPiCodingAgent from "../../index.js";
|
||||
import { createEventBus, type EventBus } from "../event-bus.js";
|
||||
import type { ExecOptions } from "../exec.js";
|
||||
import { execCommand } from "../exec.js";
|
||||
import type {
|
||||
Extension,
|
||||
ExtensionAPI,
|
||||
ExtensionFactory,
|
||||
ExtensionRuntime,
|
||||
LoadExtensionsResult,
|
||||
MessageRenderer,
|
||||
ProviderConfig,
|
||||
RegisteredCommand,
|
||||
ToolDefinition,
|
||||
} from "./types.js";
|
||||
|
||||
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
|
||||
const VIRTUAL_MODULES: Record<string, unknown> = {
|
||||
"@sinclair/typebox": _bundledTypebox,
|
||||
"@mariozechner/pi-agent-core": _bundledPiAgentCore,
|
||||
"@mariozechner/pi-tui": _bundledPiTui,
|
||||
"@mariozechner/pi-ai": _bundledPiAi,
|
||||
"@mariozechner/pi-ai/oauth": _bundledPiAiOauth,
|
||||
"@mariozechner/pi-coding-agent": _bundledPiCodingAgent,
|
||||
};
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
|
||||
/**
|
||||
* Get aliases for jiti (used in Node.js/development mode).
|
||||
* In Bun binary mode, virtualModules is used instead.
|
||||
*/
|
||||
let _aliases: Record<string, string> | null = null;
|
||||
function getAliases(): Record<string, string> {
|
||||
if (_aliases) return _aliases;
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||
const packageIndex = path.resolve(__dirname, "../..", "index.js");
|
||||
|
||||
const typeboxEntry = require.resolve("@sinclair/typebox");
|
||||
const typeboxRoot = typeboxEntry.replace(
|
||||
/[\\/]build[\\/]cjs[\\/]index\.js$/,
|
||||
"",
|
||||
);
|
||||
|
||||
const packagesRoot = path.resolve(__dirname, "../../../../");
|
||||
const resolveWorkspaceOrImport = (
|
||||
workspaceRelativePath: string,
|
||||
specifier: string,
|
||||
): string => {
|
||||
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
|
||||
if (fs.existsSync(workspacePath)) {
|
||||
return workspacePath;
|
||||
}
|
||||
return fileURLToPath(import.meta.resolve(specifier));
|
||||
};
|
||||
|
||||
_aliases = {
|
||||
"@mariozechner/pi-coding-agent": packageIndex,
|
||||
"@mariozechner/pi-agent-core": resolveWorkspaceOrImport(
|
||||
"agent/dist/index.js",
|
||||
"@mariozechner/pi-agent-core",
|
||||
),
|
||||
"@mariozechner/pi-tui": resolveWorkspaceOrImport(
|
||||
"tui/dist/index.js",
|
||||
"@mariozechner/pi-tui",
|
||||
),
|
||||
"@mariozechner/pi-ai": resolveWorkspaceOrImport(
|
||||
"ai/dist/index.js",
|
||||
"@mariozechner/pi-ai",
|
||||
),
|
||||
"@mariozechner/pi-ai/oauth": resolveWorkspaceOrImport(
|
||||
"ai/dist/oauth.js",
|
||||
"@mariozechner/pi-ai/oauth",
|
||||
),
|
||||
"@sinclair/typebox": typeboxRoot,
|
||||
};
|
||||
|
||||
return _aliases;
|
||||
}
|
||||
|
||||
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
|
||||
|
||||
function normalizeUnicodeSpaces(str: string): string {
|
||||
return str.replace(UNICODE_SPACES, " ");
|
||||
}
|
||||
|
||||
function expandPath(p: string): string {
|
||||
const normalized = normalizeUnicodeSpaces(p);
|
||||
if (normalized.startsWith("~/")) {
|
||||
return path.join(os.homedir(), normalized.slice(2));
|
||||
}
|
||||
if (normalized.startsWith("~")) {
|
||||
return path.join(os.homedir(), normalized.slice(1));
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function resolvePath(extPath: string, cwd: string): string {
|
||||
const expanded = expandPath(extPath);
|
||||
if (path.isAbsolute(expanded)) {
|
||||
return expanded;
|
||||
}
|
||||
return path.resolve(cwd, expanded);
|
||||
}
|
||||
|
||||
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
|
||||
|
||||
/**
|
||||
* Create a runtime with throwing stubs for action methods.
|
||||
* Runner.bindCore() replaces these with real implementations.
|
||||
*/
|
||||
export function createExtensionRuntime(): ExtensionRuntime {
|
||||
const notInitialized = () => {
|
||||
throw new Error(
|
||||
"Extension runtime not initialized. Action methods cannot be called during extension loading.",
|
||||
);
|
||||
};
|
||||
|
||||
const runtime: ExtensionRuntime = {
|
||||
sendMessage: notInitialized,
|
||||
sendUserMessage: notInitialized,
|
||||
appendEntry: notInitialized,
|
||||
setSessionName: notInitialized,
|
||||
getSessionName: notInitialized,
|
||||
setLabel: notInitialized,
|
||||
getActiveTools: notInitialized,
|
||||
getAllTools: notInitialized,
|
||||
setActiveTools: notInitialized,
|
||||
// registerTool() is valid during extension load; refresh is only needed post-bind.
|
||||
refreshTools: () => {},
|
||||
getCommands: notInitialized,
|
||||
setModel: () =>
|
||||
Promise.reject(new Error("Extension runtime not initialized")),
|
||||
getThinkingLevel: notInitialized,
|
||||
setThinkingLevel: notInitialized,
|
||||
flagValues: new Map(),
|
||||
pendingProviderRegistrations: [],
|
||||
// Pre-bind: queue registrations so bindCore() can flush them once the
|
||||
// model registry is available. bindCore() replaces both with direct calls.
|
||||
registerProvider: (name, config) => {
|
||||
runtime.pendingProviderRegistrations.push({ name, config });
|
||||
},
|
||||
unregisterProvider: (name) => {
|
||||
runtime.pendingProviderRegistrations =
|
||||
runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
|
||||
},
|
||||
};
|
||||
|
||||
return runtime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the ExtensionAPI for an extension.
|
||||
* Registration methods write to the extension object.
|
||||
* Action methods delegate to the shared runtime.
|
||||
*/
|
||||
function createExtensionAPI(
|
||||
extension: Extension,
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
): ExtensionAPI {
|
||||
const api = {
|
||||
// Registration methods - write to extension
|
||||
on(event: string, handler: HandlerFn): void {
|
||||
const list = extension.handlers.get(event) ?? [];
|
||||
list.push(handler);
|
||||
extension.handlers.set(event, list);
|
||||
},
|
||||
|
||||
registerTool(tool: ToolDefinition): void {
|
||||
extension.tools.set(tool.name, {
|
||||
definition: tool,
|
||||
extensionPath: extension.path,
|
||||
});
|
||||
runtime.refreshTools();
|
||||
},
|
||||
|
||||
registerCommand(
|
||||
name: string,
|
||||
options: Omit<RegisteredCommand, "name">,
|
||||
): void {
|
||||
extension.commands.set(name, { name, ...options });
|
||||
},
|
||||
|
||||
registerShortcut(
|
||||
shortcut: KeyId,
|
||||
options: {
|
||||
description?: string;
|
||||
handler: (
|
||||
ctx: import("./types.js").ExtensionContext,
|
||||
) => Promise<void> | void;
|
||||
},
|
||||
): void {
|
||||
extension.shortcuts.set(shortcut, {
|
||||
shortcut,
|
||||
extensionPath: extension.path,
|
||||
...options,
|
||||
});
|
||||
},
|
||||
|
||||
registerFlag(
|
||||
name: string,
|
||||
options: {
|
||||
description?: string;
|
||||
type: "boolean" | "string";
|
||||
default?: boolean | string;
|
||||
},
|
||||
): void {
|
||||
extension.flags.set(name, {
|
||||
name,
|
||||
extensionPath: extension.path,
|
||||
...options,
|
||||
});
|
||||
if (options.default !== undefined && !runtime.flagValues.has(name)) {
|
||||
runtime.flagValues.set(name, options.default);
|
||||
}
|
||||
},
|
||||
|
||||
registerMessageRenderer<T>(
|
||||
customType: string,
|
||||
renderer: MessageRenderer<T>,
|
||||
): void {
|
||||
extension.messageRenderers.set(customType, renderer as MessageRenderer);
|
||||
},
|
||||
|
||||
// Flag access - checks extension registered it, reads from runtime
|
||||
getFlag(name: string): boolean | string | undefined {
|
||||
if (!extension.flags.has(name)) return undefined;
|
||||
return runtime.flagValues.get(name);
|
||||
},
|
||||
|
||||
// Action methods - delegate to shared runtime
|
||||
sendMessage(message, options): void {
|
||||
runtime.sendMessage(message, options);
|
||||
},
|
||||
|
||||
sendUserMessage(content, options): void {
|
||||
runtime.sendUserMessage(content, options);
|
||||
},
|
||||
|
||||
appendEntry(customType: string, data?: unknown): void {
|
||||
runtime.appendEntry(customType, data);
|
||||
},
|
||||
|
||||
setSessionName(name: string): void {
|
||||
runtime.setSessionName(name);
|
||||
},
|
||||
|
||||
getSessionName(): string | undefined {
|
||||
return runtime.getSessionName();
|
||||
},
|
||||
|
||||
setLabel(entryId: string, label: string | undefined): void {
|
||||
runtime.setLabel(entryId, label);
|
||||
},
|
||||
|
||||
exec(command: string, args: string[], options?: ExecOptions) {
|
||||
return execCommand(command, args, options?.cwd ?? cwd, options);
|
||||
},
|
||||
|
||||
getActiveTools(): string[] {
|
||||
return runtime.getActiveTools();
|
||||
},
|
||||
|
||||
getAllTools() {
|
||||
return runtime.getAllTools();
|
||||
},
|
||||
|
||||
setActiveTools(toolNames: string[]): void {
|
||||
runtime.setActiveTools(toolNames);
|
||||
},
|
||||
|
||||
getCommands() {
|
||||
return runtime.getCommands();
|
||||
},
|
||||
|
||||
setModel(model) {
|
||||
return runtime.setModel(model);
|
||||
},
|
||||
|
||||
getThinkingLevel() {
|
||||
return runtime.getThinkingLevel();
|
||||
},
|
||||
|
||||
setThinkingLevel(level) {
|
||||
runtime.setThinkingLevel(level);
|
||||
},
|
||||
|
||||
registerProvider(name: string, config: ProviderConfig) {
|
||||
runtime.registerProvider(name, config);
|
||||
},
|
||||
|
||||
unregisterProvider(name: string) {
|
||||
runtime.unregisterProvider(name);
|
||||
},
|
||||
|
||||
events: eventBus,
|
||||
} as ExtensionAPI;
|
||||
|
||||
return api;
|
||||
}
|
||||
|
||||
async function loadExtensionModule(extensionPath: string) {
|
||||
const jiti = createJiti(import.meta.url, {
|
||||
moduleCache: false,
|
||||
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
|
||||
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
|
||||
// In Node.js/dev: use aliases to resolve to node_modules paths
|
||||
...(isBunBinary
|
||||
? { virtualModules: VIRTUAL_MODULES, tryNative: false }
|
||||
: { alias: getAliases() }),
|
||||
});
|
||||
|
||||
const module = await jiti.import(extensionPath, { default: true });
|
||||
const factory = module as ExtensionFactory;
|
||||
return typeof factory !== "function" ? undefined : factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension object with empty collections.
|
||||
*/
|
||||
function createExtension(
|
||||
extensionPath: string,
|
||||
resolvedPath: string,
|
||||
): Extension {
|
||||
return {
|
||||
path: extensionPath,
|
||||
resolvedPath,
|
||||
handlers: new Map(),
|
||||
tools: new Map(),
|
||||
messageRenderers: new Map(),
|
||||
commands: new Map(),
|
||||
flags: new Map(),
|
||||
shortcuts: new Map(),
|
||||
};
|
||||
}
|
||||
|
||||
async function loadExtension(
|
||||
extensionPath: string,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
): Promise<{ extension: Extension | null; error: string | null }> {
|
||||
const resolvedPath = resolvePath(extensionPath, cwd);
|
||||
|
||||
try {
|
||||
const factory = await loadExtensionModule(resolvedPath);
|
||||
if (!factory) {
|
||||
return {
|
||||
extension: null,
|
||||
error: `Extension does not export a valid factory function: ${extensionPath}`,
|
||||
};
|
||||
}
|
||||
|
||||
const extension = createExtension(extensionPath, resolvedPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
|
||||
return { extension, error: null };
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return { extension: null, error: `Failed to load extension: ${message}` };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension from an inline factory function.
|
||||
*/
|
||||
export async function loadExtensionFromFactory(
|
||||
factory: ExtensionFactory,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
extensionPath = "<inline>",
|
||||
): Promise<Extension> {
|
||||
const extension = createExtension(extensionPath, extensionPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
return extension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load extensions from paths.
|
||||
*/
|
||||
export async function loadExtensions(
|
||||
paths: string[],
|
||||
cwd: string,
|
||||
eventBus?: EventBus,
|
||||
): Promise<LoadExtensionsResult> {
|
||||
const extensions: Extension[] = [];
|
||||
const errors: Array<{ path: string; error: string }> = [];
|
||||
const resolvedEventBus = eventBus ?? createEventBus();
|
||||
const runtime = createExtensionRuntime();
|
||||
|
||||
for (const extPath of paths) {
|
||||
const { extension, error } = await loadExtension(
|
||||
extPath,
|
||||
cwd,
|
||||
resolvedEventBus,
|
||||
runtime,
|
||||
);
|
||||
|
||||
if (error) {
|
||||
errors.push({ path: extPath, error });
|
||||
continue;
|
||||
}
|
||||
|
||||
if (extension) {
|
||||
extensions.push(extension);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
extensions,
|
||||
errors,
|
||||
runtime,
|
||||
};
|
||||
}
|
||||
|
||||
interface PiManifest {
|
||||
extensions?: string[];
|
||||
themes?: string[];
|
||||
skills?: string[];
|
||||
prompts?: string[];
|
||||
}
|
||||
|
||||
function readPiManifest(packageJsonPath: string): PiManifest | null {
|
||||
try {
|
||||
const content = fs.readFileSync(packageJsonPath, "utf-8");
|
||||
const pkg = JSON.parse(content);
|
||||
if (pkg.pi && typeof pkg.pi === "object") {
|
||||
return pkg.pi as PiManifest;
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function isExtensionFile(name: string): boolean {
|
||||
return name.endsWith(".ts") || name.endsWith(".js");
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve extension entry points from a directory.
|
||||
*
|
||||
* Checks for:
|
||||
* 1. package.json with "pi.extensions" field -> returns declared paths
|
||||
* 2. index.ts or index.js -> returns the index file
|
||||
*
|
||||
* Returns resolved paths or null if no entry points found.
|
||||
*/
|
||||
function resolveExtensionEntries(dir: string): string[] | null {
|
||||
// Check for package.json with "pi" field first
|
||||
const packageJsonPath = path.join(dir, "package.json");
|
||||
if (fs.existsSync(packageJsonPath)) {
|
||||
const manifest = readPiManifest(packageJsonPath);
|
||||
if (manifest?.extensions?.length) {
|
||||
const entries: string[] = [];
|
||||
for (const extPath of manifest.extensions) {
|
||||
const resolvedExtPath = path.resolve(dir, extPath);
|
||||
if (fs.existsSync(resolvedExtPath)) {
|
||||
entries.push(resolvedExtPath);
|
||||
}
|
||||
}
|
||||
if (entries.length > 0) {
|
||||
return entries;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for index.ts or index.js
|
||||
const indexTs = path.join(dir, "index.ts");
|
||||
const indexJs = path.join(dir, "index.js");
|
||||
if (fs.existsSync(indexTs)) {
|
||||
return [indexTs];
|
||||
}
|
||||
if (fs.existsSync(indexJs)) {
|
||||
return [indexJs];
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover extensions in a directory.
|
||||
*
|
||||
* Discovery rules:
|
||||
* 1. Direct files: `extensions/*.ts` or `*.js` → load
|
||||
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` → load
|
||||
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field → load what it declares
|
||||
*
|
||||
* No recursion beyond one level. Complex packages must use package.json manifest.
|
||||
*/
|
||||
function discoverExtensionsInDir(dir: string): string[] {
|
||||
if (!fs.existsSync(dir)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const discovered: string[] = [];
|
||||
|
||||
try {
|
||||
const entries = fs.readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const entryPath = path.join(dir, entry.name);
|
||||
|
||||
// 1. Direct files: *.ts or *.js
|
||||
if (
|
||||
(entry.isFile() || entry.isSymbolicLink()) &&
|
||||
isExtensionFile(entry.name)
|
||||
) {
|
||||
discovered.push(entryPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2 & 3. Subdirectories
|
||||
if (entry.isDirectory() || entry.isSymbolicLink()) {
|
||||
const entries = resolveExtensionEntries(entryPath);
|
||||
if (entries) {
|
||||
discovered.push(...entries);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
|
||||
return discovered;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover and load extensions from standard locations.
|
||||
*/
|
||||
export async function discoverAndLoadExtensions(
|
||||
configuredPaths: string[],
|
||||
cwd: string,
|
||||
agentDir: string = getAgentDir(),
|
||||
eventBus?: EventBus,
|
||||
): Promise<LoadExtensionsResult> {
|
||||
const allPaths: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
const addPaths = (paths: string[]) => {
|
||||
for (const p of paths) {
|
||||
const resolved = path.resolve(p);
|
||||
if (!seen.has(resolved)) {
|
||||
seen.add(resolved);
|
||||
allPaths.push(p);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Project-local extensions: cwd/.pi/extensions/
|
||||
const localExtDir = path.join(cwd, ".pi", "extensions");
|
||||
addPaths(discoverExtensionsInDir(localExtDir));
|
||||
|
||||
// 2. Global extensions: agentDir/extensions/
|
||||
const globalExtDir = path.join(agentDir, "extensions");
|
||||
addPaths(discoverExtensionsInDir(globalExtDir));
|
||||
|
||||
// 3. Explicitly configured paths
|
||||
for (const p of configuredPaths) {
|
||||
const resolved = resolvePath(p, cwd);
|
||||
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
|
||||
// Check for package.json with pi manifest or index.ts
|
||||
const entries = resolveExtensionEntries(resolved);
|
||||
if (entries) {
|
||||
addPaths(entries);
|
||||
continue;
|
||||
}
|
||||
// No explicit entries - discover individual files in directory
|
||||
addPaths(discoverExtensionsInDir(resolved));
|
||||
continue;
|
||||
}
|
||||
|
||||
addPaths([resolved]);
|
||||
}
|
||||
|
||||
return loadExtensions(allPaths, cwd, eventBus);
|
||||
}
|
||||
950
packages/coding-agent/src/core/extensions/runner.ts
Normal file
950
packages/coding-agent/src/core/extensions/runner.ts
Normal file
|
|
@ -0,0 +1,950 @@
|
|||
/**
|
||||
* Extension runner - executes extensions and manages their lifecycle.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Model } from "@mariozechner/pi-ai";
|
||||
import type { KeyId } from "@mariozechner/pi-tui";
|
||||
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ResourceDiagnostic } from "../diagnostics.js";
|
||||
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
|
||||
import type { ModelRegistry } from "../model-registry.js";
|
||||
import type { SessionManager } from "../session-manager.js";
|
||||
import type {
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
CompactOptions,
|
||||
ContextEvent,
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFlag,
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
MessageRenderer,
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
} from "./types.js";
|
||||
|
||||
// Keybindings for these actions cannot be overridden by extensions
|
||||
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"submit",
|
||||
"selectConfirm",
|
||||
"selectCancel",
|
||||
"copy",
|
||||
"deleteToLineEnd",
|
||||
];
|
||||
|
||||
type BuiltInKeyBindings = Partial<
|
||||
Record<KeyId, { action: KeyAction; restrictOverride: boolean }>
|
||||
>;
|
||||
|
||||
const buildBuiltinKeybindings = (
|
||||
effectiveKeybindings: Required<KeybindingsConfig>,
|
||||
): BuiltInKeyBindings => {
|
||||
const builtinKeybindings = {} as BuiltInKeyBindings;
|
||||
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
|
||||
const keyAction = action as KeyAction;
|
||||
const keyList = Array.isArray(keys) ? keys : [keys];
|
||||
const restrictOverride =
|
||||
RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
|
||||
for (const key of keyList) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
builtinKeybindings[normalizedKey] = {
|
||||
action: keyAction,
|
||||
restrictOverride: restrictOverride,
|
||||
};
|
||||
}
|
||||
}
|
||||
return builtinKeybindings;
|
||||
};
|
||||
|
||||
/** Combined result from all before_agent_start handlers */
|
||||
interface BeforeAgentStartCombinedResult {
|
||||
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
|
||||
systemPrompt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Events handled by the generic emit() method.
|
||||
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
|
||||
*/
|
||||
type RunnerEmitEvent = Exclude<
|
||||
ExtensionEvent,
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| UserBashEvent
|
||||
| ContextEvent
|
||||
| BeforeAgentStartEvent
|
||||
| ResourcesDiscoverEvent
|
||||
| InputEvent
|
||||
>;
|
||||
|
||||
type SessionBeforeEvent = Extract<
|
||||
RunnerEmitEvent,
|
||||
{
|
||||
type:
|
||||
| "session_before_switch"
|
||||
| "session_before_fork"
|
||||
| "session_before_compact"
|
||||
| "session_before_tree";
|
||||
}
|
||||
>;
|
||||
|
||||
type SessionBeforeEventResult =
|
||||
| SessionBeforeSwitchResult
|
||||
| SessionBeforeForkResult
|
||||
| SessionBeforeCompactResult
|
||||
| SessionBeforeTreeResult;
|
||||
|
||||
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends {
|
||||
type: "session_before_switch";
|
||||
}
|
||||
? SessionBeforeSwitchResult | undefined
|
||||
: TEvent extends { type: "session_before_fork" }
|
||||
? SessionBeforeForkResult | undefined
|
||||
: TEvent extends { type: "session_before_compact" }
|
||||
? SessionBeforeCompactResult | undefined
|
||||
: TEvent extends { type: "session_before_tree" }
|
||||
? SessionBeforeTreeResult | undefined
|
||||
: undefined;
|
||||
|
||||
export type ExtensionErrorListener = (error: ExtensionError) => void;
|
||||
|
||||
export type NewSessionHandler = (options?: {
|
||||
parentSession?: string;
|
||||
setup?: (sessionManager: SessionManager) => Promise<void>;
|
||||
}) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type NavigateTreeHandler = (
|
||||
targetId: string,
|
||||
options?: {
|
||||
summarize?: boolean;
|
||||
customInstructions?: string;
|
||||
replaceInstructions?: boolean;
|
||||
label?: string;
|
||||
},
|
||||
) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type SwitchSessionHandler = (
|
||||
sessionPath: string,
|
||||
) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ReloadHandler = () => Promise<void>;
|
||||
|
||||
export type ShutdownHandler = () => void;
|
||||
|
||||
/**
|
||||
* Helper function to emit session_shutdown event to extensions.
|
||||
* Returns true if the event was emitted, false if there were no handlers.
|
||||
*/
|
||||
export async function emitSessionShutdownEvent(
|
||||
extensionRunner: ExtensionRunner | undefined,
|
||||
): Promise<boolean> {
|
||||
if (extensionRunner?.hasHandlers("session_shutdown")) {
|
||||
await extensionRunner.emit({
|
||||
type: "session_shutdown",
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const noOpUIContext: ExtensionUIContext = {
|
||||
select: async () => undefined,
|
||||
confirm: async () => false,
|
||||
input: async () => undefined,
|
||||
notify: () => {},
|
||||
onTerminalInput: () => () => {},
|
||||
setStatus: () => {},
|
||||
setWorkingMessage: () => {},
|
||||
setWidget: () => {},
|
||||
setFooter: () => {},
|
||||
setHeader: () => {},
|
||||
setTitle: () => {},
|
||||
custom: async () => undefined as never,
|
||||
pasteToEditor: () => {},
|
||||
setEditorText: () => {},
|
||||
getEditorText: () => "",
|
||||
editor: async () => undefined,
|
||||
setEditorComponent: () => {},
|
||||
get theme() {
|
||||
return theme;
|
||||
},
|
||||
getAllThemes: () => [],
|
||||
getTheme: () => undefined,
|
||||
setTheme: (_theme: string | Theme) => ({
|
||||
success: false,
|
||||
error: "UI not available",
|
||||
}),
|
||||
getToolsExpanded: () => false,
|
||||
setToolsExpanded: () => {},
|
||||
};
|
||||
|
||||
export class ExtensionRunner {
|
||||
private extensions: Extension[];
|
||||
private runtime: ExtensionRuntime;
|
||||
private uiContext: ExtensionUIContext;
|
||||
private cwd: string;
|
||||
private sessionManager: SessionManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private errorListeners: Set<ExtensionErrorListener> = new Set();
|
||||
private getModel: () => Model<any> | undefined = () => undefined;
|
||||
private isIdleFn: () => boolean = () => true;
|
||||
private waitForIdleFn: () => Promise<void> = async () => {};
|
||||
private abortFn: () => void = () => {};
|
||||
private hasPendingMessagesFn: () => boolean = () => false;
|
||||
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
|
||||
private compactFn: (options?: CompactOptions) => void = () => {};
|
||||
private getSystemPromptFn: () => string = () => "";
|
||||
private newSessionHandler: NewSessionHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private forkHandler: ForkHandler = async () => ({ cancelled: false });
|
||||
private navigateTreeHandler: NavigateTreeHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private switchSessionHandler: SwitchSessionHandler = async () => ({
|
||||
cancelled: false,
|
||||
});
|
||||
private reloadHandler: ReloadHandler = async () => {};
|
||||
private shutdownHandler: ShutdownHandler = () => {};
|
||||
private shortcutDiagnostics: ResourceDiagnostic[] = [];
|
||||
private commandDiagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
constructor(
|
||||
extensions: Extension[],
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
sessionManager: SessionManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
) {
|
||||
this.extensions = extensions;
|
||||
this.runtime = runtime;
|
||||
this.uiContext = noOpUIContext;
|
||||
this.cwd = cwd;
|
||||
this.sessionManager = sessionManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
}
|
||||
|
||||
bindCore(
|
||||
actions: ExtensionActions,
|
||||
contextActions: ExtensionContextActions,
|
||||
): void {
|
||||
// Copy actions into the shared runtime (all extension APIs reference this)
|
||||
this.runtime.sendMessage = actions.sendMessage;
|
||||
this.runtime.sendUserMessage = actions.sendUserMessage;
|
||||
this.runtime.appendEntry = actions.appendEntry;
|
||||
this.runtime.setSessionName = actions.setSessionName;
|
||||
this.runtime.getSessionName = actions.getSessionName;
|
||||
this.runtime.setLabel = actions.setLabel;
|
||||
this.runtime.getActiveTools = actions.getActiveTools;
|
||||
this.runtime.getAllTools = actions.getAllTools;
|
||||
this.runtime.setActiveTools = actions.setActiveTools;
|
||||
this.runtime.refreshTools = actions.refreshTools;
|
||||
this.runtime.getCommands = actions.getCommands;
|
||||
this.runtime.setModel = actions.setModel;
|
||||
this.runtime.getThinkingLevel = actions.getThinkingLevel;
|
||||
this.runtime.setThinkingLevel = actions.setThinkingLevel;
|
||||
|
||||
// Context actions (required)
|
||||
this.getModel = contextActions.getModel;
|
||||
this.isIdleFn = contextActions.isIdle;
|
||||
this.abortFn = contextActions.abort;
|
||||
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
|
||||
this.shutdownHandler = contextActions.shutdown;
|
||||
this.getContextUsageFn = contextActions.getContextUsage;
|
||||
this.compactFn = contextActions.compact;
|
||||
this.getSystemPromptFn = contextActions.getSystemPrompt;
|
||||
|
||||
// Flush provider registrations queued during extension loading
|
||||
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
}
|
||||
this.runtime.pendingProviderRegistrations = [];
|
||||
|
||||
// From this point on, provider registration/unregistration takes effect immediately
|
||||
// without requiring a /reload.
|
||||
this.runtime.registerProvider = (name, config) =>
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
this.runtime.unregisterProvider = (name) =>
|
||||
this.modelRegistry.unregisterProvider(name);
|
||||
}
|
||||
|
||||
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
||||
if (actions) {
|
||||
this.waitForIdleFn = actions.waitForIdle;
|
||||
this.newSessionHandler = actions.newSession;
|
||||
this.forkHandler = actions.fork;
|
||||
this.navigateTreeHandler = actions.navigateTree;
|
||||
this.switchSessionHandler = actions.switchSession;
|
||||
this.reloadHandler = actions.reload;
|
||||
return;
|
||||
}
|
||||
|
||||
this.waitForIdleFn = async () => {};
|
||||
this.newSessionHandler = async () => ({ cancelled: false });
|
||||
this.forkHandler = async () => ({ cancelled: false });
|
||||
this.navigateTreeHandler = async () => ({ cancelled: false });
|
||||
this.switchSessionHandler = async () => ({ cancelled: false });
|
||||
this.reloadHandler = async () => {};
|
||||
}
|
||||
|
||||
setUIContext(uiContext?: ExtensionUIContext): void {
|
||||
this.uiContext = uiContext ?? noOpUIContext;
|
||||
}
|
||||
|
||||
getUIContext(): ExtensionUIContext {
|
||||
return this.uiContext;
|
||||
}
|
||||
|
||||
hasUI(): boolean {
|
||||
return this.uiContext !== noOpUIContext;
|
||||
}
|
||||
|
||||
getExtensionPaths(): string[] {
|
||||
return this.extensions.map((e) => e.path);
|
||||
}
|
||||
|
||||
/** Get all registered tools from all extensions (first registration per name wins). */
|
||||
getAllRegisteredTools(): RegisteredTool[] {
|
||||
const toolsByName = new Map<string, RegisteredTool>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const tool of ext.tools.values()) {
|
||||
if (!toolsByName.has(tool.definition.name)) {
|
||||
toolsByName.set(tool.definition.name, tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Array.from(toolsByName.values());
|
||||
}
|
||||
|
||||
/** Get a tool definition by name. Returns undefined if not found. */
|
||||
getToolDefinition(
|
||||
toolName: string,
|
||||
): RegisteredTool["definition"] | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const tool = ext.tools.get(toolName);
|
||||
if (tool) {
|
||||
return tool.definition;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getFlags(): Map<string, ExtensionFlag> {
|
||||
const allFlags = new Map<string, ExtensionFlag>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const [name, flag] of ext.flags) {
|
||||
if (!allFlags.has(name)) {
|
||||
allFlags.set(name, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
return allFlags;
|
||||
}
|
||||
|
||||
setFlagValue(name: string, value: boolean | string): void {
|
||||
this.runtime.flagValues.set(name, value);
|
||||
}
|
||||
|
||||
getFlagValues(): Map<string, boolean | string> {
|
||||
return new Map(this.runtime.flagValues);
|
||||
}
|
||||
|
||||
getShortcuts(
|
||||
effectiveKeybindings: Required<KeybindingsConfig>,
|
||||
): Map<KeyId, ExtensionShortcut> {
|
||||
this.shortcutDiagnostics = [];
|
||||
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
|
||||
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
|
||||
|
||||
const addDiagnostic = (message: string, extensionPath: string) => {
|
||||
this.shortcutDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: extensionPath,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
};
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const [key, shortcut] of ext.shortcuts) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
|
||||
const builtInKeybinding = builtinKeybindings[normalizedKey];
|
||||
if (builtInKeybinding?.restrictOverride === true) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (builtInKeybinding?.restrictOverride === false) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
|
||||
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
|
||||
if (existingExtensionShortcut) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
extensionShortcuts.set(normalizedKey, shortcut);
|
||||
}
|
||||
}
|
||||
return extensionShortcuts;
|
||||
}
|
||||
|
||||
getShortcutDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.shortcutDiagnostics;
|
||||
}
|
||||
|
||||
onError(listener: ExtensionErrorListener): () => void {
|
||||
this.errorListeners.add(listener);
|
||||
return () => this.errorListeners.delete(listener);
|
||||
}
|
||||
|
||||
emitError(error: ExtensionError): void {
|
||||
for (const listener of this.errorListeners) {
|
||||
listener(error);
|
||||
}
|
||||
}
|
||||
|
||||
hasHandlers(eventType: string): boolean {
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(eventType);
|
||||
if (handlers && handlers.length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
getMessageRenderer(customType: string): MessageRenderer | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const renderer = ext.messageRenderers.get(customType);
|
||||
if (renderer) {
|
||||
return renderer;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
|
||||
this.commandDiagnostics = [];
|
||||
|
||||
const commands: RegisteredCommand[] = [];
|
||||
const commandOwners = new Map<string, string>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
if (reserved?.has(command.name)) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
|
||||
this.commandDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: ext.path,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const existingOwner = commandOwners.get(command.name);
|
||||
if (existingOwner) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
|
||||
this.commandDiagnostics.push({
|
||||
type: "warning",
|
||||
message,
|
||||
path: ext.path,
|
||||
});
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
commandOwners.set(command.name, ext.path);
|
||||
commands.push(command);
|
||||
}
|
||||
}
|
||||
return commands;
|
||||
}
|
||||
|
||||
getCommandDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.commandDiagnostics;
|
||||
}
|
||||
|
||||
getRegisteredCommandsWithPaths(): Array<{
|
||||
command: RegisteredCommand;
|
||||
extensionPath: string;
|
||||
}> {
|
||||
const result: Array<{ command: RegisteredCommand; extensionPath: string }> =
|
||||
[];
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
result.push({ command, extensionPath: ext.path });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
getCommand(name: string): RegisteredCommand | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const command = ext.commands.get(name);
|
||||
if (command) {
|
||||
return command;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request a graceful shutdown. Called by extension tools and event handlers.
|
||||
* The actual shutdown behavior is provided by the mode via bindExtensions().
|
||||
*/
|
||||
shutdown(): void {
|
||||
this.shutdownHandler();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ExtensionContext for use in event handlers and tool execution.
|
||||
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
|
||||
*/
|
||||
createContext(): ExtensionContext {
|
||||
const getModel = this.getModel;
|
||||
return {
|
||||
ui: this.uiContext,
|
||||
hasUI: this.hasUI(),
|
||||
cwd: this.cwd,
|
||||
sessionManager: this.sessionManager,
|
||||
modelRegistry: this.modelRegistry,
|
||||
get model() {
|
||||
return getModel();
|
||||
},
|
||||
isIdle: () => this.isIdleFn(),
|
||||
abort: () => this.abortFn(),
|
||||
hasPendingMessages: () => this.hasPendingMessagesFn(),
|
||||
shutdown: () => this.shutdownHandler(),
|
||||
getContextUsage: () => this.getContextUsageFn(),
|
||||
compact: (options) => this.compactFn(options),
|
||||
getSystemPrompt: () => this.getSystemPromptFn(),
|
||||
};
|
||||
}
|
||||
|
||||
createCommandContext(): ExtensionCommandContext {
|
||||
return {
|
||||
...this.createContext(),
|
||||
waitForIdle: () => this.waitForIdleFn(),
|
||||
newSession: (options) => this.newSessionHandler(options),
|
||||
fork: (entryId) => this.forkHandler(entryId),
|
||||
navigateTree: (targetId, options) =>
|
||||
this.navigateTreeHandler(targetId, options),
|
||||
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
|
||||
reload: () => this.reloadHandler(),
|
||||
};
|
||||
}
|
||||
|
||||
private isSessionBeforeEvent(
|
||||
event: RunnerEmitEvent,
|
||||
): event is SessionBeforeEvent {
|
||||
return (
|
||||
event.type === "session_before_switch" ||
|
||||
event.type === "session_before_fork" ||
|
||||
event.type === "session_before_compact" ||
|
||||
event.type === "session_before_tree"
|
||||
);
|
||||
}
|
||||
|
||||
async emit<TEvent extends RunnerEmitEvent>(
|
||||
event: TEvent,
|
||||
): Promise<RunnerEmitResult<TEvent>> {
|
||||
const ctx = this.createContext();
|
||||
let result: SessionBeforeEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(event.type);
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (this.isSessionBeforeEvent(event) && handlerResult) {
|
||||
result = handlerResult as SessionBeforeEventResult;
|
||||
if (result.cancel) {
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: event.type,
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
|
||||
async emitToolResult(
|
||||
event: ToolResultEvent,
|
||||
): Promise<ToolResultEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const currentEvent: ToolResultEvent = { ...event };
|
||||
let modified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_result");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = (await handler(currentEvent, ctx)) as
|
||||
| ToolResultEventResult
|
||||
| undefined;
|
||||
if (!handlerResult) continue;
|
||||
|
||||
if (handlerResult.content !== undefined) {
|
||||
currentEvent.content = handlerResult.content;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.details !== undefined) {
|
||||
currentEvent.details = handlerResult.details;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.isError !== undefined) {
|
||||
currentEvent.isError = handlerResult.isError;
|
||||
modified = true;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "tool_result",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!modified) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
content: currentEvent.content,
|
||||
details: currentEvent.details,
|
||||
isError: currentEvent.isError,
|
||||
};
|
||||
}
|
||||
|
||||
async emitToolCall(
|
||||
event: ToolCallEvent,
|
||||
): Promise<ToolCallEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
let result: ToolCallEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_call");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
result = handlerResult as ToolCallEventResult;
|
||||
if (result.block) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async emitUserBash(
|
||||
event: UserBashEvent,
|
||||
): Promise<UserBashEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("user_bash");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
if (handlerResult) {
|
||||
return handlerResult as UserBashEventResult;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "user_bash",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
|
||||
const ctx = this.createContext();
|
||||
let currentMessages = structuredClone(messages);
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("context");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ContextEvent = {
|
||||
type: "context",
|
||||
messages: currentMessages,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult && (handlerResult as ContextEventResult).messages) {
|
||||
currentMessages = (handlerResult as ContextEventResult).messages!;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "context",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
async emitBeforeAgentStart(
|
||||
prompt: string,
|
||||
images: ImageContent[] | undefined,
|
||||
systemPrompt: string,
|
||||
): Promise<BeforeAgentStartCombinedResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
|
||||
let currentSystemPrompt = systemPrompt;
|
||||
let systemPromptModified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("before_agent_start");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: BeforeAgentStartEvent = {
|
||||
type: "before_agent_start",
|
||||
prompt,
|
||||
images,
|
||||
systemPrompt: currentSystemPrompt,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
const result = handlerResult as BeforeAgentStartEventResult;
|
||||
if (result.message) {
|
||||
messages.push(result.message);
|
||||
}
|
||||
if (result.systemPrompt !== undefined) {
|
||||
currentSystemPrompt = result.systemPrompt;
|
||||
systemPromptModified = true;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "before_agent_start",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (messages.length > 0 || systemPromptModified) {
|
||||
return {
|
||||
messages: messages.length > 0 ? messages : undefined,
|
||||
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitResourcesDiscover(
|
||||
cwd: string,
|
||||
reason: ResourcesDiscoverEvent["reason"],
|
||||
): Promise<{
|
||||
skillPaths: Array<{ path: string; extensionPath: string }>;
|
||||
promptPaths: Array<{ path: string; extensionPath: string }>;
|
||||
themePaths: Array<{ path: string; extensionPath: string }>;
|
||||
}> {
|
||||
const ctx = this.createContext();
|
||||
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const themePaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("resources_discover");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ResourcesDiscoverEvent = {
|
||||
type: "resources_discover",
|
||||
cwd,
|
||||
reason,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
const result = handlerResult as ResourcesDiscoverResult | undefined;
|
||||
|
||||
if (result?.skillPaths?.length) {
|
||||
skillPaths.push(
|
||||
...result.skillPaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (result?.promptPaths?.length) {
|
||||
promptPaths.push(
|
||||
...result.promptPaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (result?.themePaths?.length) {
|
||||
themePaths.push(
|
||||
...result.themePaths.map((path) => ({
|
||||
path,
|
||||
extensionPath: ext.path,
|
||||
})),
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "resources_discover",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { skillPaths, promptPaths, themePaths };
|
||||
}
|
||||
|
||||
/** Emit input event. Transforms chain, "handled" short-circuits. */
|
||||
async emitInput(
|
||||
text: string,
|
||||
images: ImageContent[] | undefined,
|
||||
source: InputSource,
|
||||
): Promise<InputEventResult> {
|
||||
const ctx = this.createContext();
|
||||
let currentText = text;
|
||||
let currentImages = images;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const handler of ext.handlers.get("input") ?? []) {
|
||||
try {
|
||||
const event: InputEvent = {
|
||||
type: "input",
|
||||
text: currentText,
|
||||
images: currentImages,
|
||||
source,
|
||||
};
|
||||
const result = (await handler(event, ctx)) as
|
||||
| InputEventResult
|
||||
| undefined;
|
||||
if (result?.action === "handled") return result;
|
||||
if (result?.action === "transform") {
|
||||
currentText = result.text;
|
||||
currentImages = result.images ?? currentImages;
|
||||
}
|
||||
} catch (err) {
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "input",
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
stack: err instanceof Error ? err.stack : undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentText !== text || currentImages !== images
|
||||
? { action: "transform", text: currentText, images: currentImages }
|
||||
: { action: "continue" };
|
||||
}
|
||||
}
|
||||
1575
packages/coding-agent/src/core/extensions/types.ts
Normal file
1575
packages/coding-agent/src/core/extensions/types.ts
Normal file
File diff suppressed because it is too large
Load diff
147
packages/coding-agent/src/core/extensions/wrapper.ts
Normal file
147
packages/coding-agent/src/core/extensions/wrapper.ts
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
/**
|
||||
* Tool wrappers for extensions.
|
||||
*/
|
||||
|
||||
import type {
|
||||
AgentTool,
|
||||
AgentToolUpdateCallback,
|
||||
} from "@mariozechner/pi-agent-core";
|
||||
import type { ExtensionRunner } from "./runner.js";
|
||||
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
|
||||
|
||||
/**
|
||||
* Wrap a RegisteredTool into an AgentTool.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTool(
|
||||
registeredTool: RegisteredTool,
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool {
|
||||
const { definition } = registeredTool;
|
||||
return {
|
||||
name: definition.name,
|
||||
label: definition.label,
|
||||
description: definition.description,
|
||||
parameters: definition.parameters,
|
||||
execute: (toolCallId, params, signal, onUpdate) =>
|
||||
definition.execute(
|
||||
toolCallId,
|
||||
params,
|
||||
signal,
|
||||
onUpdate,
|
||||
runner.createContext(),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all registered tools into AgentTools.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTools(
|
||||
registeredTools: RegisteredTool[],
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool[] {
|
||||
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a tool with extension callbacks for interception.
|
||||
* - Emits tool_call event before execution (can block)
|
||||
* - Emits tool_result event after execution (can modify result)
|
||||
*/
|
||||
export function wrapToolWithExtensions<T>(
|
||||
tool: AgentTool<any, T>,
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool<any, T> {
|
||||
return {
|
||||
...tool,
|
||||
execute: async (
|
||||
toolCallId: string,
|
||||
params: Record<string, unknown>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<T>,
|
||||
) => {
|
||||
// Emit tool_call event - extensions can block execution
|
||||
if (runner.hasHandlers("tool_call")) {
|
||||
try {
|
||||
const callResult = (await runner.emitToolCall({
|
||||
type: "tool_call",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
})) as ToolCallEventResult | undefined;
|
||||
|
||||
if (callResult?.block) {
|
||||
const reason =
|
||||
callResult.reason || "Tool execution was blocked by an extension";
|
||||
throw new Error(reason);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
throw err;
|
||||
}
|
||||
throw new Error(
|
||||
`Extension failed, blocking execution: ${String(err)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the actual tool
|
||||
try {
|
||||
const result = await tool.execute(toolCallId, params, signal, onUpdate);
|
||||
|
||||
// Emit tool_result event - extensions can modify the result
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
const resultResult = await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError: false,
|
||||
});
|
||||
|
||||
if (resultResult) {
|
||||
return {
|
||||
content: resultResult.content ?? result.content,
|
||||
details: (resultResult.details ?? result.details) as T,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (err) {
|
||||
// Emit tool_result event for errors
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: err instanceof Error ? err.message : String(err),
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
isError: true,
|
||||
});
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all tools with extension callbacks.
|
||||
*/
|
||||
export function wrapToolsWithExtensions<T>(
|
||||
tools: AgentTool<any, T>[],
|
||||
runner: ExtensionRunner,
|
||||
): AgentTool<any, T>[] {
|
||||
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
|
||||
}
|
||||
149
packages/coding-agent/src/core/footer-data-provider.ts
Normal file
149
packages/coding-agent/src/core/footer-data-provider.ts
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
|
||||
import { dirname, join, resolve } from "path";
|
||||
|
||||
/**
|
||||
* Find the git HEAD path by walking up from cwd.
|
||||
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
|
||||
*/
|
||||
function findGitHeadPath(): string | null {
|
||||
let dir = process.cwd();
|
||||
while (true) {
|
||||
const gitPath = join(dir, ".git");
|
||||
if (existsSync(gitPath)) {
|
||||
try {
|
||||
const stat = statSync(gitPath);
|
||||
if (stat.isFile()) {
|
||||
const content = readFileSync(gitPath, "utf8").trim();
|
||||
if (content.startsWith("gitdir: ")) {
|
||||
const gitDir = content.slice(8);
|
||||
const headPath = resolve(dir, gitDir, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} else if (stat.isDirectory()) {
|
||||
const headPath = join(gitPath, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
const parent = dirname(dir);
|
||||
if (parent === dir) return null;
|
||||
dir = parent;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
|
||||
* Token stats, model info available via ctx.sessionManager and ctx.model.
|
||||
*/
|
||||
export class FooterDataProvider {
|
||||
private extensionStatuses = new Map<string, string>();
|
||||
private cachedBranch: string | null | undefined = undefined;
|
||||
private gitWatcher: FSWatcher | null = null;
|
||||
private branchChangeCallbacks = new Set<() => void>();
|
||||
private availableProviderCount = 0;
|
||||
|
||||
constructor() {
|
||||
this.setupGitWatcher();
|
||||
}
|
||||
|
||||
/** Current git branch, null if not in repo, "detached" if detached HEAD */
|
||||
getGitBranch(): string | null {
|
||||
if (this.cachedBranch !== undefined) return this.cachedBranch;
|
||||
|
||||
try {
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) {
|
||||
this.cachedBranch = null;
|
||||
return null;
|
||||
}
|
||||
const content = readFileSync(gitHeadPath, "utf8").trim();
|
||||
this.cachedBranch = content.startsWith("ref: refs/heads/")
|
||||
? content.slice(16)
|
||||
: "detached";
|
||||
} catch {
|
||||
this.cachedBranch = null;
|
||||
}
|
||||
return this.cachedBranch;
|
||||
}
|
||||
|
||||
/** Extension status texts set via ctx.ui.setStatus() */
|
||||
getExtensionStatuses(): ReadonlyMap<string, string> {
|
||||
return this.extensionStatuses;
|
||||
}
|
||||
|
||||
/** Subscribe to git branch changes. Returns unsubscribe function. */
|
||||
onBranchChange(callback: () => void): () => void {
|
||||
this.branchChangeCallbacks.add(callback);
|
||||
return () => this.branchChangeCallbacks.delete(callback);
|
||||
}
|
||||
|
||||
/** Internal: set extension status */
|
||||
setExtensionStatus(key: string, text: string | undefined): void {
|
||||
if (text === undefined) {
|
||||
this.extensionStatuses.delete(key);
|
||||
} else {
|
||||
this.extensionStatuses.set(key, text);
|
||||
}
|
||||
}
|
||||
|
||||
/** Internal: clear extension statuses */
|
||||
clearExtensionStatuses(): void {
|
||||
this.extensionStatuses.clear();
|
||||
}
|
||||
|
||||
/** Number of unique providers with available models (for footer display) */
|
||||
getAvailableProviderCount(): number {
|
||||
return this.availableProviderCount;
|
||||
}
|
||||
|
||||
/** Internal: update available provider count */
|
||||
setAvailableProviderCount(count: number): void {
|
||||
this.availableProviderCount = count;
|
||||
}
|
||||
|
||||
/** Internal: cleanup */
|
||||
dispose(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
this.branchChangeCallbacks.clear();
|
||||
}
|
||||
|
||||
private setupGitWatcher(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) return;
|
||||
|
||||
// Watch the directory containing HEAD, not HEAD itself.
|
||||
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
|
||||
// fs.watch on a file stops working after the inode changes.
|
||||
const gitDir = dirname(gitHeadPath);
|
||||
|
||||
try {
|
||||
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
|
||||
if (filename === "HEAD") {
|
||||
this.cachedBranch = undefined;
|
||||
for (const cb of this.branchChangeCallbacks) cb();
|
||||
}
|
||||
});
|
||||
} catch {
|
||||
// Silently fail if we can't watch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
|
||||
export type ReadonlyFooterDataProvider = Pick<
|
||||
FooterDataProvider,
|
||||
| "getGitBranch"
|
||||
| "getExtensionStatuses"
|
||||
| "getAvailableProviderCount"
|
||||
| "onBranchChange"
|
||||
>;
|
||||
1290
packages/coding-agent/src/core/gateway-runtime.ts
Normal file
1290
packages/coding-agent/src/core/gateway-runtime.ts
Normal file
File diff suppressed because it is too large
Load diff
70
packages/coding-agent/src/core/index.ts
Normal file
70
packages/coding-agent/src/core/index.ts
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Core modules shared between all run modes.
|
||||
*/
|
||||
|
||||
export {
|
||||
AgentSession,
|
||||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type ModelCycleResult,
|
||||
type PromptOptions,
|
||||
type SessionStats,
|
||||
} from "./agent-session.js";
|
||||
export {
|
||||
type BashExecutorOptions,
|
||||
type BashResult,
|
||||
executeBash,
|
||||
executeBashWithOperations,
|
||||
} from "./bash-executor.js";
|
||||
export type { CompactionResult } from "./compaction/index.js";
|
||||
export {
|
||||
createEventBus,
|
||||
type EventBus,
|
||||
type EventBusController,
|
||||
} from "./event-bus.js";
|
||||
|
||||
// Extensions system
|
||||
export {
|
||||
type AgentEndEvent,
|
||||
type AgentStartEvent,
|
||||
type AgentToolResult,
|
||||
type AgentToolUpdateCallback,
|
||||
type BeforeAgentStartEvent,
|
||||
type ContextEvent,
|
||||
discoverAndLoadExtensions,
|
||||
type ExecOptions,
|
||||
type ExecResult,
|
||||
type Extension,
|
||||
type ExtensionAPI,
|
||||
type ExtensionCommandContext,
|
||||
type ExtensionContext,
|
||||
type ExtensionError,
|
||||
type ExtensionEvent,
|
||||
type ExtensionFactory,
|
||||
type ExtensionFlag,
|
||||
type ExtensionHandler,
|
||||
ExtensionRunner,
|
||||
type ExtensionShortcut,
|
||||
type ExtensionUIContext,
|
||||
type LoadExtensionsResult,
|
||||
type MessageRenderer,
|
||||
type RegisteredCommand,
|
||||
type SessionBeforeCompactEvent,
|
||||
type SessionBeforeForkEvent,
|
||||
type SessionBeforeSwitchEvent,
|
||||
type SessionBeforeTreeEvent,
|
||||
type SessionCompactEvent,
|
||||
type SessionForkEvent,
|
||||
type SessionShutdownEvent,
|
||||
type SessionStartEvent,
|
||||
type SessionSwitchEvent,
|
||||
type SessionTreeEvent,
|
||||
type ToolCallEvent,
|
||||
type ToolDefinition,
|
||||
type ToolRenderResultOptions,
|
||||
type ToolResultEvent,
|
||||
type TurnEndEvent,
|
||||
type TurnStartEvent,
|
||||
wrapToolsWithExtensions,
|
||||
} from "./extensions/index.js";
|
||||
211
packages/coding-agent/src/core/keybindings.ts
Normal file
211
packages/coding-agent/src/core/keybindings.ts
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import {
|
||||
DEFAULT_EDITOR_KEYBINDINGS,
|
||||
type EditorAction,
|
||||
type EditorKeybindingsConfig,
|
||||
EditorKeybindingsManager,
|
||||
type KeyId,
|
||||
matchesKey,
|
||||
setEditorKeybindings,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
|
||||
/**
|
||||
* Application-level actions (coding agent specific).
|
||||
*/
|
||||
export type AppAction =
|
||||
| "interrupt"
|
||||
| "clear"
|
||||
| "exit"
|
||||
| "suspend"
|
||||
| "cycleThinkingLevel"
|
||||
| "cycleModelForward"
|
||||
| "cycleModelBackward"
|
||||
| "selectModel"
|
||||
| "expandTools"
|
||||
| "toggleThinking"
|
||||
| "toggleSessionNamedFilter"
|
||||
| "externalEditor"
|
||||
| "followUp"
|
||||
| "dequeue"
|
||||
| "pasteImage"
|
||||
| "newSession"
|
||||
| "tree"
|
||||
| "fork"
|
||||
| "resume";
|
||||
|
||||
/**
|
||||
* All configurable actions.
|
||||
*/
|
||||
export type KeyAction = AppAction | EditorAction;
|
||||
|
||||
/**
|
||||
* Full keybindings configuration (app + editor actions).
|
||||
*/
|
||||
export type KeybindingsConfig = {
|
||||
[K in KeyAction]?: KeyId | KeyId[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Default application keybindings.
|
||||
*/
|
||||
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
|
||||
interrupt: "escape",
|
||||
clear: "ctrl+c",
|
||||
exit: "ctrl+d",
|
||||
suspend: "ctrl+z",
|
||||
cycleThinkingLevel: "shift+tab",
|
||||
cycleModelForward: "ctrl+p",
|
||||
cycleModelBackward: "shift+ctrl+p",
|
||||
selectModel: "ctrl+l",
|
||||
expandTools: "ctrl+o",
|
||||
toggleThinking: "ctrl+t",
|
||||
toggleSessionNamedFilter: "ctrl+n",
|
||||
externalEditor: "ctrl+g",
|
||||
followUp: "alt+enter",
|
||||
dequeue: "alt+up",
|
||||
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
|
||||
newSession: [],
|
||||
tree: [],
|
||||
fork: [],
|
||||
resume: [],
|
||||
};
|
||||
|
||||
/**
|
||||
* All default keybindings (app + editor).
|
||||
*/
|
||||
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
|
||||
...DEFAULT_EDITOR_KEYBINDINGS,
|
||||
...DEFAULT_APP_KEYBINDINGS,
|
||||
};
|
||||
|
||||
// App actions list for type checking
|
||||
const APP_ACTIONS: AppAction[] = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"toggleSessionNamedFilter",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"dequeue",
|
||||
"pasteImage",
|
||||
"newSession",
|
||||
"tree",
|
||||
"fork",
|
||||
"resume",
|
||||
];
|
||||
|
||||
function isAppAction(action: string): action is AppAction {
|
||||
return APP_ACTIONS.includes(action as AppAction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages all keybindings (app + editor).
|
||||
*/
|
||||
export class KeybindingsManager {
|
||||
private config: KeybindingsConfig;
|
||||
private appActionToKeys: Map<AppAction, KeyId[]>;
|
||||
|
||||
private constructor(config: KeybindingsConfig) {
|
||||
this.config = config;
|
||||
this.appActionToKeys = new Map();
|
||||
this.buildMaps();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create from config file and set up editor keybindings.
|
||||
*/
|
||||
static create(agentDir: string = getAgentDir()): KeybindingsManager {
|
||||
const configPath = join(agentDir, "keybindings.json");
|
||||
const config = KeybindingsManager.loadFromFile(configPath);
|
||||
const manager = new KeybindingsManager(config);
|
||||
|
||||
// Set up editor keybindings globally
|
||||
// Include both editor actions and expandTools (shared between app and editor)
|
||||
const editorConfig: EditorKeybindingsConfig = {};
|
||||
for (const [action, keys] of Object.entries(config)) {
|
||||
if (!isAppAction(action) || action === "expandTools") {
|
||||
editorConfig[action as EditorAction] = keys;
|
||||
}
|
||||
}
|
||||
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create in-memory.
|
||||
*/
|
||||
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
|
||||
return new KeybindingsManager(config);
|
||||
}
|
||||
|
||||
private static loadFromFile(path: string): KeybindingsConfig {
|
||||
if (!existsSync(path)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(path, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private buildMaps(): void {
|
||||
this.appActionToKeys.clear();
|
||||
|
||||
// Set defaults for app actions
|
||||
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action as AppAction, [...keyArray]);
|
||||
}
|
||||
|
||||
// Override with user config (app actions only)
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys === undefined || !isAppAction(action)) continue;
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action, keyArray);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if input matches an app action.
|
||||
*/
|
||||
matches(data: string, action: AppAction): boolean {
|
||||
const keys = this.appActionToKeys.get(action);
|
||||
if (!keys) return false;
|
||||
for (const key of keys) {
|
||||
if (matchesKey(data, key)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get keys bound to an app action.
|
||||
*/
|
||||
getKeys(action: AppAction): KeyId[] {
|
||||
return this.appActionToKeys.get(action) ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full effective config.
|
||||
*/
|
||||
getEffectiveConfig(): Required<KeybindingsConfig> {
|
||||
const result = { ...DEFAULT_KEYBINDINGS };
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys !== undefined) {
|
||||
(result as KeybindingsConfig)[action as KeyAction] = keys;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export for convenience
|
||||
export type { EditorAction, KeyId };
|
||||
217
packages/coding-agent/src/core/messages.ts
Normal file
217
packages/coding-agent/src/core/messages.ts
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
/**
|
||||
* Custom message types and transformers for the coding agent.
|
||||
*
|
||||
* Extends the base AgentMessage type with coding-agent specific message types,
|
||||
* and provides a transformer to convert them to LLM-compatible messages.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
|
||||
|
||||
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const COMPACTION_SUMMARY_SUFFIX = `
|
||||
</summary>`;
|
||||
|
||||
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
|
||||
|
||||
/**
|
||||
* Message type for bash executions via the ! command.
|
||||
*/
|
||||
export interface BashExecutionMessage {
|
||||
role: "bashExecution";
|
||||
command: string;
|
||||
output: string;
|
||||
exitCode: number | undefined;
|
||||
cancelled: boolean;
|
||||
truncated: boolean;
|
||||
fullOutputPath?: string;
|
||||
timestamp: number;
|
||||
/** If true, this message is excluded from LLM context (!! prefix) */
|
||||
excludeFromContext?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message type for extension-injected messages via sendMessage().
|
||||
* These are custom messages that extensions can inject into the conversation.
|
||||
*/
|
||||
export interface CustomMessage<T = unknown> {
|
||||
role: "custom";
|
||||
customType: string;
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
display: boolean;
|
||||
details?: T;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface BranchSummaryMessage {
|
||||
role: "branchSummary";
|
||||
summary: string;
|
||||
fromId: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface CompactionSummaryMessage {
|
||||
role: "compactionSummary";
|
||||
summary: string;
|
||||
tokensBefore: number;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
// Extend CustomAgentMessages via declaration merging
|
||||
declare module "@mariozechner/pi-agent-core" {
|
||||
interface CustomAgentMessages {
|
||||
bashExecution: BashExecutionMessage;
|
||||
custom: CustomMessage;
|
||||
branchSummary: BranchSummaryMessage;
|
||||
compactionSummary: CompactionSummaryMessage;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a BashExecutionMessage to user message text for LLM context.
|
||||
*/
|
||||
export function bashExecutionToText(msg: BashExecutionMessage): string {
|
||||
let text = `Ran \`${msg.command}\`\n`;
|
||||
if (msg.output) {
|
||||
text += `\`\`\`\n${msg.output}\n\`\`\``;
|
||||
} else {
|
||||
text += "(no output)";
|
||||
}
|
||||
if (msg.cancelled) {
|
||||
text += "\n\n(command cancelled)";
|
||||
} else if (
|
||||
msg.exitCode !== null &&
|
||||
msg.exitCode !== undefined &&
|
||||
msg.exitCode !== 0
|
||||
) {
|
||||
text += `\n\nCommand exited with code ${msg.exitCode}`;
|
||||
}
|
||||
if (msg.truncated && msg.fullOutputPath) {
|
||||
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
export function createBranchSummaryMessage(
|
||||
summary: string,
|
||||
fromId: string,
|
||||
timestamp: string,
|
||||
): BranchSummaryMessage {
|
||||
return {
|
||||
role: "branchSummary",
|
||||
summary,
|
||||
fromId,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
export function createCompactionSummaryMessage(
|
||||
summary: string,
|
||||
tokensBefore: number,
|
||||
timestamp: string,
|
||||
): CompactionSummaryMessage {
|
||||
return {
|
||||
role: "compactionSummary",
|
||||
summary: summary,
|
||||
tokensBefore,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert CustomMessageEntry to AgentMessage format */
|
||||
export function createCustomMessage(
|
||||
customType: string,
|
||||
content: string | (TextContent | ImageContent)[],
|
||||
display: boolean,
|
||||
details: unknown | undefined,
|
||||
timestamp: string,
|
||||
): CustomMessage {
|
||||
return {
|
||||
role: "custom",
|
||||
customType,
|
||||
content,
|
||||
display,
|
||||
details,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
|
||||
*
|
||||
* This is used by:
|
||||
* - Agent's transormToLlm option (for prompt calls and queued messages)
|
||||
* - Compaction's generateSummary (for summarization)
|
||||
* - Custom extensions and tools
|
||||
*/
|
||||
export function convertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages
|
||||
.map((m): Message | undefined => {
|
||||
switch (m.role) {
|
||||
case "bashExecution":
|
||||
// Skip messages excluded from context (!! prefix)
|
||||
if (m.excludeFromContext) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: bashExecutionToText(m) }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "custom": {
|
||||
const content =
|
||||
typeof m.content === "string"
|
||||
? [{ type: "text" as const, text: m.content }]
|
||||
: m.content;
|
||||
return {
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
case "branchSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX,
|
||||
},
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "compactionSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text:
|
||||
COMPACTION_SUMMARY_PREFIX +
|
||||
m.summary +
|
||||
COMPACTION_SUMMARY_SUFFIX,
|
||||
},
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "user":
|
||||
case "assistant":
|
||||
case "toolResult":
|
||||
return m;
|
||||
default:
|
||||
// biome-ignore lint/correctness/noSwitchDeclarations: fine
|
||||
const _exhaustiveCheck: never = m;
|
||||
return undefined;
|
||||
}
|
||||
})
|
||||
.filter((m) => m !== undefined);
|
||||
}
|
||||
822
packages/coding-agent/src/core/model-registry.ts
Normal file
822
packages/coding-agent/src/core/model-registry.ts
Normal file
|
|
@ -0,0 +1,822 @@
|
|||
/**
|
||||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
type OAuthProviderInterface,
|
||||
type OpenAICompletionsCompat,
|
||||
type OpenAIResponsesCompat,
|
||||
registerApiProvider,
|
||||
resetApiProviders,
|
||||
type SimpleStreamOptions,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
registerOAuthProvider,
|
||||
resetOAuthProviders,
|
||||
} from "@mariozechner/pi-ai/oauth";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
import {
|
||||
clearConfigValueCache,
|
||||
resolveConfigValue,
|
||||
resolveHeaders,
|
||||
} from "./resolve-config-value.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const ajv = new Ajv();
|
||||
|
||||
// Schema for OpenRouter routing preferences
|
||||
const OpenRouterRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for Vercel AI Gateway routing preferences
|
||||
const VercelGatewayRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompletionsCompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("max_completion_tokens"),
|
||||
Type.Literal("max_tokens"),
|
||||
]),
|
||||
),
|
||||
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||
thinkingFormat: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai"),
|
||||
Type.Literal("zai"),
|
||||
Type.Literal("qwen"),
|
||||
]),
|
||||
),
|
||||
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
|
||||
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
|
||||
});
|
||||
|
||||
const OpenAIResponsesCompatSchema = Type.Object({
|
||||
// Reserved for future use
|
||||
});
|
||||
|
||||
const OpenAICompatSchema = Type.Union([
|
||||
OpenAICompletionsCompatSchema,
|
||||
OpenAIResponsesCompatSchema,
|
||||
]);
|
||||
|
||||
// Schema for custom model definition
|
||||
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
// Schema for per-model overrides (all fields optional, merged with built-in model)
|
||||
const ModelOverrideSchema = Type.Object({
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Optional(Type.Number()),
|
||||
output: Type.Optional(Type.Number()),
|
||||
cacheRead: Type.Optional(Type.Number()),
|
||||
cacheWrite: Type.Optional(Type.Number()),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
type ModelOverride = Static<typeof ModelOverrideSchema>;
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
apiKey: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||
modelOverrides: Type.Optional(
|
||||
Type.Record(Type.String(), ModelOverrideSchema),
|
||||
),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/** Provider override config (baseUrl, headers, apiKey) without custom models */
|
||||
interface ProviderOverride {
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
apiKey?: string;
|
||||
}
|
||||
|
||||
/** Result of loading custom models from models.json */
|
||||
interface CustomModelsResult {
|
||||
models: Model<Api>[];
|
||||
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
|
||||
overrides: Map<string, ProviderOverride>;
|
||||
/** Per-model overrides: provider -> modelId -> override */
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>;
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
|
||||
}
|
||||
|
||||
function mergeCompat(
|
||||
baseCompat: Model<Api>["compat"],
|
||||
overrideCompat: ModelOverride["compat"],
|
||||
): Model<Api>["compat"] | undefined {
|
||||
if (!overrideCompat) return baseCompat;
|
||||
|
||||
const base = baseCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat
|
||||
| undefined;
|
||||
const override = overrideCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
const merged = { ...base, ...override } as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
|
||||
const baseCompletions = base as OpenAICompletionsCompat | undefined;
|
||||
const overrideCompletions = override as OpenAICompletionsCompat;
|
||||
const mergedCompletions = merged as OpenAICompletionsCompat;
|
||||
|
||||
if (
|
||||
baseCompletions?.openRouterRouting ||
|
||||
overrideCompletions.openRouterRouting
|
||||
) {
|
||||
mergedCompletions.openRouterRouting = {
|
||||
...baseCompletions?.openRouterRouting,
|
||||
...overrideCompletions.openRouterRouting,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
baseCompletions?.vercelGatewayRouting ||
|
||||
overrideCompletions.vercelGatewayRouting
|
||||
) {
|
||||
mergedCompletions.vercelGatewayRouting = {
|
||||
...baseCompletions?.vercelGatewayRouting,
|
||||
...overrideCompletions.vercelGatewayRouting,
|
||||
};
|
||||
}
|
||||
|
||||
return merged as Model<Api>["compat"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Deep merge a model override into a model.
|
||||
* Handles nested objects (cost, compat) by merging rather than replacing.
|
||||
*/
|
||||
function applyModelOverride(
|
||||
model: Model<Api>,
|
||||
override: ModelOverride,
|
||||
): Model<Api> {
|
||||
const result = { ...model };
|
||||
|
||||
// Simple field overrides
|
||||
if (override.name !== undefined) result.name = override.name;
|
||||
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
|
||||
if (override.input !== undefined)
|
||||
result.input = override.input as ("text" | "image")[];
|
||||
if (override.contextWindow !== undefined)
|
||||
result.contextWindow = override.contextWindow;
|
||||
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
|
||||
|
||||
// Merge cost (partial override)
|
||||
if (override.cost) {
|
||||
result.cost = {
|
||||
input: override.cost.input ?? model.cost.input,
|
||||
output: override.cost.output ?? model.cost.output,
|
||||
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
|
||||
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
|
||||
};
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
if (override.headers) {
|
||||
const resolvedHeaders = resolveHeaders(override.headers);
|
||||
result.headers = resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers;
|
||||
}
|
||||
|
||||
// Deep merge compat
|
||||
result.compat = mergeCompat(model.compat, override.compat);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export const clearApiKeyCache = clearConfigValueCache;
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | undefined = join(
|
||||
getAgentDir(),
|
||||
"models.json",
|
||||
),
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
if (keyConfig) {
|
||||
return resolveConfigValue(keyConfig);
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = undefined;
|
||||
|
||||
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
|
||||
resetApiProviders();
|
||||
resetOAuthProviders();
|
||||
|
||||
this.loadModels();
|
||||
|
||||
for (const [providerName, config] of this.registeredProviders.entries()) {
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (undefined if no error).
|
||||
*/
|
||||
getError(): string | undefined {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load custom models and overrides from models.json
|
||||
const {
|
||||
models: customModels,
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error,
|
||||
} = this.modelsJsonPath
|
||||
? this.loadCustomModels(this.modelsJsonPath)
|
||||
: emptyCustomModelsResult();
|
||||
|
||||
if (error) {
|
||||
this.loadError = error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
|
||||
let combined = this.mergeCustomModels(builtInModels, customModels);
|
||||
|
||||
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||
const cred = this.authStorage.get(oauthProvider.id);
|
||||
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
|
||||
combined = oauthProvider.modifyModels(combined, cred);
|
||||
}
|
||||
}
|
||||
|
||||
this.models = combined;
|
||||
}
|
||||
|
||||
/** Load built-in models and apply provider/model overrides */
|
||||
private loadBuiltInModels(
|
||||
overrides: Map<string, ProviderOverride>,
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>,
|
||||
): Model<Api>[] {
|
||||
return getProviders().flatMap((provider) => {
|
||||
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||
const providerOverride = overrides.get(provider);
|
||||
const perModelOverrides = modelOverrides.get(provider);
|
||||
|
||||
return models.map((m) => {
|
||||
let model = m;
|
||||
|
||||
// Apply provider-level baseUrl/headers override
|
||||
if (providerOverride) {
|
||||
const resolvedHeaders = resolveHeaders(providerOverride.headers);
|
||||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply per-model override
|
||||
const modelOverride = perModelOverrides?.get(m.id);
|
||||
if (modelOverride) {
|
||||
model = applyModelOverride(model, modelOverride);
|
||||
}
|
||||
|
||||
return model;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
|
||||
private mergeCustomModels(
|
||||
builtInModels: Model<Api>[],
|
||||
customModels: Model<Api>[],
|
||||
): Model<Api>[] {
|
||||
const merged = [...builtInModels];
|
||||
for (const customModel of customModels) {
|
||||
const existingIndex = merged.findIndex(
|
||||
(m) => m.provider === customModel.provider && m.id === customModel.id,
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = customModel;
|
||||
} else {
|
||||
merged.push(customModel);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return emptyCustomModelsResult();
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const validate = ajv.getSchema("ModelsConfig")!;
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
|
||||
.join("\n") || "Unknown schema error";
|
||||
return emptyCustomModelsResult(
|
||||
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
const overrides = new Map<string, ProviderOverride>();
|
||||
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
|
||||
if (
|
||||
providerConfig.baseUrl ||
|
||||
providerConfig.headers ||
|
||||
providerConfig.apiKey
|
||||
) {
|
||||
overrides.set(providerName, {
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
headers: providerConfig.headers,
|
||||
apiKey: providerConfig.apiKey,
|
||||
});
|
||||
}
|
||||
|
||||
// Store API key for fallback resolver.
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
if (providerConfig.modelOverrides) {
|
||||
modelOverrides.set(
|
||||
providerName,
|
||||
new Map(Object.entries(providerConfig.modelOverrides)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
models: this.parseModels(config),
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error: undefined,
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
const models = providerConfig.models ?? [];
|
||||
const hasModelOverrides =
|
||||
providerConfig.modelOverrides &&
|
||||
Object.keys(providerConfig.modelOverrides).length > 0;
|
||||
|
||||
if (models.length === 0) {
|
||||
// Override-only config: needs baseUrl OR modelOverrides (or both)
|
||||
if (!providerConfig.baseUrl && !hasModelOverrides) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Custom models are merged into provider models and require endpoint + auth.
|
||||
if (!providerConfig.baseUrl) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
if (!providerConfig.apiKey) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (const modelDef of models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id)
|
||||
throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
// Validate contextWindow/maxTokens only if provided (they have defaults)
|
||||
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
|
||||
);
|
||||
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const modelDefs = providerConfig.models ?? [];
|
||||
if (modelDefs.length === 0) continue; // Override-only, no custom models
|
||||
|
||||
// Store API key config for fallback resolver
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
for (const modelDef of modelDefs) {
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
if (!api) continue;
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
// Resolve env vars and shell commands in header values
|
||||
const providerHeaders = resolveHeaders(providerConfig.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader && providerConfig.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
// Provider baseUrl is required when custom models are defined.
|
||||
// Individual models can override it with modelDef.baseUrl.
|
||||
const defaultCost = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name ?? modelDef.id,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
|
||||
reasoning: modelDef.reasoning ?? false,
|
||||
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
|
||||
cost: modelDef.cost ?? defaultCost,
|
||||
contextWindow: modelDef.contextWindow ?? 128000,
|
||||
maxTokens: modelDef.maxTokens ?? 16384,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have auth configured.
|
||||
* This is a fast check that doesn't refresh OAuth tokens.
|
||||
*/
|
||||
getAvailable(): Model<Api>[] {
|
||||
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | undefined {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
*/
|
||||
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
*/
|
||||
isUsingOAuth(model: Model<Api>): boolean {
|
||||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a provider dynamically (from extensions).
|
||||
*
|
||||
* If provider has models: replaces all existing models for this provider.
|
||||
* If provider has only baseUrl/headers: overrides existing models' URLs.
|
||||
* If provider has oauth: registers OAuth provider for /login support.
|
||||
*/
|
||||
registerProvider(providerName: string, config: ProviderConfigInput): void {
|
||||
this.registeredProviders.set(providerName, config);
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister a previously registered provider.
|
||||
*
|
||||
* Removes the provider from the registry and reloads models from disk so that
|
||||
* built-in models overridden by this provider are restored to their original state.
|
||||
* Also resets dynamic OAuth and API stream registrations before reapplying
|
||||
* remaining dynamic providers.
|
||||
* Has no effect if the provider was never registered.
|
||||
*/
|
||||
unregisterProvider(providerName: string): void {
|
||||
if (!this.registeredProviders.has(providerName)) return;
|
||||
this.registeredProviders.delete(providerName);
|
||||
this.customProviderApiKeys.delete(providerName);
|
||||
this.refresh();
|
||||
}
|
||||
|
||||
private applyProviderConfig(
|
||||
providerName: string,
|
||||
config: ProviderConfigInput,
|
||||
): void {
|
||||
// Register OAuth provider if provided
|
||||
if (config.oauth) {
|
||||
// Ensure the OAuth provider ID matches the provider name
|
||||
const oauthProvider: OAuthProviderInterface = {
|
||||
...config.oauth,
|
||||
id: providerName,
|
||||
};
|
||||
registerOAuthProvider(oauthProvider);
|
||||
}
|
||||
|
||||
if (config.streamSimple) {
|
||||
if (!config.api) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "api" is required when registering streamSimple.`,
|
||||
);
|
||||
}
|
||||
const streamSimple = config.streamSimple;
|
||||
registerApiProvider(
|
||||
{
|
||||
api: config.api,
|
||||
stream: (model, context, options) =>
|
||||
streamSimple(model, context, options as SimpleStreamOptions),
|
||||
streamSimple,
|
||||
},
|
||||
`provider:${providerName}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Store API key for auth resolution
|
||||
if (config.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||
}
|
||||
|
||||
if (config.models && config.models.length > 0) {
|
||||
// Full replacement: remove existing models for this provider
|
||||
this.models = this.models.filter((m) => m.provider !== providerName);
|
||||
|
||||
// Validate required fields
|
||||
if (!config.baseUrl) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining models.`,
|
||||
);
|
||||
}
|
||||
if (!config.apiKey && !config.oauth) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Parse and add new models
|
||||
for (const modelDef of config.models) {
|
||||
const api = modelDef.api || config.api;
|
||||
if (!api) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
const providerHeaders = resolveHeaders(config.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header
|
||||
if (config.authHeader && config.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(config.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
this.models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: config.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
|
||||
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
|
||||
if (config.oauth?.modifyModels) {
|
||||
const cred = this.authStorage.get(providerName);
|
||||
if (cred?.type === "oauth") {
|
||||
this.models = config.oauth.modifyModels(this.models, cred);
|
||||
}
|
||||
}
|
||||
} else if (config.baseUrl) {
|
||||
// Override-only: update baseUrl/headers for existing models
|
||||
const resolvedHeaders = resolveHeaders(config.headers);
|
||||
this.models = this.models.map((m) => {
|
||||
if (m.provider !== providerName) return m;
|
||||
return {
|
||||
...m,
|
||||
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||
headers: resolvedHeaders
|
||||
? { ...m.headers, ...resolvedHeaders }
|
||||
: m.headers,
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Input type for registerProvider API.
|
||||
*/
|
||||
export interface ProviderConfigInput {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
streamSimple?: (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
oauth?: Omit<OAuthProviderInterface, "id">;
|
||||
models?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
api?: Api;
|
||||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
compat?: Model<Api>["compat"];
|
||||
}>;
|
||||
}
|
||||
707
packages/coding-agent/src/core/model-resolver.ts
Normal file
707
packages/coding-agent/src/core/model-resolver.ts
Normal file
|
|
@ -0,0 +1,707 @@
|
|||
/**
|
||||
* Model resolution, scoping, and initial selection
|
||||
*/
|
||||
|
||||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import {
|
||||
type Api,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
modelsAreEqual,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { minimatch } from "minimatch";
|
||||
import { isValidThinkingLevel } from "../cli/args.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
/** Default model IDs for each known provider */
|
||||
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
||||
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
|
||||
anthropic: "claude-opus-4-6",
|
||||
openai: "gpt-5.4",
|
||||
"azure-openai-responses": "gpt-5.2",
|
||||
"openai-codex": "gpt-5.4",
|
||||
google: "gemini-2.5-pro",
|
||||
"google-gemini-cli": "gemini-2.5-pro",
|
||||
"google-antigravity": "gemini-3.1-pro-high",
|
||||
"google-vertex": "gemini-3-pro-preview",
|
||||
"github-copilot": "gpt-4o",
|
||||
openrouter: "openai/gpt-5.1-codex",
|
||||
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
|
||||
xai: "grok-4-fast-non-reasoning",
|
||||
groq: "openai/gpt-oss-120b",
|
||||
cerebras: "zai-glm-4.6",
|
||||
zai: "glm-4.6",
|
||||
mistral: "devstral-medium-latest",
|
||||
minimax: "MiniMax-M2.1",
|
||||
"minimax-cn": "MiniMax-M2.1",
|
||||
huggingface: "moonshotai/Kimi-K2.5",
|
||||
opencode: "claude-opus-4-6",
|
||||
"opencode-go": "kimi-k2.5",
|
||||
"kimi-coding": "kimi-k2-thinking",
|
||||
};
|
||||
|
||||
export interface ScopedModel {
|
||||
model: Model<Api>;
|
||||
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to check if a model ID looks like an alias (no date suffix)
|
||||
* Dates are typically in format: -20241022 or -20250929
|
||||
*/
|
||||
function isAlias(id: string): boolean {
|
||||
// Check if ID ends with -latest
|
||||
if (id.endsWith("-latest")) return true;
|
||||
|
||||
// Check if ID ends with a date pattern (-YYYYMMDD)
|
||||
const datePattern = /-\d{8}$/;
|
||||
return !datePattern.test(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to match a pattern to a model from the available models list.
|
||||
* Returns the matched model or undefined if no match found.
|
||||
*/
|
||||
function tryMatchModel(
|
||||
modelPattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
): Model<Api> | undefined {
|
||||
// Check for provider/modelId format (provider is everything before the first /)
|
||||
const slashIndex = modelPattern.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const provider = modelPattern.substring(0, slashIndex);
|
||||
const modelId = modelPattern.substring(slashIndex + 1);
|
||||
const providerMatch = availableModels.find(
|
||||
(m) =>
|
||||
m.provider.toLowerCase() === provider.toLowerCase() &&
|
||||
m.id.toLowerCase() === modelId.toLowerCase(),
|
||||
);
|
||||
if (providerMatch) {
|
||||
return providerMatch;
|
||||
}
|
||||
// No exact provider/model match - fall through to other matching
|
||||
}
|
||||
|
||||
// Check for exact ID match (case-insensitive)
|
||||
const exactMatch = availableModels.find(
|
||||
(m) => m.id.toLowerCase() === modelPattern.toLowerCase(),
|
||||
);
|
||||
if (exactMatch) {
|
||||
return exactMatch;
|
||||
}
|
||||
|
||||
// No exact match - fall back to partial matching
|
||||
const matches = availableModels.filter(
|
||||
(m) =>
|
||||
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
|
||||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
|
||||
);
|
||||
|
||||
if (matches.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Separate into aliases and dated versions
|
||||
const aliases = matches.filter((m) => isAlias(m.id));
|
||||
const datedVersions = matches.filter((m) => !isAlias(m.id));
|
||||
|
||||
if (aliases.length > 0) {
|
||||
// Prefer alias - if multiple aliases, pick the one that sorts highest
|
||||
aliases.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return aliases[0];
|
||||
} else {
|
||||
// No alias found, pick latest dated version
|
||||
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return datedVersions[0];
|
||||
}
|
||||
}
|
||||
|
||||
export interface ParsedModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
/** Thinking level if explicitly specified in pattern, undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
}
|
||||
|
||||
function buildFallbackModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
availableModels: Model<Api>[],
|
||||
): Model<Api> | undefined {
|
||||
const providerModels = availableModels.filter((m) => m.provider === provider);
|
||||
if (providerModels.length === 0) return undefined;
|
||||
|
||||
const defaultId = defaultModelPerProvider[provider as KnownProvider];
|
||||
const baseModel = defaultId
|
||||
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
|
||||
: providerModels[0];
|
||||
|
||||
return {
|
||||
...baseModel,
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a pattern to extract model and thinking level.
|
||||
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Try to match full pattern as a model
|
||||
* 2. If found, return it with "off" thinking level
|
||||
* 3. If not found and has colons, split on last colon:
|
||||
* - If suffix is valid thinking level, use it and recurse on prefix
|
||||
* - If suffix is invalid, warn and recurse on prefix with "off"
|
||||
*
|
||||
* @internal Exported for testing
|
||||
*/
|
||||
export function parseModelPattern(
|
||||
pattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
options?: { allowInvalidThinkingLevelFallback?: boolean },
|
||||
): ParsedModelResult {
|
||||
// Try exact match first
|
||||
const exactMatch = tryMatchModel(pattern, availableModels);
|
||||
if (exactMatch) {
|
||||
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// No match - try splitting on last colon if present
|
||||
const lastColonIndex = pattern.lastIndexOf(":");
|
||||
if (lastColonIndex === -1) {
|
||||
// No colons, pattern simply doesn't match any model
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
const prefix = pattern.substring(0, lastColonIndex);
|
||||
const suffix = pattern.substring(lastColonIndex + 1);
|
||||
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
// Valid thinking level - recurse on prefix and use this level
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
// Only use this thinking level if no warning from inner recursion
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: result.warning ? undefined : suffix,
|
||||
warning: result.warning,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
// Invalid suffix
|
||||
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
|
||||
if (!allowFallback) {
|
||||
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
|
||||
// This avoids accidentally resolving to a different model.
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// Scope mode: recurse on prefix and warn
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: undefined,
|
||||
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve model patterns to actual Model objects with optional thinking levels
|
||||
* Format: "pattern:level" where :level is optional
|
||||
* For each pattern, finds all matching models and picks the best version:
|
||||
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
|
||||
* 2. If no alias, pick the latest dated version
|
||||
*
|
||||
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
||||
* The algorithm tries to match the full pattern first, then progressively
|
||||
* strips colon-suffixes to find a match.
|
||||
*/
|
||||
export async function resolveModelScope(
|
||||
patterns: string[],
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<ScopedModel[]> {
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
const scopedModels: ScopedModel[] = [];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
// Check if pattern contains glob characters
|
||||
if (
|
||||
pattern.includes("*") ||
|
||||
pattern.includes("?") ||
|
||||
pattern.includes("[")
|
||||
) {
|
||||
// Extract optional thinking level suffix (e.g., "provider/*:high")
|
||||
const colonIdx = pattern.lastIndexOf(":");
|
||||
let globPattern = pattern;
|
||||
let thinkingLevel: ThinkingLevel | undefined;
|
||||
|
||||
if (colonIdx !== -1) {
|
||||
const suffix = pattern.substring(colonIdx + 1);
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
thinkingLevel = suffix;
|
||||
globPattern = pattern.substring(0, colonIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// Match against "provider/modelId" format OR just model ID
|
||||
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
|
||||
const matchingModels = availableModels.filter((m) => {
|
||||
const fullId = `${m.provider}/${m.id}`;
|
||||
return (
|
||||
minimatch(fullId, globPattern, { nocase: true }) ||
|
||||
minimatch(m.id, globPattern, { nocase: true })
|
||||
);
|
||||
});
|
||||
|
||||
if (matchingModels.length === 0) {
|
||||
console.warn(
|
||||
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const model of matchingModels) {
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(
|
||||
pattern,
|
||||
availableModels,
|
||||
);
|
||||
|
||||
if (warning) {
|
||||
console.warn(chalk.yellow(`Warning: ${warning}`));
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
console.warn(
|
||||
chalk.yellow(`Warning: No models match pattern "${pattern}"`),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Avoid duplicates
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
|
||||
return scopedModels;
|
||||
}
|
||||
|
||||
export interface ResolveCliModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
/**
|
||||
* Error message suitable for CLI display.
|
||||
* When set, model will be undefined.
|
||||
*/
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a single model from CLI flags.
|
||||
*
|
||||
* Supports:
|
||||
* - --provider <provider> --model <pattern>
|
||||
* - --model <provider>/<pattern>
|
||||
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
|
||||
*
|
||||
* Note: This does not apply the thinking level by itself, but it may *parse* and
|
||||
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
|
||||
*/
|
||||
export function resolveCliModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): ResolveCliModelResult {
|
||||
const { cliProvider, cliModel, modelRegistry } = options;
|
||||
|
||||
if (!cliModel) {
|
||||
return { model: undefined, warning: undefined, error: undefined };
|
||||
}
|
||||
|
||||
// Important: use *all* models here, not just models with pre-configured auth.
|
||||
// This allows "--api-key" to be used for first-time setup.
|
||||
const availableModels = modelRegistry.getAll();
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error:
|
||||
"No models available. Check your installation or add models to models.json.",
|
||||
};
|
||||
}
|
||||
|
||||
// Build canonical provider lookup (case-insensitive)
|
||||
const providerMap = new Map<string, string>();
|
||||
for (const m of availableModels) {
|
||||
providerMap.set(m.provider.toLowerCase(), m.provider);
|
||||
}
|
||||
|
||||
let provider = cliProvider
|
||||
? providerMap.get(cliProvider.toLowerCase())
|
||||
: undefined;
|
||||
if (cliProvider && !provider) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
|
||||
};
|
||||
}
|
||||
|
||||
// If no explicit --provider, try to interpret "provider/model" format first.
|
||||
// When the prefix before the first slash matches a known provider, prefer that
|
||||
// interpretation over matching models whose IDs literally contain slashes
|
||||
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
|
||||
// vercel-ai-gateway model with id "zai/glm-5").
|
||||
let pattern = cliModel;
|
||||
let inferredProvider = false;
|
||||
|
||||
if (!provider) {
|
||||
const slashIndex = cliModel.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const maybeProvider = cliModel.substring(0, slashIndex);
|
||||
const canonical = providerMap.get(maybeProvider.toLowerCase());
|
||||
if (canonical) {
|
||||
provider = canonical;
|
||||
pattern = cliModel.substring(slashIndex + 1);
|
||||
inferredProvider = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider was inferred from the slash, try exact matches without provider inference.
|
||||
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
|
||||
if (!provider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) =>
|
||||
m.id.toLowerCase() === lower ||
|
||||
`${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return {
|
||||
model: exact,
|
||||
warning: undefined,
|
||||
thinkingLevel: undefined,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (cliProvider && provider) {
|
||||
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
|
||||
const prefix = `${provider}/`;
|
||||
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
|
||||
pattern = cliModel.substring(prefix.length);
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = provider
|
||||
? availableModels.filter((m) => m.provider === provider)
|
||||
: availableModels;
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(
|
||||
pattern,
|
||||
candidates,
|
||||
{
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
},
|
||||
);
|
||||
|
||||
if (model) {
|
||||
return { model, thinkingLevel, warning, error: undefined };
|
||||
}
|
||||
|
||||
// If we inferred a provider from the slash but found no match within that provider,
|
||||
// fall back to matching the full input as a raw model id across all models.
|
||||
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
|
||||
// looks like a provider but the full string is actually a model id on openrouter.
|
||||
if (inferredProvider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) =>
|
||||
m.id.toLowerCase() === lower ||
|
||||
`${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return {
|
||||
model: exact,
|
||||
warning: undefined,
|
||||
thinkingLevel: undefined,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
// Also try parseModelPattern on the full input against all models
|
||||
const fallback = parseModelPattern(cliModel, availableModels, {
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
});
|
||||
if (fallback.model) {
|
||||
return {
|
||||
model: fallback.model,
|
||||
thinkingLevel: fallback.thinkingLevel,
|
||||
warning: fallback.warning,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (provider) {
|
||||
const fallbackModel = buildFallbackModel(
|
||||
provider,
|
||||
pattern,
|
||||
availableModels,
|
||||
);
|
||||
if (fallbackModel) {
|
||||
const fallbackWarning = warning
|
||||
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
|
||||
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
|
||||
return {
|
||||
model: fallbackModel,
|
||||
thinkingLevel: undefined,
|
||||
warning: fallbackWarning,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const display = provider ? `${provider}/${pattern}` : cliModel;
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: undefined,
|
||||
warning,
|
||||
error: `Model "${display}" not found. Use --list-models to see available models.`,
|
||||
};
|
||||
}
|
||||
|
||||
export interface InitialModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
fallbackMessage: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the initial model to use based on priority:
|
||||
* 1. CLI args (provider + model)
|
||||
* 2. First model from scoped models (if not continuing/resuming)
|
||||
* 3. Restored from session (if continuing/resuming)
|
||||
* 4. Saved default from settings
|
||||
* 5. First available model with valid API key
|
||||
*/
|
||||
export async function findInitialModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
scopedModels: ScopedModel[];
|
||||
isContinuing: boolean;
|
||||
defaultProvider?: string;
|
||||
defaultModelId?: string;
|
||||
defaultThinkingLevel?: ThinkingLevel;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<InitialModelResult> {
|
||||
const {
|
||||
cliProvider,
|
||||
cliModel,
|
||||
scopedModels,
|
||||
isContinuing,
|
||||
defaultProvider,
|
||||
defaultModelId,
|
||||
defaultThinkingLevel,
|
||||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | undefined;
|
||||
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
|
||||
|
||||
// 1. CLI args take priority
|
||||
if (cliProvider && cliModel) {
|
||||
const resolved = resolveCliModel({
|
||||
cliProvider,
|
||||
cliModel,
|
||||
modelRegistry,
|
||||
});
|
||||
if (resolved.error) {
|
||||
console.error(chalk.red(resolved.error));
|
||||
process.exit(1);
|
||||
}
|
||||
if (resolved.model) {
|
||||
return {
|
||||
model: resolved.model,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Use first model from scoped models (skip if continuing/resuming)
|
||||
if (scopedModels.length > 0 && !isContinuing) {
|
||||
return {
|
||||
model: scopedModels[0].model,
|
||||
thinkingLevel:
|
||||
scopedModels[0].thinkingLevel ??
|
||||
defaultThinkingLevel ??
|
||||
DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// 3. Try saved default from settings
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (found) {
|
||||
model = found;
|
||||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try first available model with valid API key
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
for (const provider of Object.keys(
|
||||
defaultModelPerProvider,
|
||||
) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === provider && m.id === defaultId,
|
||||
);
|
||||
if (match) {
|
||||
return {
|
||||
model: match,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
return {
|
||||
model: availableModels[0],
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// 5. No model found
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore model from session, with fallback to available models
|
||||
*/
|
||||
export async function restoreModelFromSession(
|
||||
savedProvider: string,
|
||||
savedModelId: string,
|
||||
currentModel: Model<Api> | undefined,
|
||||
shouldPrintMessages: boolean,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{
|
||||
model: Model<Api> | undefined;
|
||||
fallbackMessage: string | undefined;
|
||||
}> {
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
const hasApiKey = restoredModel
|
||||
? !!(await modelRegistry.getApiKey(restoredModel))
|
||||
: false;
|
||||
|
||||
if (restoredModel && hasApiKey) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`),
|
||||
);
|
||||
}
|
||||
return { model: restoredModel, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// Model not found or no API key - fall back
|
||||
const reason = !restoredModel
|
||||
? "model no longer exists"
|
||||
: "no API key available";
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.error(
|
||||
chalk.yellow(
|
||||
`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// If we already have a model, use it as fallback
|
||||
if (currentModel) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(
|
||||
`Falling back to: ${currentModel.provider}/${currentModel.id}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
return {
|
||||
model: currentModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Try to find any available model
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
let fallbackModel: Model<Api> | undefined;
|
||||
for (const provider of Object.keys(
|
||||
defaultModelPerProvider,
|
||||
) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find(
|
||||
(m) => m.provider === provider && m.id === defaultId,
|
||||
);
|
||||
if (match) {
|
||||
fallbackModel = match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
if (!fallbackModel) {
|
||||
fallbackModel = availableModels[0];
|
||||
}
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.log(
|
||||
chalk.dim(
|
||||
`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
model: fallbackModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// No models available
|
||||
return { model: undefined, fallbackMessage: undefined };
|
||||
}
|
||||
2087
packages/coding-agent/src/core/package-manager.ts
Normal file
2087
packages/coding-agent/src/core/package-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
327
packages/coding-agent/src/core/prompt-templates.ts
Normal file
327
packages/coding-agent/src/core/prompt-templates.ts
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { basename, isAbsolute, join, resolve, sep } from "path";
|
||||
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
|
||||
import { parseFrontmatter } from "../utils/frontmatter.js";
|
||||
|
||||
/**
|
||||
* Represents a prompt template loaded from a markdown file
|
||||
*/
|
||||
export interface PromptTemplate {
|
||||
name: string;
|
||||
description: string;
|
||||
content: string;
|
||||
source: string; // "user", "project", or "path"
|
||||
filePath: string; // Absolute path to the template file
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse command arguments respecting quoted strings (bash-style)
|
||||
* Returns array of arguments
|
||||
*/
|
||||
export function parseCommandArgs(argsString: string): string[] {
|
||||
const args: string[] = [];
|
||||
let current = "";
|
||||
let inQuote: string | null = null;
|
||||
|
||||
for (let i = 0; i < argsString.length; i++) {
|
||||
const char = argsString[i];
|
||||
|
||||
if (inQuote) {
|
||||
if (char === inQuote) {
|
||||
inQuote = null;
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
} else if (char === '"' || char === "'") {
|
||||
inQuote = char;
|
||||
} else if (char === " " || char === "\t") {
|
||||
if (current) {
|
||||
args.push(current);
|
||||
current = "";
|
||||
}
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
}
|
||||
|
||||
if (current) {
|
||||
args.push(current);
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitute argument placeholders in template content
|
||||
* Supports:
|
||||
* - $1, $2, ... for positional args
|
||||
* - $@ and $ARGUMENTS for all args
|
||||
* - ${@:N} for args from Nth onwards (bash-style slicing)
|
||||
* - ${@:N:L} for L args starting from Nth
|
||||
*
|
||||
* Note: Replacement happens on the template string only. Argument values
|
||||
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
|
||||
*/
|
||||
export function substituteArgs(content: string, args: string[]): string {
|
||||
let result = content;
|
||||
|
||||
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
|
||||
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
|
||||
result = result.replace(/\$(\d+)/g, (_, num) => {
|
||||
const index = parseInt(num, 10) - 1;
|
||||
return args[index] ?? "";
|
||||
});
|
||||
|
||||
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
|
||||
// Process BEFORE simple $@ to avoid conflicts
|
||||
result = result.replace(
|
||||
/\$\{@:(\d+)(?::(\d+))?\}/g,
|
||||
(_, startStr, lengthStr) => {
|
||||
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
|
||||
// Treat 0 as 1 (bash convention: args start at 1)
|
||||
if (start < 0) start = 0;
|
||||
|
||||
if (lengthStr) {
|
||||
const length = parseInt(lengthStr, 10);
|
||||
return args.slice(start, start + length).join(" ");
|
||||
}
|
||||
return args.slice(start).join(" ");
|
||||
},
|
||||
);
|
||||
|
||||
// Pre-compute all args joined (optimization)
|
||||
const allArgs = args.join(" ");
|
||||
|
||||
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
|
||||
result = result.replace(/\$ARGUMENTS/g, allArgs);
|
||||
|
||||
// Replace $@ with all args joined (existing syntax)
|
||||
result = result.replace(/\$@/g, allArgs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function loadTemplateFromFile(
|
||||
filePath: string,
|
||||
source: string,
|
||||
sourceLabel: string,
|
||||
): PromptTemplate | null {
|
||||
try {
|
||||
const rawContent = readFileSync(filePath, "utf-8");
|
||||
const { frontmatter, body } =
|
||||
parseFrontmatter<Record<string, string>>(rawContent);
|
||||
|
||||
const name = basename(filePath).replace(/\.md$/, "");
|
||||
|
||||
// Get description from frontmatter or first non-empty line
|
||||
let description = frontmatter.description || "";
|
||||
if (!description) {
|
||||
const firstLine = body.split("\n").find((line) => line.trim());
|
||||
if (firstLine) {
|
||||
// Truncate if too long
|
||||
description = firstLine.slice(0, 60);
|
||||
if (firstLine.length > 60) description += "...";
|
||||
}
|
||||
}
|
||||
|
||||
// Append source to description
|
||||
description = description ? `${description} ${sourceLabel}` : sourceLabel;
|
||||
|
||||
return {
|
||||
name,
|
||||
description,
|
||||
content: body,
|
||||
source,
|
||||
filePath,
|
||||
};
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
|
||||
*/
|
||||
function loadTemplatesFromDir(
|
||||
dir: string,
|
||||
source: string,
|
||||
sourceLabel: string,
|
||||
): PromptTemplate[] {
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (!existsSync(dir)) {
|
||||
return templates;
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const fullPath = join(dir, entry.name);
|
||||
|
||||
// For symlinks, check if they point to a file
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
const stats = statSync(fullPath);
|
||||
isFile = stats.isFile();
|
||||
} catch {
|
||||
// Broken symlink, skip it
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFile && entry.name.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return templates;
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
export interface LoadPromptTemplatesOptions {
|
||||
/** Working directory for project-local templates. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Agent config directory for global templates. Default: from getPromptsDir() */
|
||||
agentDir?: string;
|
||||
/** Explicit prompt template paths (files or directories) */
|
||||
promptPaths?: string[];
|
||||
/** Include default prompt directories. Default: true */
|
||||
includeDefaults?: boolean;
|
||||
}
|
||||
|
||||
function normalizePath(input: string): string {
|
||||
const trimmed = input.trim();
|
||||
if (trimmed === "~") return homedir();
|
||||
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
|
||||
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
function resolvePromptPath(p: string, cwd: string): string {
|
||||
const normalized = normalizePath(p);
|
||||
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
|
||||
}
|
||||
|
||||
function buildPathSourceLabel(p: string): string {
|
||||
const base = basename(p).replace(/\.md$/, "") || "path";
|
||||
return `(path:${base})`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all prompt templates from:
|
||||
* 1. Global: agentDir/prompts/
|
||||
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
* 3. Explicit prompt paths
|
||||
*/
|
||||
export function loadPromptTemplates(
|
||||
options: LoadPromptTemplatesOptions = {},
|
||||
): PromptTemplate[] {
|
||||
const resolvedCwd = options.cwd ?? process.cwd();
|
||||
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
|
||||
const promptPaths = options.promptPaths ?? [];
|
||||
const includeDefaults = options.includeDefaults ?? true;
|
||||
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (includeDefaults) {
|
||||
// 1. Load global templates from agentDir/prompts/
|
||||
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
|
||||
const globalPromptsDir = options.agentDir
|
||||
? join(options.agentDir, "prompts")
|
||||
: resolvedAgentDir;
|
||||
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
|
||||
|
||||
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
templates.push(
|
||||
...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"),
|
||||
);
|
||||
}
|
||||
|
||||
const userPromptsDir = options.agentDir
|
||||
? join(options.agentDir, "prompts")
|
||||
: resolvedAgentDir;
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
|
||||
const isUnderPath = (target: string, root: string): boolean => {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep)
|
||||
? normalizedRoot
|
||||
: `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
};
|
||||
|
||||
const getSourceInfo = (
|
||||
resolvedPath: string,
|
||||
): { source: string; label: string } => {
|
||||
if (!includeDefaults) {
|
||||
if (isUnderPath(resolvedPath, userPromptsDir)) {
|
||||
return { source: "user", label: "(user)" };
|
||||
}
|
||||
if (isUnderPath(resolvedPath, projectPromptsDir)) {
|
||||
return { source: "project", label: "(project)" };
|
||||
}
|
||||
}
|
||||
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
|
||||
};
|
||||
|
||||
// 3. Load explicit prompt paths
|
||||
for (const rawPath of promptPaths) {
|
||||
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
|
||||
if (!existsSync(resolvedPath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolvedPath);
|
||||
const { source, label } = getSourceInfo(resolvedPath);
|
||||
if (stats.isDirectory()) {
|
||||
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
|
||||
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(resolvedPath, source, label);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore read failures
|
||||
}
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand a prompt template if it matches a template name.
|
||||
* Returns the expanded content or the original text if not a template.
|
||||
*/
|
||||
export function expandPromptTemplate(
|
||||
text: string,
|
||||
templates: PromptTemplate[],
|
||||
): string {
|
||||
if (!text.startsWith("/")) return text;
|
||||
|
||||
const spaceIndex = text.indexOf(" ");
|
||||
const templateName =
|
||||
spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
|
||||
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
|
||||
|
||||
const template = templates.find((t) => t.name === templateName);
|
||||
if (template) {
|
||||
const args = parseCommandArgs(argsString);
|
||||
return substituteArgs(template.content, args);
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
66
packages/coding-agent/src/core/resolve-config-value.ts
Normal file
66
packages/coding-agent/src/core/resolve-config-value.ts
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Resolve configuration values that may be shell commands, environment variables, or literals.
|
||||
* Used by auth-storage.ts and model-registry.ts.
|
||||
*/
|
||||
|
||||
import { execSync } from "child_process";
|
||||
|
||||
// Cache for shell command results (persists for process lifetime)
|
||||
const commandResultCache = new Map<string, string | undefined>();
|
||||
|
||||
/**
|
||||
* Resolve a config value (API key, header value, etc.) to an actual value.
|
||||
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
|
||||
* - Otherwise checks environment variable first, then treats as literal (not cached)
|
||||
*/
|
||||
export function resolveConfigValue(config: string): string | undefined {
|
||||
if (config.startsWith("!")) {
|
||||
return executeCommand(config);
|
||||
}
|
||||
const envValue = process.env[config];
|
||||
return envValue || config;
|
||||
}
|
||||
|
||||
function executeCommand(commandConfig: string): string | undefined {
|
||||
if (commandResultCache.has(commandConfig)) {
|
||||
return commandResultCache.get(commandConfig);
|
||||
}
|
||||
|
||||
const command = commandConfig.slice(1);
|
||||
let result: string | undefined;
|
||||
try {
|
||||
const output = execSync(command, {
|
||||
encoding: "utf-8",
|
||||
timeout: 10000,
|
||||
stdio: ["ignore", "pipe", "ignore"],
|
||||
});
|
||||
result = output.trim() || undefined;
|
||||
} catch {
|
||||
result = undefined;
|
||||
}
|
||||
|
||||
commandResultCache.set(commandConfig, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve all header values using the same resolution logic as API keys.
|
||||
*/
|
||||
export function resolveHeaders(
|
||||
headers: Record<string, string> | undefined,
|
||||
): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
const resolvedValue = resolveConfigValue(value);
|
||||
if (resolvedValue) {
|
||||
resolved[key] = resolvedValue;
|
||||
}
|
||||
}
|
||||
return Object.keys(resolved).length > 0 ? resolved : undefined;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export function clearConfigValueCache(): void {
|
||||
commandResultCache.clear();
|
||||
}
|
||||
1094
packages/coding-agent/src/core/resource-loader.ts
Normal file
1094
packages/coding-agent/src/core/resource-loader.ts
Normal file
File diff suppressed because it is too large
Load diff
398
packages/coding-agent/src/core/sdk.ts
Normal file
398
packages/coding-agent/src/core/sdk.ts
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
import { join } from "node:path";
|
||||
import {
|
||||
Agent,
|
||||
type AgentMessage,
|
||||
type ThinkingLevel,
|
||||
} from "@mariozechner/pi-agent-core";
|
||||
import type { Message, Model } from "@mariozechner/pi-ai";
|
||||
import { getAgentDir, getDocsPath } from "../config.js";
|
||||
import { AgentSession } from "./agent-session.js";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type {
|
||||
ExtensionRunner,
|
||||
LoadExtensionsResult,
|
||||
ToolDefinition,
|
||||
} from "./extensions/index.js";
|
||||
import { convertToLlm } from "./messages.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { findInitialModel } from "./model-resolver.js";
|
||||
import type { ResourceLoader } from "./resource-loader.js";
|
||||
import { DefaultResourceLoader } from "./resource-loader.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { SettingsManager } from "./settings-manager.js";
|
||||
import { time } from "./timings.js";
|
||||
import {
|
||||
allTools,
|
||||
bashTool,
|
||||
codingTools,
|
||||
createBashTool,
|
||||
createCodingTools,
|
||||
createEditTool,
|
||||
createFindTool,
|
||||
createGrepTool,
|
||||
createLsTool,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createWriteTool,
|
||||
editTool,
|
||||
findTool,
|
||||
grepTool,
|
||||
lsTool,
|
||||
readOnlyTools,
|
||||
readTool,
|
||||
type Tool,
|
||||
type ToolName,
|
||||
writeTool,
|
||||
} from "./tools/index.js";
|
||||
|
||||
export interface CreateAgentSessionOptions {
|
||||
/** Working directory for project-local discovery. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Global config directory. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
|
||||
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
model?: Model<any>;
|
||||
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
/** Models available for cycling (Ctrl+P in interactive mode) */
|
||||
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
|
||||
|
||||
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
|
||||
tools?: Tool[];
|
||||
/** Custom tools to register (in addition to built-in tools). */
|
||||
customTools?: ToolDefinition[];
|
||||
|
||||
/** Resource loader. When omitted, DefaultResourceLoader is used. */
|
||||
resourceLoader?: ResourceLoader;
|
||||
|
||||
/** Session manager. Default: SessionManager.create(cwd) */
|
||||
sessionManager?: SessionManager;
|
||||
|
||||
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
|
||||
settingsManager?: SettingsManager;
|
||||
}
|
||||
|
||||
/** Result from createAgentSession */
|
||||
export interface CreateAgentSessionResult {
|
||||
/** The created session */
|
||||
session: AgentSession;
|
||||
/** Extensions result (for UI context setup in interactive mode) */
|
||||
extensionsResult: LoadExtensionsResult;
|
||||
/** Warning if session was restored with a different model than saved */
|
||||
modelFallbackMessage?: string;
|
||||
}
|
||||
|
||||
// Re-exports
|
||||
|
||||
export type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionContext,
|
||||
ExtensionFactory,
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
ToolDefinition,
|
||||
} from "./extensions/index.js";
|
||||
export type { PromptTemplate } from "./prompt-templates.js";
|
||||
export type { Skill } from "./skills.js";
|
||||
export type { Tool } from "./tools/index.js";
|
||||
|
||||
export {
|
||||
// Pre-built tools (use process.cwd())
|
||||
readTool,
|
||||
bashTool,
|
||||
editTool,
|
||||
writeTool,
|
||||
grepTool,
|
||||
findTool,
|
||||
lsTool,
|
||||
codingTools,
|
||||
readOnlyTools,
|
||||
allTools as allBuiltInTools,
|
||||
// Tool factories (for custom cwd)
|
||||
createCodingTools,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createBashTool,
|
||||
createEditTool,
|
||||
createWriteTool,
|
||||
createGrepTool,
|
||||
createFindTool,
|
||||
createLsTool,
|
||||
};
|
||||
|
||||
// Helper Functions
|
||||
|
||||
function getDefaultAgentDir(): string {
|
||||
return getAgentDir();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AgentSession with the specified options.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Minimal - uses defaults
|
||||
* const { session } = await createAgentSession();
|
||||
*
|
||||
* // With explicit model
|
||||
* import { getModel } from '@mariozechner/pi-ai';
|
||||
* const { session } = await createAgentSession({
|
||||
* model: getModel('anthropic', 'claude-opus-4-5'),
|
||||
* thinkingLevel: 'high',
|
||||
* });
|
||||
*
|
||||
* // Continue previous session
|
||||
* const { session, modelFallbackMessage } = await createAgentSession({
|
||||
* continueSession: true,
|
||||
* });
|
||||
*
|
||||
* // Full control
|
||||
* const loader = new DefaultResourceLoader({
|
||||
* cwd: process.cwd(),
|
||||
* agentDir: getAgentDir(),
|
||||
* settingsManager: SettingsManager.create(),
|
||||
* });
|
||||
* await loader.reload();
|
||||
* const { session } = await createAgentSession({
|
||||
* model: myModel,
|
||||
* tools: [readTool, bashTool],
|
||||
* resourceLoader: loader,
|
||||
* sessionManager: SessionManager.inMemory(),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export async function createAgentSession(
|
||||
options: CreateAgentSessionOptions = {},
|
||||
): Promise<CreateAgentSessionResult> {
|
||||
const cwd = options.cwd ?? process.cwd();
|
||||
const agentDir = options.agentDir ?? getDefaultAgentDir();
|
||||
let resourceLoader = options.resourceLoader;
|
||||
|
||||
// Use provided or create AuthStorage and ModelRegistry
|
||||
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
|
||||
const modelsPath = options.agentDir
|
||||
? join(agentDir, "models.json")
|
||||
: undefined;
|
||||
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
|
||||
const modelRegistry =
|
||||
options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
|
||||
|
||||
const settingsManager =
|
||||
options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
|
||||
|
||||
if (!resourceLoader) {
|
||||
resourceLoader = new DefaultResourceLoader({
|
||||
cwd,
|
||||
agentDir,
|
||||
settingsManager,
|
||||
});
|
||||
await resourceLoader.reload();
|
||||
time("resourceLoader.reload");
|
||||
}
|
||||
|
||||
// Check if session has existing data to restore
|
||||
const existingSession = sessionManager.buildSessionContext();
|
||||
const hasExistingSession = existingSession.messages.length > 0;
|
||||
const hasThinkingEntry = sessionManager
|
||||
.getBranch()
|
||||
.some((entry) => entry.type === "thinking_level_change");
|
||||
|
||||
let model = options.model;
|
||||
let modelFallbackMessage: string | undefined;
|
||||
|
||||
// If session has data, try to restore model from it
|
||||
if (!model && hasExistingSession && existingSession.model) {
|
||||
const restoredModel = modelRegistry.find(
|
||||
existingSession.model.provider,
|
||||
existingSession.model.modelId,
|
||||
);
|
||||
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
||||
model = restoredModel;
|
||||
}
|
||||
if (!model) {
|
||||
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
|
||||
}
|
||||
}
|
||||
|
||||
// If still no model, use findInitialModel (checks settings default, then provider defaults)
|
||||
if (!model) {
|
||||
const result = await findInitialModel({
|
||||
scopedModels: [],
|
||||
isContinuing: hasExistingSession,
|
||||
defaultProvider: settingsManager.getDefaultProvider(),
|
||||
defaultModelId: settingsManager.getDefaultModel(),
|
||||
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
|
||||
modelRegistry,
|
||||
});
|
||||
model = result.model;
|
||||
if (!model) {
|
||||
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
|
||||
} else if (modelFallbackMessage) {
|
||||
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
|
||||
}
|
||||
}
|
||||
|
||||
let thinkingLevel = options.thinkingLevel;
|
||||
|
||||
// If session has data, restore thinking level from it
|
||||
if (thinkingLevel === undefined && hasExistingSession) {
|
||||
thinkingLevel = hasThinkingEntry
|
||||
? (existingSession.thinkingLevel as ThinkingLevel)
|
||||
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
|
||||
}
|
||||
|
||||
// Fall back to settings default
|
||||
if (thinkingLevel === undefined) {
|
||||
thinkingLevel =
|
||||
settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
|
||||
}
|
||||
|
||||
// Clamp to model capabilities
|
||||
if (!model || !model.reasoning) {
|
||||
thinkingLevel = "off";
|
||||
}
|
||||
|
||||
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
|
||||
const initialActiveToolNames: ToolName[] = options.tools
|
||||
? options.tools
|
||||
.map((t) => t.name)
|
||||
.filter((n): n is ToolName => n in allTools)
|
||||
: defaultActiveToolNames;
|
||||
|
||||
let agent: Agent;
|
||||
|
||||
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
|
||||
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
|
||||
const converted = convertToLlm(messages);
|
||||
// Check setting dynamically so mid-session changes take effect
|
||||
if (!settingsManager.getBlockImages()) {
|
||||
return converted;
|
||||
}
|
||||
// Filter out ImageContent from all messages, replacing with text placeholder
|
||||
return converted.map((msg) => {
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
const content = msg.content;
|
||||
if (Array.isArray(content)) {
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (hasImages) {
|
||||
const filteredContent = content
|
||||
.map((c) =>
|
||||
c.type === "image"
|
||||
? {
|
||||
type: "text" as const,
|
||||
text: "Image reading is disabled.",
|
||||
}
|
||||
: c,
|
||||
)
|
||||
.filter(
|
||||
(c, i, arr) =>
|
||||
// Dedupe consecutive "Image reading is disabled." texts
|
||||
!(
|
||||
c.type === "text" &&
|
||||
c.text === "Image reading is disabled." &&
|
||||
i > 0 &&
|
||||
arr[i - 1].type === "text" &&
|
||||
(arr[i - 1] as { type: "text"; text: string }).text ===
|
||||
"Image reading is disabled."
|
||||
),
|
||||
);
|
||||
return { ...msg, content: filteredContent };
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
};
|
||||
|
||||
const extensionRunnerRef: { current?: ExtensionRunner } = {};
|
||||
|
||||
agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "",
|
||||
model,
|
||||
thinkingLevel,
|
||||
tools: [],
|
||||
},
|
||||
convertToLlm: convertToLlmWithBlockImages,
|
||||
sessionId: sessionManager.getSessionId(),
|
||||
transformContext: async (messages) => {
|
||||
const runner = extensionRunnerRef.current;
|
||||
if (!runner) return messages;
|
||||
return runner.emitContext(messages);
|
||||
},
|
||||
steeringMode: settingsManager.getSteeringMode(),
|
||||
followUpMode: settingsManager.getFollowUpMode(),
|
||||
transport: settingsManager.getTransport(),
|
||||
thinkingBudgets: settingsManager.getThinkingBudgets(),
|
||||
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
|
||||
getApiKey: async (provider) => {
|
||||
// Use the provider argument from the in-flight request;
|
||||
// agent.state.model may already be switched mid-turn.
|
||||
const resolvedProvider = provider || agent.state.model?.provider;
|
||||
if (!resolvedProvider) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
|
||||
if (!key) {
|
||||
const model = agent.state.model;
|
||||
const isOAuth = model && modelRegistry.isUsingOAuth(model);
|
||||
if (isOAuth) {
|
||||
throw new Error(
|
||||
`Authentication failed for "${resolvedProvider}". ` +
|
||||
`Credentials may have expired or network is unavailable. ` +
|
||||
`Run '/login ${resolvedProvider}' to re-authenticate.`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`No API key found for "${resolvedProvider}". ` +
|
||||
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
|
||||
);
|
||||
}
|
||||
return key;
|
||||
},
|
||||
});
|
||||
|
||||
// Restore messages if session has existing data
|
||||
if (hasExistingSession) {
|
||||
agent.replaceMessages(existingSession.messages);
|
||||
if (!hasThinkingEntry) {
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
} else {
|
||||
// Save initial model and thinking level for new sessions so they can be restored on resume
|
||||
if (model) {
|
||||
sessionManager.appendModelChange(model.provider, model.id);
|
||||
}
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd,
|
||||
scopedModels: options.scopedModels,
|
||||
resourceLoader,
|
||||
customTools: options.customTools,
|
||||
modelRegistry,
|
||||
initialActiveToolNames,
|
||||
extensionRunnerRef,
|
||||
});
|
||||
const extensionsResult = resourceLoader.getExtensions();
|
||||
|
||||
return {
|
||||
session,
|
||||
extensionsResult,
|
||||
modelFallbackMessage,
|
||||
};
|
||||
}
|
||||
1514
packages/coding-agent/src/core/session-manager.ts
Normal file
1514
packages/coding-agent/src/core/session-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
1057
packages/coding-agent/src/core/settings-manager.ts
Normal file
1057
packages/coding-agent/src/core/settings-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
518
packages/coding-agent/src/core/skills.ts
Normal file
518
packages/coding-agent/src/core/skills.ts
Normal file
|
|
@ -0,0 +1,518 @@
|
|||
import {
|
||||
existsSync,
|
||||
readdirSync,
|
||||
readFileSync,
|
||||
realpathSync,
|
||||
statSync,
|
||||
} from "fs";
|
||||
import ignore from "ignore";
|
||||
import { homedir } from "os";
|
||||
import {
|
||||
basename,
|
||||
dirname,
|
||||
isAbsolute,
|
||||
join,
|
||||
relative,
|
||||
resolve,
|
||||
sep,
|
||||
} from "path";
|
||||
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
|
||||
import { parseFrontmatter } from "../utils/frontmatter.js";
|
||||
import type { ResourceDiagnostic } from "./diagnostics.js";
|
||||
|
||||
/** Max name length per spec */
|
||||
const MAX_NAME_LENGTH = 64;
|
||||
|
||||
/** Max description length per spec */
|
||||
const MAX_DESCRIPTION_LENGTH = 1024;
|
||||
|
||||
const IGNORE_FILE_NAMES = [".gitignore", ".ignore", ".fdignore"];
|
||||
|
||||
type IgnoreMatcher = ReturnType<typeof ignore>;
|
||||
|
||||
function toPosixPath(p: string): string {
|
||||
return p.split(sep).join("/");
|
||||
}
|
||||
|
||||
function prefixIgnorePattern(line: string, prefix: string): string | null {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) return null;
|
||||
if (trimmed.startsWith("#") && !trimmed.startsWith("\\#")) return null;
|
||||
|
||||
let pattern = line;
|
||||
let negated = false;
|
||||
|
||||
if (pattern.startsWith("!")) {
|
||||
negated = true;
|
||||
pattern = pattern.slice(1);
|
||||
} else if (pattern.startsWith("\\!")) {
|
||||
pattern = pattern.slice(1);
|
||||
}
|
||||
|
||||
if (pattern.startsWith("/")) {
|
||||
pattern = pattern.slice(1);
|
||||
}
|
||||
|
||||
const prefixed = prefix ? `${prefix}${pattern}` : pattern;
|
||||
return negated ? `!${prefixed}` : prefixed;
|
||||
}
|
||||
|
||||
function addIgnoreRules(ig: IgnoreMatcher, dir: string, rootDir: string): void {
|
||||
const relativeDir = relative(rootDir, dir);
|
||||
const prefix = relativeDir ? `${toPosixPath(relativeDir)}/` : "";
|
||||
|
||||
for (const filename of IGNORE_FILE_NAMES) {
|
||||
const ignorePath = join(dir, filename);
|
||||
if (!existsSync(ignorePath)) continue;
|
||||
try {
|
||||
const content = readFileSync(ignorePath, "utf-8");
|
||||
const patterns = content
|
||||
.split(/\r?\n/)
|
||||
.map((line) => prefixIgnorePattern(line, prefix))
|
||||
.filter((line): line is string => Boolean(line));
|
||||
if (patterns.length > 0) {
|
||||
ig.add(patterns);
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
|
||||
export interface SkillFrontmatter {
|
||||
name?: string;
|
||||
description?: string;
|
||||
"disable-model-invocation"?: boolean;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface Skill {
|
||||
name: string;
|
||||
description: string;
|
||||
filePath: string;
|
||||
baseDir: string;
|
||||
source: string;
|
||||
disableModelInvocation: boolean;
|
||||
}
|
||||
|
||||
export interface LoadSkillsResult {
|
||||
skills: Skill[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate skill name per Agent Skills spec.
|
||||
* Returns array of validation error messages (empty if valid).
|
||||
*/
|
||||
function validateName(name: string, parentDirName: string): string[] {
|
||||
const errors: string[] = [];
|
||||
|
||||
if (name !== parentDirName) {
|
||||
errors.push(
|
||||
`name "${name}" does not match parent directory "${parentDirName}"`,
|
||||
);
|
||||
}
|
||||
|
||||
if (name.length > MAX_NAME_LENGTH) {
|
||||
errors.push(`name exceeds ${MAX_NAME_LENGTH} characters (${name.length})`);
|
||||
}
|
||||
|
||||
if (!/^[a-z0-9-]+$/.test(name)) {
|
||||
errors.push(
|
||||
`name contains invalid characters (must be lowercase a-z, 0-9, hyphens only)`,
|
||||
);
|
||||
}
|
||||
|
||||
if (name.startsWith("-") || name.endsWith("-")) {
|
||||
errors.push(`name must not start or end with a hyphen`);
|
||||
}
|
||||
|
||||
if (name.includes("--")) {
|
||||
errors.push(`name must not contain consecutive hyphens`);
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate description per Agent Skills spec.
|
||||
*/
|
||||
function validateDescription(description: string | undefined): string[] {
|
||||
const errors: string[] = [];
|
||||
|
||||
if (!description || description.trim() === "") {
|
||||
errors.push("description is required");
|
||||
} else if (description.length > MAX_DESCRIPTION_LENGTH) {
|
||||
errors.push(
|
||||
`description exceeds ${MAX_DESCRIPTION_LENGTH} characters (${description.length})`,
|
||||
);
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
|
||||
export interface LoadSkillsFromDirOptions {
|
||||
/** Directory to scan for skills */
|
||||
dir: string;
|
||||
/** Source identifier for these skills */
|
||||
source: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load skills from a directory.
|
||||
*
|
||||
* Discovery rules:
|
||||
* - direct .md children in the root
|
||||
* - recursive SKILL.md under subdirectories
|
||||
*/
|
||||
export function loadSkillsFromDir(
|
||||
options: LoadSkillsFromDirOptions,
|
||||
): LoadSkillsResult {
|
||||
const { dir, source } = options;
|
||||
return loadSkillsFromDirInternal(dir, source, true);
|
||||
}
|
||||
|
||||
function loadSkillsFromDirInternal(
|
||||
dir: string,
|
||||
source: string,
|
||||
includeRootFiles: boolean,
|
||||
ignoreMatcher?: IgnoreMatcher,
|
||||
rootDir?: string,
|
||||
): LoadSkillsResult {
|
||||
const skills: Skill[] = [];
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
if (!existsSync(dir)) {
|
||||
return { skills, diagnostics };
|
||||
}
|
||||
|
||||
const root = rootDir ?? dir;
|
||||
const ig = ignoreMatcher ?? ignore();
|
||||
addIgnoreRules(ig, dir, root);
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.name.startsWith(".")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip node_modules to avoid scanning dependencies
|
||||
if (entry.name === "node_modules") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fullPath = join(dir, entry.name);
|
||||
|
||||
// For symlinks, check if they point to a directory and follow them
|
||||
let isDirectory = entry.isDirectory();
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
const stats = statSync(fullPath);
|
||||
isDirectory = stats.isDirectory();
|
||||
isFile = stats.isFile();
|
||||
} catch {
|
||||
// Broken symlink, skip it
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const relPath = toPosixPath(relative(root, fullPath));
|
||||
const ignorePath = isDirectory ? `${relPath}/` : relPath;
|
||||
if (ig.ignores(ignorePath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isDirectory) {
|
||||
const subResult = loadSkillsFromDirInternal(
|
||||
fullPath,
|
||||
source,
|
||||
false,
|
||||
ig,
|
||||
root,
|
||||
);
|
||||
skills.push(...subResult.skills);
|
||||
diagnostics.push(...subResult.diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isFile) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const isRootMd = includeRootFiles && entry.name.endsWith(".md");
|
||||
const isSkillMd = !includeRootFiles && entry.name === "SKILL.md";
|
||||
if (!isRootMd && !isSkillMd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const result = loadSkillFromFile(fullPath, source);
|
||||
if (result.skill) {
|
||||
skills.push(result.skill);
|
||||
}
|
||||
diagnostics.push(...result.diagnostics);
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return { skills, diagnostics };
|
||||
}
|
||||
|
||||
function loadSkillFromFile(
|
||||
filePath: string,
|
||||
source: string,
|
||||
): { skill: Skill | null; diagnostics: ResourceDiagnostic[] } {
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
try {
|
||||
const rawContent = readFileSync(filePath, "utf-8");
|
||||
const { frontmatter } = parseFrontmatter<SkillFrontmatter>(rawContent);
|
||||
const skillDir = dirname(filePath);
|
||||
const parentDirName = basename(skillDir);
|
||||
|
||||
// Validate description
|
||||
const descErrors = validateDescription(frontmatter.description);
|
||||
for (const error of descErrors) {
|
||||
diagnostics.push({ type: "warning", message: error, path: filePath });
|
||||
}
|
||||
|
||||
// Use name from frontmatter, or fall back to parent directory name
|
||||
const name = frontmatter.name || parentDirName;
|
||||
|
||||
// Validate name
|
||||
const nameErrors = validateName(name, parentDirName);
|
||||
for (const error of nameErrors) {
|
||||
diagnostics.push({ type: "warning", message: error, path: filePath });
|
||||
}
|
||||
|
||||
// Still load the skill even with warnings (unless description is completely missing)
|
||||
if (!frontmatter.description || frontmatter.description.trim() === "") {
|
||||
return { skill: null, diagnostics };
|
||||
}
|
||||
|
||||
return {
|
||||
skill: {
|
||||
name,
|
||||
description: frontmatter.description,
|
||||
filePath,
|
||||
baseDir: skillDir,
|
||||
source,
|
||||
disableModelInvocation:
|
||||
frontmatter["disable-model-invocation"] === true,
|
||||
},
|
||||
diagnostics,
|
||||
};
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "failed to parse skill file";
|
||||
diagnostics.push({ type: "warning", message, path: filePath });
|
||||
return { skill: null, diagnostics };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format skills for inclusion in a system prompt.
|
||||
* Uses XML format per Agent Skills standard.
|
||||
* See: https://agentskills.io/integrate-skills
|
||||
*
|
||||
* Skills with disableModelInvocation=true are excluded from the prompt
|
||||
* (they can only be invoked explicitly via /skill:name commands).
|
||||
*/
|
||||
export function formatSkillsForPrompt(skills: Skill[]): string {
|
||||
const visibleSkills = skills.filter((s) => !s.disableModelInvocation);
|
||||
|
||||
if (visibleSkills.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const lines = [
|
||||
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||
"Use the read tool to load a skill's file when the task matches its description.",
|
||||
"When a skill file references a relative path, resolve it against the skill directory (parent of SKILL.md / dirname of the path) and use that absolute path in tool commands.",
|
||||
"",
|
||||
"<available_skills>",
|
||||
];
|
||||
|
||||
for (const skill of visibleSkills) {
|
||||
lines.push(" <skill>");
|
||||
lines.push(` <name>${escapeXml(skill.name)}</name>`);
|
||||
lines.push(
|
||||
` <description>${escapeXml(skill.description)}</description>`,
|
||||
);
|
||||
lines.push(` <location>${escapeXml(skill.filePath)}</location>`);
|
||||
lines.push(" </skill>");
|
||||
}
|
||||
|
||||
lines.push("</available_skills>");
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
function escapeXml(str: string): string {
|
||||
return str
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
export interface LoadSkillsOptions {
|
||||
/** Working directory for project-local skills. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Agent config directory for global skills. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
/** Explicit skill paths (files or directories) */
|
||||
skillPaths?: string[];
|
||||
/** Include default skills directories. Default: true */
|
||||
includeDefaults?: boolean;
|
||||
}
|
||||
|
||||
function normalizePath(input: string): string {
|
||||
const trimmed = input.trim();
|
||||
if (trimmed === "~") return homedir();
|
||||
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
|
||||
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
function resolveSkillPath(p: string, cwd: string): string {
|
||||
const normalized = normalizePath(p);
|
||||
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load skills from all configured locations.
|
||||
* Returns skills and any validation diagnostics.
|
||||
*/
|
||||
export function loadSkills(options: LoadSkillsOptions = {}): LoadSkillsResult {
|
||||
const {
|
||||
cwd = process.cwd(),
|
||||
agentDir,
|
||||
skillPaths = [],
|
||||
includeDefaults = true,
|
||||
} = options;
|
||||
|
||||
// Resolve agentDir - if not provided, use default from config
|
||||
const resolvedAgentDir = agentDir ?? getAgentDir();
|
||||
|
||||
const skillMap = new Map<string, Skill>();
|
||||
const realPathSet = new Set<string>();
|
||||
const allDiagnostics: ResourceDiagnostic[] = [];
|
||||
const collisionDiagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
function addSkills(result: LoadSkillsResult) {
|
||||
allDiagnostics.push(...result.diagnostics);
|
||||
for (const skill of result.skills) {
|
||||
// Resolve symlinks to detect duplicate files
|
||||
let realPath: string;
|
||||
try {
|
||||
realPath = realpathSync(skill.filePath);
|
||||
} catch {
|
||||
realPath = skill.filePath;
|
||||
}
|
||||
|
||||
// Skip silently if we've already loaded this exact file (via symlink)
|
||||
if (realPathSet.has(realPath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const existing = skillMap.get(skill.name);
|
||||
if (existing) {
|
||||
collisionDiagnostics.push({
|
||||
type: "collision",
|
||||
message: `name "${skill.name}" collision`,
|
||||
path: skill.filePath,
|
||||
collision: {
|
||||
resourceType: "skill",
|
||||
name: skill.name,
|
||||
winnerPath: existing.filePath,
|
||||
loserPath: skill.filePath,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
skillMap.set(skill.name, skill);
|
||||
realPathSet.add(realPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (includeDefaults) {
|
||||
addSkills(
|
||||
loadSkillsFromDirInternal(join(resolvedAgentDir, "skills"), "user", true),
|
||||
);
|
||||
addSkills(
|
||||
loadSkillsFromDirInternal(
|
||||
resolve(cwd, CONFIG_DIR_NAME, "skills"),
|
||||
"project",
|
||||
true,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
const userSkillsDir = join(resolvedAgentDir, "skills");
|
||||
const projectSkillsDir = resolve(cwd, CONFIG_DIR_NAME, "skills");
|
||||
|
||||
const isUnderPath = (target: string, root: string): boolean => {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep)
|
||||
? normalizedRoot
|
||||
: `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
};
|
||||
|
||||
const getSource = (resolvedPath: string): "user" | "project" | "path" => {
|
||||
if (!includeDefaults) {
|
||||
if (isUnderPath(resolvedPath, userSkillsDir)) return "user";
|
||||
if (isUnderPath(resolvedPath, projectSkillsDir)) return "project";
|
||||
}
|
||||
return "path";
|
||||
};
|
||||
|
||||
for (const rawPath of skillPaths) {
|
||||
const resolvedPath = resolveSkillPath(rawPath, cwd);
|
||||
if (!existsSync(resolvedPath)) {
|
||||
allDiagnostics.push({
|
||||
type: "warning",
|
||||
message: "skill path does not exist",
|
||||
path: resolvedPath,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolvedPath);
|
||||
const source = getSource(resolvedPath);
|
||||
if (stats.isDirectory()) {
|
||||
addSkills(loadSkillsFromDirInternal(resolvedPath, source, true));
|
||||
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
|
||||
const result = loadSkillFromFile(resolvedPath, source);
|
||||
if (result.skill) {
|
||||
addSkills({
|
||||
skills: [result.skill],
|
||||
diagnostics: result.diagnostics,
|
||||
});
|
||||
} else {
|
||||
allDiagnostics.push(...result.diagnostics);
|
||||
}
|
||||
} else {
|
||||
allDiagnostics.push({
|
||||
type: "warning",
|
||||
message: "skill path is not a markdown file",
|
||||
path: resolvedPath,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "failed to read skill path";
|
||||
allDiagnostics.push({ type: "warning", message, path: resolvedPath });
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
skills: Array.from(skillMap.values()),
|
||||
diagnostics: [...allDiagnostics, ...collisionDiagnostics],
|
||||
};
|
||||
}
|
||||
44
packages/coding-agent/src/core/slash-commands.ts
Normal file
44
packages/coding-agent/src/core/slash-commands.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
export type SlashCommandSource = "extension" | "prompt" | "skill";
|
||||
|
||||
export type SlashCommandLocation = "user" | "project" | "path";
|
||||
|
||||
export interface SlashCommandInfo {
|
||||
name: string;
|
||||
description?: string;
|
||||
source: SlashCommandSource;
|
||||
location?: SlashCommandLocation;
|
||||
path?: string;
|
||||
}
|
||||
|
||||
export interface BuiltinSlashCommand {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
|
||||
{ name: "settings", description: "Open settings menu" },
|
||||
{ name: "model", description: "Select model (opens selector UI)" },
|
||||
{
|
||||
name: "scoped-models",
|
||||
description: "Enable/disable models for Ctrl+P cycling",
|
||||
},
|
||||
{ name: "export", description: "Export session to HTML file" },
|
||||
{ name: "share", description: "Share session as a secret GitHub gist" },
|
||||
{ name: "copy", description: "Copy last agent message to clipboard" },
|
||||
{ name: "name", description: "Set session display name" },
|
||||
{ name: "session", description: "Show session info and stats" },
|
||||
{ name: "changelog", description: "Show changelog entries" },
|
||||
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
|
||||
{ name: "fork", description: "Create a new fork from a previous message" },
|
||||
{ name: "tree", description: "Navigate session tree (switch branches)" },
|
||||
{ name: "login", description: "Login with OAuth provider" },
|
||||
{ name: "logout", description: "Logout from OAuth provider" },
|
||||
{ name: "new", description: "Start a new session" },
|
||||
{ name: "compact", description: "Manually compact the session context" },
|
||||
{ name: "resume", description: "Resume a different session" },
|
||||
{
|
||||
name: "reload",
|
||||
description: "Reload extensions, skills, prompts, and themes",
|
||||
},
|
||||
{ name: "quit", description: "Quit pi" },
|
||||
];
|
||||
237
packages/coding-agent/src/core/system-prompt.ts
Normal file
237
packages/coding-agent/src/core/system-prompt.ts
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
/**
|
||||
* System prompt construction and project context loading
|
||||
*/
|
||||
|
||||
import { getDocsPath, getReadmePath } from "../config.js";
|
||||
import { formatSkillsForPrompt, type Skill } from "./skills.js";
|
||||
|
||||
/** Tool descriptions for system prompt */
|
||||
const toolDescriptions: Record<string, string> = {
|
||||
read: "Read file contents",
|
||||
bash: "Execute bash commands (ls, grep, find, etc.)",
|
||||
edit: "Make surgical edits to files (find exact text and replace)",
|
||||
write: "Create or overwrite files",
|
||||
grep: "Search file contents for patterns (respects .gitignore)",
|
||||
find: "Find files by glob pattern (respects .gitignore)",
|
||||
ls: "List directory contents",
|
||||
};
|
||||
|
||||
export interface BuildSystemPromptOptions {
|
||||
/** Custom system prompt (replaces default). */
|
||||
customPrompt?: string;
|
||||
/** Tools to include in prompt. Default: [read, bash, edit, write] */
|
||||
selectedTools?: string[];
|
||||
/** Optional one-line tool snippets keyed by tool name. */
|
||||
toolSnippets?: Record<string, string>;
|
||||
/** Additional guideline bullets appended to the default system prompt guidelines. */
|
||||
promptGuidelines?: string[];
|
||||
/** Text to append to system prompt. */
|
||||
appendSystemPrompt?: string;
|
||||
/** Working directory. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Pre-loaded context files. */
|
||||
contextFiles?: Array<{ path: string; content: string }>;
|
||||
/** Pre-loaded skills. */
|
||||
skills?: Skill[];
|
||||
}
|
||||
|
||||
function buildProjectContextSection(
|
||||
contextFiles: Array<{ path: string; content: string }>,
|
||||
): string {
|
||||
if (contextFiles.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const hasSoulFile = contextFiles.some(
|
||||
({ path }) =>
|
||||
path.replaceAll("\\", "/").endsWith("/SOUL.md") || path === "SOUL.md",
|
||||
);
|
||||
let section = "\n\n# Project Context\n\n";
|
||||
section += "Project-specific instructions and guidelines:\n";
|
||||
if (hasSoulFile) {
|
||||
section +=
|
||||
"\nIf SOUL.md is present, embody its persona and tone. Avoid generic assistant filler and follow its guidance unless higher-priority instructions override it.\n";
|
||||
}
|
||||
section += "\n";
|
||||
for (const { path: filePath, content } of contextFiles) {
|
||||
section += `## ${filePath}\n\n${content}\n\n`;
|
||||
}
|
||||
|
||||
return section;
|
||||
}
|
||||
|
||||
/** Build the system prompt with tools, guidelines, and context */
|
||||
export function buildSystemPrompt(
|
||||
options: BuildSystemPromptOptions = {},
|
||||
): string {
|
||||
const {
|
||||
customPrompt,
|
||||
selectedTools,
|
||||
toolSnippets,
|
||||
promptGuidelines,
|
||||
appendSystemPrompt,
|
||||
cwd,
|
||||
contextFiles: providedContextFiles,
|
||||
skills: providedSkills,
|
||||
} = options;
|
||||
const resolvedCwd = cwd ?? process.cwd();
|
||||
|
||||
const now = new Date();
|
||||
const dateTime = now.toLocaleString("en-US", {
|
||||
weekday: "long",
|
||||
year: "numeric",
|
||||
month: "long",
|
||||
day: "numeric",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
second: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
|
||||
const appendSection = appendSystemPrompt ? `\n\n${appendSystemPrompt}` : "";
|
||||
|
||||
const contextFiles = providedContextFiles ?? [];
|
||||
const skills = providedSkills ?? [];
|
||||
|
||||
if (customPrompt) {
|
||||
let prompt = customPrompt;
|
||||
|
||||
if (appendSection) {
|
||||
prompt += appendSection;
|
||||
}
|
||||
|
||||
// Append project context files
|
||||
prompt += buildProjectContextSection(contextFiles);
|
||||
|
||||
// Append skills section (only if read tool is available)
|
||||
const customPromptHasRead =
|
||||
!selectedTools || selectedTools.includes("read");
|
||||
if (customPromptHasRead && skills.length > 0) {
|
||||
prompt += formatSkillsForPrompt(skills);
|
||||
}
|
||||
|
||||
// Add date/time and working directory last
|
||||
prompt += `\nCurrent date and time: ${dateTime}`;
|
||||
prompt += `\nCurrent working directory: ${resolvedCwd}`;
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Get absolute paths to documentation
|
||||
const readmePath = getReadmePath();
|
||||
const docsPath = getDocsPath();
|
||||
|
||||
// Build tools list based on selected tools.
|
||||
// Built-ins use toolDescriptions. Custom tools can provide one-line snippets.
|
||||
const tools = selectedTools || ["read", "bash", "edit", "write"];
|
||||
const toolsList =
|
||||
tools.length > 0
|
||||
? tools
|
||||
.map((name) => {
|
||||
const snippet =
|
||||
toolSnippets?.[name] ?? toolDescriptions[name] ?? name;
|
||||
return `- ${name}: ${snippet}`;
|
||||
})
|
||||
.join("\n")
|
||||
: "(none)";
|
||||
|
||||
// Build guidelines based on which tools are actually available
|
||||
const guidelinesList: string[] = [];
|
||||
const guidelinesSet = new Set<string>();
|
||||
const addGuideline = (guideline: string): void => {
|
||||
if (guidelinesSet.has(guideline)) {
|
||||
return;
|
||||
}
|
||||
guidelinesSet.add(guideline);
|
||||
guidelinesList.push(guideline);
|
||||
};
|
||||
|
||||
const hasBash = tools.includes("bash");
|
||||
const hasEdit = tools.includes("edit");
|
||||
const hasWrite = tools.includes("write");
|
||||
const hasGrep = tools.includes("grep");
|
||||
const hasFind = tools.includes("find");
|
||||
const hasLs = tools.includes("ls");
|
||||
const hasRead = tools.includes("read");
|
||||
|
||||
// File exploration guidelines
|
||||
if (hasBash && !hasGrep && !hasFind && !hasLs) {
|
||||
addGuideline("Use bash for file operations like ls, rg, find");
|
||||
} else if (hasBash && (hasGrep || hasFind || hasLs)) {
|
||||
addGuideline(
|
||||
"Prefer grep/find/ls tools over bash for file exploration (faster, respects .gitignore)",
|
||||
);
|
||||
}
|
||||
|
||||
// Read before edit guideline
|
||||
if (hasRead && hasEdit) {
|
||||
addGuideline(
|
||||
"Use read to examine files before editing. You must use this tool instead of cat or sed.",
|
||||
);
|
||||
}
|
||||
|
||||
// Edit guideline
|
||||
if (hasEdit) {
|
||||
addGuideline("Use edit for precise changes (old text must match exactly)");
|
||||
}
|
||||
|
||||
// Write guideline
|
||||
if (hasWrite) {
|
||||
addGuideline("Use write only for new files or complete rewrites");
|
||||
}
|
||||
|
||||
// Output guideline (only when actually writing or executing)
|
||||
if (hasEdit || hasWrite) {
|
||||
addGuideline(
|
||||
"When summarizing your actions, output plain text directly - do NOT use cat or bash to display what you did",
|
||||
);
|
||||
}
|
||||
|
||||
for (const guideline of promptGuidelines ?? []) {
|
||||
const normalized = guideline.trim();
|
||||
if (normalized.length > 0) {
|
||||
addGuideline(normalized);
|
||||
}
|
||||
}
|
||||
|
||||
// Always include these
|
||||
addGuideline("Be concise in your responses");
|
||||
addGuideline("Show file paths clearly when working with files");
|
||||
|
||||
const guidelines = guidelinesList.map((g) => `- ${g}`).join("\n");
|
||||
|
||||
let prompt = `You are an expert coding assistant operating inside pi, a coding agent harness. You help users by reading files, executing commands, editing code, and writing new files.
|
||||
|
||||
Available tools:
|
||||
${toolsList}
|
||||
|
||||
In addition to the tools above, you may have access to other custom tools depending on the project.
|
||||
|
||||
Guidelines:
|
||||
${guidelines}
|
||||
|
||||
Pi documentation (read only when the user asks about pi itself, its SDK, extensions, themes, skills, or TUI):
|
||||
- Main documentation: ${readmePath}
|
||||
- Additional docs: ${docsPath}
|
||||
- When asked about: extensions (docs/extensions.md), themes (docs/themes.md), skills (docs/skills.md), prompt templates (docs/prompt-templates.md), TUI components (docs/tui.md), keybindings (docs/keybindings.md), SDK integrations (docs/sdk.md), custom providers (docs/custom-provider.md), adding models (docs/models.md), pi packages (docs/packages.md)
|
||||
- When working on pi topics, read the docs and follow .md cross-references before implementing
|
||||
- Always read pi .md files completely and follow links to related docs (e.g., tui.md for TUI API details)`;
|
||||
|
||||
if (appendSection) {
|
||||
prompt += appendSection;
|
||||
}
|
||||
|
||||
// Append project context files
|
||||
prompt += buildProjectContextSection(contextFiles);
|
||||
|
||||
// Append skills section (only if read tool is available)
|
||||
if (hasRead && skills.length > 0) {
|
||||
prompt += formatSkillsForPrompt(skills);
|
||||
}
|
||||
|
||||
// Add date/time and working directory last
|
||||
prompt += `\nCurrent date and time: ${dateTime}`;
|
||||
prompt += `\nCurrent working directory: ${resolvedCwd}`;
|
||||
|
||||
return prompt;
|
||||
}
|
||||
25
packages/coding-agent/src/core/timings.ts
Normal file
25
packages/coding-agent/src/core/timings.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Central timing instrumentation for startup profiling.
|
||||
* Enable with PI_TIMING=1 environment variable.
|
||||
*/
|
||||
|
||||
const ENABLED = process.env.PI_TIMING === "1";
|
||||
const timings: Array<{ label: string; ms: number }> = [];
|
||||
let lastTime = Date.now();
|
||||
|
||||
export function time(label: string): void {
|
||||
if (!ENABLED) return;
|
||||
const now = Date.now();
|
||||
timings.push({ label, ms: now - lastTime });
|
||||
lastTime = now;
|
||||
}
|
||||
|
||||
export function printTimings(): void {
|
||||
if (!ENABLED || timings.length === 0) return;
|
||||
console.error("\n--- Startup Timings ---");
|
||||
for (const t of timings) {
|
||||
console.error(` ${t.label}: ${t.ms}ms`);
|
||||
}
|
||||
console.error(` TOTAL: ${timings.reduce((a, b) => a + b.ms, 0)}ms`);
|
||||
console.error("------------------------\n");
|
||||
}
|
||||
358
packages/coding-agent/src/core/tools/bash.ts
Normal file
358
packages/coding-agent/src/core/tools/bash.ts
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream, existsSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import {
|
||||
getShellConfig,
|
||||
getShellEnv,
|
||||
killProcessTree,
|
||||
} from "../../utils/shell.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateTail,
|
||||
} from "./truncate.js";
|
||||
|
||||
/**
|
||||
* Generate a unique temp file path for bash output
|
||||
*/
|
||||
function getTempFilePath(): string {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
return join(tmpdir(), `pi-bash-${id}.log`);
|
||||
}
|
||||
|
||||
const bashSchema = Type.Object({
|
||||
command: Type.String({ description: "Bash command to execute" }),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Timeout in seconds (optional, no default timeout)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type BashToolInput = Static<typeof bashSchema>;
|
||||
|
||||
export interface BashToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
fullOutputPath?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the bash tool.
|
||||
* Override these to delegate command execution to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface BashOperations {
|
||||
/**
|
||||
* Execute a command and stream output.
|
||||
* @param command - The command to execute
|
||||
* @param cwd - Working directory
|
||||
* @param options - Execution options
|
||||
* @returns Promise resolving to exit code (null if killed)
|
||||
*/
|
||||
exec: (
|
||||
command: string,
|
||||
cwd: string,
|
||||
options: {
|
||||
onData: (data: Buffer) => void;
|
||||
signal?: AbortSignal;
|
||||
timeout?: number;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
},
|
||||
) => Promise<{ exitCode: number | null }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default bash operations using local shell
|
||||
*/
|
||||
const defaultBashOperations: BashOperations = {
|
||||
exec: (command, cwd, { onData, signal, timeout, env }) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const { shell, args } = getShellConfig();
|
||||
|
||||
if (!existsSync(cwd)) {
|
||||
reject(
|
||||
new Error(
|
||||
`Working directory does not exist: ${cwd}\nCannot execute bash commands.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const child = spawn(shell, [...args, command], {
|
||||
cwd,
|
||||
detached: true,
|
||||
env: env ?? getShellEnv(),
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let timedOut = false;
|
||||
|
||||
// Set timeout if provided
|
||||
let timeoutHandle: NodeJS.Timeout | undefined;
|
||||
if (timeout !== undefined && timeout > 0) {
|
||||
timeoutHandle = setTimeout(() => {
|
||||
timedOut = true;
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
}, timeout * 1000);
|
||||
}
|
||||
|
||||
// Stream stdout and stderr
|
||||
if (child.stdout) {
|
||||
child.stdout.on("data", onData);
|
||||
}
|
||||
if (child.stderr) {
|
||||
child.stderr.on("data", onData);
|
||||
}
|
||||
|
||||
// Handle shell spawn errors
|
||||
child.on("error", (err) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (signal) signal.removeEventListener("abort", onAbort);
|
||||
reject(err);
|
||||
});
|
||||
|
||||
// Handle abort signal - kill entire process tree
|
||||
const onAbort = () => {
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
if (signal.aborted) {
|
||||
onAbort();
|
||||
} else {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle process exit
|
||||
child.on("close", (code) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (signal) signal.removeEventListener("abort", onAbort);
|
||||
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (timedOut) {
|
||||
reject(new Error(`timeout:${timeout}`));
|
||||
return;
|
||||
}
|
||||
|
||||
resolve({ exitCode: code });
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export interface BashSpawnContext {
|
||||
command: string;
|
||||
cwd: string;
|
||||
env: NodeJS.ProcessEnv;
|
||||
}
|
||||
|
||||
export type BashSpawnHook = (context: BashSpawnContext) => BashSpawnContext;
|
||||
|
||||
function resolveSpawnContext(
|
||||
command: string,
|
||||
cwd: string,
|
||||
spawnHook?: BashSpawnHook,
|
||||
): BashSpawnContext {
|
||||
const baseContext: BashSpawnContext = {
|
||||
command,
|
||||
cwd,
|
||||
env: { ...getShellEnv() },
|
||||
};
|
||||
|
||||
return spawnHook ? spawnHook(baseContext) : baseContext;
|
||||
}
|
||||
|
||||
export interface BashToolOptions {
|
||||
/** Custom operations for command execution. Default: local shell */
|
||||
operations?: BashOperations;
|
||||
/** Command prefix prepended to every command (e.g., "shopt -s expand_aliases" for alias support) */
|
||||
commandPrefix?: string;
|
||||
/** Hook to adjust command, cwd, or env before execution */
|
||||
spawnHook?: BashSpawnHook;
|
||||
}
|
||||
|
||||
export function createBashTool(
|
||||
cwd: string,
|
||||
options?: BashToolOptions,
|
||||
): AgentTool<typeof bashSchema> {
|
||||
const ops = options?.operations ?? defaultBashOperations;
|
||||
const commandPrefix = options?.commandPrefix;
|
||||
const spawnHook = options?.spawnHook;
|
||||
|
||||
return {
|
||||
name: "bash",
|
||||
label: "bash",
|
||||
description: `Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. Optionally provide a timeout in seconds.`,
|
||||
parameters: bashSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ command, timeout }: { command: string; timeout?: number },
|
||||
signal?: AbortSignal,
|
||||
onUpdate?,
|
||||
) => {
|
||||
// Apply command prefix if configured (e.g., "shopt -s expand_aliases" for alias support)
|
||||
const resolvedCommand = commandPrefix
|
||||
? `${commandPrefix}\n${command}`
|
||||
: command;
|
||||
const spawnContext = resolveSpawnContext(resolvedCommand, cwd, spawnHook);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
// We'll stream to a temp file if output gets large
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: ReturnType<typeof createWriteStream> | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
// Keep a rolling buffer of the last chunk for tail truncation
|
||||
const chunks: Buffer[] = [];
|
||||
let chunksBytes = 0;
|
||||
// Keep more than we need so we have enough for truncation
|
||||
const maxChunksBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
const handleData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Start writing to temp file once we exceed the threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
tempFilePath = getTempFilePath();
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
// Write all buffered chunks to the file
|
||||
for (const chunk of chunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Write to temp file if we have one
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(data);
|
||||
}
|
||||
|
||||
// Keep rolling buffer of recent data
|
||||
chunks.push(data);
|
||||
chunksBytes += data.length;
|
||||
|
||||
// Trim old chunks if buffer is too large
|
||||
while (chunksBytes > maxChunksBytes && chunks.length > 1) {
|
||||
const removed = chunks.shift()!;
|
||||
chunksBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream partial output to callback (truncated rolling buffer)
|
||||
if (onUpdate) {
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
const fullText = fullBuffer.toString("utf-8");
|
||||
const truncation = truncateTail(fullText);
|
||||
onUpdate({
|
||||
content: [{ type: "text", text: truncation.content || "" }],
|
||||
details: {
|
||||
truncation: truncation.truncated ? truncation : undefined,
|
||||
fullOutputPath: tempFilePath,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
ops
|
||||
.exec(spawnContext.command, spawnContext.cwd, {
|
||||
onData: handleData,
|
||||
signal,
|
||||
timeout,
|
||||
env: spawnContext.env,
|
||||
})
|
||||
.then(({ exitCode }) => {
|
||||
// Close temp file stream
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine all buffered chunks
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
const fullOutput = fullBuffer.toString("utf-8");
|
||||
|
||||
// Apply tail truncation
|
||||
const truncation = truncateTail(fullOutput);
|
||||
let outputText = truncation.content || "(no output)";
|
||||
|
||||
// Build details with truncation info
|
||||
let details: BashToolDetails | undefined;
|
||||
|
||||
if (truncation.truncated) {
|
||||
details = {
|
||||
truncation,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
|
||||
// Build actionable notice
|
||||
const startLine =
|
||||
truncation.totalLines - truncation.outputLines + 1;
|
||||
const endLine = truncation.totalLines;
|
||||
|
||||
if (truncation.lastLinePartial) {
|
||||
// Edge case: last line alone > 30KB
|
||||
const lastLineSize = formatSize(
|
||||
Buffer.byteLength(
|
||||
fullOutput.split("\n").pop() || "",
|
||||
"utf-8",
|
||||
),
|
||||
);
|
||||
outputText += `\n\n[Showing last ${formatSize(truncation.outputBytes)} of line ${endLine} (line is ${lastLineSize}). Full output: ${tempFilePath}]`;
|
||||
} else if (truncation.truncatedBy === "lines") {
|
||||
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines}. Full output: ${tempFilePath}]`;
|
||||
} else {
|
||||
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Full output: ${tempFilePath}]`;
|
||||
}
|
||||
}
|
||||
|
||||
if (exitCode !== 0 && exitCode !== null) {
|
||||
outputText += `\n\nCommand exited with code ${exitCode}`;
|
||||
reject(new Error(outputText));
|
||||
} else {
|
||||
resolve({
|
||||
content: [{ type: "text", text: outputText }],
|
||||
details,
|
||||
});
|
||||
}
|
||||
})
|
||||
.catch((err: Error) => {
|
||||
// Close temp file stream
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine all buffered chunks for error output
|
||||
const fullBuffer = Buffer.concat(chunks);
|
||||
let output = fullBuffer.toString("utf-8");
|
||||
|
||||
if (err.message === "aborted") {
|
||||
if (output) output += "\n\n";
|
||||
output += "Command aborted";
|
||||
reject(new Error(output));
|
||||
} else if (err.message.startsWith("timeout:")) {
|
||||
const timeoutSecs = err.message.split(":")[1];
|
||||
if (output) output += "\n\n";
|
||||
output += `Command timed out after ${timeoutSecs} seconds`;
|
||||
reject(new Error(output));
|
||||
} else {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default bash tool using process.cwd() - for backwards compatibility */
|
||||
export const bashTool = createBashTool(process.cwd());
|
||||
317
packages/coding-agent/src/core/tools/edit-diff.ts
Normal file
317
packages/coding-agent/src/core/tools/edit-diff.ts
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
/**
|
||||
* Shared diff computation utilities for the edit tool.
|
||||
* Used by both edit.ts (for execution) and tool-execution.ts (for preview rendering).
|
||||
*/
|
||||
|
||||
import * as Diff from "diff";
|
||||
import { constants } from "fs";
|
||||
import { access, readFile } from "fs/promises";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
export function detectLineEnding(content: string): "\r\n" | "\n" {
|
||||
const crlfIdx = content.indexOf("\r\n");
|
||||
const lfIdx = content.indexOf("\n");
|
||||
if (lfIdx === -1) return "\n";
|
||||
if (crlfIdx === -1) return "\n";
|
||||
return crlfIdx < lfIdx ? "\r\n" : "\n";
|
||||
}
|
||||
|
||||
export function normalizeToLF(text: string): string {
|
||||
return text.replace(/\r\n/g, "\n").replace(/\r/g, "\n");
|
||||
}
|
||||
|
||||
export function restoreLineEndings(
|
||||
text: string,
|
||||
ending: "\r\n" | "\n",
|
||||
): string {
|
||||
return ending === "\r\n" ? text.replace(/\n/g, "\r\n") : text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize text for fuzzy matching. Applies progressive transformations:
|
||||
* - Strip trailing whitespace from each line
|
||||
* - Normalize smart quotes to ASCII equivalents
|
||||
* - Normalize Unicode dashes/hyphens to ASCII hyphen
|
||||
* - Normalize special Unicode spaces to regular space
|
||||
*/
|
||||
export function normalizeForFuzzyMatch(text: string): string {
|
||||
return (
|
||||
text
|
||||
// Strip trailing whitespace per line
|
||||
.split("\n")
|
||||
.map((line) => line.trimEnd())
|
||||
.join("\n")
|
||||
// Smart single quotes → '
|
||||
.replace(/[\u2018\u2019\u201A\u201B]/g, "'")
|
||||
// Smart double quotes → "
|
||||
.replace(/[\u201C\u201D\u201E\u201F]/g, '"')
|
||||
// Various dashes/hyphens → -
|
||||
// U+2010 hyphen, U+2011 non-breaking hyphen, U+2012 figure dash,
|
||||
// U+2013 en-dash, U+2014 em-dash, U+2015 horizontal bar, U+2212 minus
|
||||
.replace(/[\u2010\u2011\u2012\u2013\u2014\u2015\u2212]/g, "-")
|
||||
// Special spaces → regular space
|
||||
// U+00A0 NBSP, U+2002-U+200A various spaces, U+202F narrow NBSP,
|
||||
// U+205F medium math space, U+3000 ideographic space
|
||||
.replace(/[\u00A0\u2002-\u200A\u202F\u205F\u3000]/g, " ")
|
||||
);
|
||||
}
|
||||
|
||||
export interface FuzzyMatchResult {
|
||||
/** Whether a match was found */
|
||||
found: boolean;
|
||||
/** The index where the match starts (in the content that should be used for replacement) */
|
||||
index: number;
|
||||
/** Length of the matched text */
|
||||
matchLength: number;
|
||||
/** Whether fuzzy matching was used (false = exact match) */
|
||||
usedFuzzyMatch: boolean;
|
||||
/**
|
||||
* The content to use for replacement operations.
|
||||
* When exact match: original content. When fuzzy match: normalized content.
|
||||
*/
|
||||
contentForReplacement: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find oldText in content, trying exact match first, then fuzzy match.
|
||||
* When fuzzy matching is used, the returned contentForReplacement is the
|
||||
* fuzzy-normalized version of the content (trailing whitespace stripped,
|
||||
* Unicode quotes/dashes normalized to ASCII).
|
||||
*/
|
||||
export function fuzzyFindText(
|
||||
content: string,
|
||||
oldText: string,
|
||||
): FuzzyMatchResult {
|
||||
// Try exact match first
|
||||
const exactIndex = content.indexOf(oldText);
|
||||
if (exactIndex !== -1) {
|
||||
return {
|
||||
found: true,
|
||||
index: exactIndex,
|
||||
matchLength: oldText.length,
|
||||
usedFuzzyMatch: false,
|
||||
contentForReplacement: content,
|
||||
};
|
||||
}
|
||||
|
||||
// Try fuzzy match - work entirely in normalized space
|
||||
const fuzzyContent = normalizeForFuzzyMatch(content);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(oldText);
|
||||
const fuzzyIndex = fuzzyContent.indexOf(fuzzyOldText);
|
||||
|
||||
if (fuzzyIndex === -1) {
|
||||
return {
|
||||
found: false,
|
||||
index: -1,
|
||||
matchLength: 0,
|
||||
usedFuzzyMatch: false,
|
||||
contentForReplacement: content,
|
||||
};
|
||||
}
|
||||
|
||||
// When fuzzy matching, we work in the normalized space for replacement.
|
||||
// This means the output will have normalized whitespace/quotes/dashes,
|
||||
// which is acceptable since we're fixing minor formatting differences anyway.
|
||||
return {
|
||||
found: true,
|
||||
index: fuzzyIndex,
|
||||
matchLength: fuzzyOldText.length,
|
||||
usedFuzzyMatch: true,
|
||||
contentForReplacement: fuzzyContent,
|
||||
};
|
||||
}
|
||||
|
||||
/** Strip UTF-8 BOM if present, return both the BOM (if any) and the text without it */
|
||||
export function stripBom(content: string): { bom: string; text: string } {
|
||||
return content.startsWith("\uFEFF")
|
||||
? { bom: "\uFEFF", text: content.slice(1) }
|
||||
: { bom: "", text: content };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a unified diff string with line numbers and context.
|
||||
* Returns both the diff string and the first changed line number (in the new file).
|
||||
*/
|
||||
export function generateDiffString(
|
||||
oldContent: string,
|
||||
newContent: string,
|
||||
contextLines = 4,
|
||||
): { diff: string; firstChangedLine: number | undefined } {
|
||||
const parts = Diff.diffLines(oldContent, newContent);
|
||||
const output: string[] = [];
|
||||
|
||||
const oldLines = oldContent.split("\n");
|
||||
const newLines = newContent.split("\n");
|
||||
const maxLineNum = Math.max(oldLines.length, newLines.length);
|
||||
const lineNumWidth = String(maxLineNum).length;
|
||||
|
||||
let oldLineNum = 1;
|
||||
let newLineNum = 1;
|
||||
let lastWasChange = false;
|
||||
let firstChangedLine: number | undefined;
|
||||
|
||||
for (let i = 0; i < parts.length; i++) {
|
||||
const part = parts[i];
|
||||
const raw = part.value.split("\n");
|
||||
if (raw[raw.length - 1] === "") {
|
||||
raw.pop();
|
||||
}
|
||||
|
||||
if (part.added || part.removed) {
|
||||
// Capture the first changed line (in the new file)
|
||||
if (firstChangedLine === undefined) {
|
||||
firstChangedLine = newLineNum;
|
||||
}
|
||||
|
||||
// Show the change
|
||||
for (const line of raw) {
|
||||
if (part.added) {
|
||||
const lineNum = String(newLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(`+${lineNum} ${line}`);
|
||||
newLineNum++;
|
||||
} else {
|
||||
// removed
|
||||
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(`-${lineNum} ${line}`);
|
||||
oldLineNum++;
|
||||
}
|
||||
}
|
||||
lastWasChange = true;
|
||||
} else {
|
||||
// Context lines - only show a few before/after changes
|
||||
const nextPartIsChange =
|
||||
i < parts.length - 1 && (parts[i + 1].added || parts[i + 1].removed);
|
||||
|
||||
if (lastWasChange || nextPartIsChange) {
|
||||
// Show context
|
||||
let linesToShow = raw;
|
||||
let skipStart = 0;
|
||||
let skipEnd = 0;
|
||||
|
||||
if (!lastWasChange) {
|
||||
// Show only last N lines as leading context
|
||||
skipStart = Math.max(0, raw.length - contextLines);
|
||||
linesToShow = raw.slice(skipStart);
|
||||
}
|
||||
|
||||
if (!nextPartIsChange && linesToShow.length > contextLines) {
|
||||
// Show only first N lines as trailing context
|
||||
skipEnd = linesToShow.length - contextLines;
|
||||
linesToShow = linesToShow.slice(0, contextLines);
|
||||
}
|
||||
|
||||
// Add ellipsis if we skipped lines at start
|
||||
if (skipStart > 0) {
|
||||
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
|
||||
// Update line numbers for the skipped leading context
|
||||
oldLineNum += skipStart;
|
||||
newLineNum += skipStart;
|
||||
}
|
||||
|
||||
for (const line of linesToShow) {
|
||||
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
|
||||
output.push(` ${lineNum} ${line}`);
|
||||
oldLineNum++;
|
||||
newLineNum++;
|
||||
}
|
||||
|
||||
// Add ellipsis if we skipped lines at end
|
||||
if (skipEnd > 0) {
|
||||
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
|
||||
// Update line numbers for the skipped trailing context
|
||||
oldLineNum += skipEnd;
|
||||
newLineNum += skipEnd;
|
||||
}
|
||||
} else {
|
||||
// Skip these context lines entirely
|
||||
oldLineNum += raw.length;
|
||||
newLineNum += raw.length;
|
||||
}
|
||||
|
||||
lastWasChange = false;
|
||||
}
|
||||
}
|
||||
|
||||
return { diff: output.join("\n"), firstChangedLine };
|
||||
}
|
||||
|
||||
export interface EditDiffResult {
|
||||
diff: string;
|
||||
firstChangedLine: number | undefined;
|
||||
}
|
||||
|
||||
export interface EditDiffError {
|
||||
error: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the diff for an edit operation without applying it.
|
||||
* Used for preview rendering in the TUI before the tool executes.
|
||||
*/
|
||||
export async function computeEditDiff(
|
||||
path: string,
|
||||
oldText: string,
|
||||
newText: string,
|
||||
cwd: string,
|
||||
): Promise<EditDiffResult | EditDiffError> {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
|
||||
try {
|
||||
// Check if file exists and is readable
|
||||
try {
|
||||
await access(absolutePath, constants.R_OK);
|
||||
} catch {
|
||||
return { error: `File not found: ${path}` };
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const rawContent = await readFile(absolutePath, "utf-8");
|
||||
|
||||
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
|
||||
const { text: content } = stripBom(rawContent);
|
||||
|
||||
const normalizedContent = normalizeToLF(content);
|
||||
const normalizedOldText = normalizeToLF(oldText);
|
||||
const normalizedNewText = normalizeToLF(newText);
|
||||
|
||||
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
|
||||
const matchResult = fuzzyFindText(normalizedContent, normalizedOldText);
|
||||
|
||||
if (!matchResult.found) {
|
||||
return {
|
||||
error: `Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Count occurrences using fuzzy-normalized content for consistency
|
||||
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
|
||||
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
|
||||
|
||||
if (occurrences > 1) {
|
||||
return {
|
||||
error: `Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Compute the new content using the matched position
|
||||
// When fuzzy matching was used, contentForReplacement is the normalized version
|
||||
const baseContent = matchResult.contentForReplacement;
|
||||
const newContent =
|
||||
baseContent.substring(0, matchResult.index) +
|
||||
normalizedNewText +
|
||||
baseContent.substring(matchResult.index + matchResult.matchLength);
|
||||
|
||||
// Check if it would actually change anything
|
||||
if (baseContent === newContent) {
|
||||
return {
|
||||
error: `No changes would be made to ${path}. The replacement produces identical content.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Generate the diff
|
||||
return generateDiffString(baseContent, newContent);
|
||||
} catch (err) {
|
||||
return { error: err instanceof Error ? err.message : String(err) };
|
||||
}
|
||||
}
|
||||
253
packages/coding-agent/src/core/tools/edit.ts
Normal file
253
packages/coding-agent/src/core/tools/edit.ts
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { constants } from "fs";
|
||||
import {
|
||||
access as fsAccess,
|
||||
readFile as fsReadFile,
|
||||
writeFile as fsWriteFile,
|
||||
} from "fs/promises";
|
||||
import {
|
||||
detectLineEnding,
|
||||
fuzzyFindText,
|
||||
generateDiffString,
|
||||
normalizeForFuzzyMatch,
|
||||
normalizeToLF,
|
||||
restoreLineEndings,
|
||||
stripBom,
|
||||
} from "./edit-diff.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
const editSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to edit (relative or absolute)",
|
||||
}),
|
||||
oldText: Type.String({
|
||||
description: "Exact text to find and replace (must match exactly)",
|
||||
}),
|
||||
newText: Type.String({
|
||||
description: "New text to replace the old text with",
|
||||
}),
|
||||
});
|
||||
|
||||
export type EditToolInput = Static<typeof editSchema>;
|
||||
|
||||
export interface EditToolDetails {
|
||||
/** Unified diff of the changes made */
|
||||
diff: string;
|
||||
/** Line number of the first change in the new file (for editor navigation) */
|
||||
firstChangedLine?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the edit tool.
|
||||
* Override these to delegate file editing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface EditOperations {
|
||||
/** Read file contents as a Buffer */
|
||||
readFile: (absolutePath: string) => Promise<Buffer>;
|
||||
/** Write content to a file */
|
||||
writeFile: (absolutePath: string, content: string) => Promise<void>;
|
||||
/** Check if file is readable and writable (throw if not) */
|
||||
access: (absolutePath: string) => Promise<void>;
|
||||
}
|
||||
|
||||
const defaultEditOperations: EditOperations = {
|
||||
readFile: (path) => fsReadFile(path),
|
||||
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
|
||||
access: (path) => fsAccess(path, constants.R_OK | constants.W_OK),
|
||||
};
|
||||
|
||||
export interface EditToolOptions {
|
||||
/** Custom operations for file editing. Default: local filesystem */
|
||||
operations?: EditOperations;
|
||||
}
|
||||
|
||||
export function createEditTool(
|
||||
cwd: string,
|
||||
options?: EditToolOptions,
|
||||
): AgentTool<typeof editSchema> {
|
||||
const ops = options?.operations ?? defaultEditOperations;
|
||||
|
||||
return {
|
||||
name: "edit",
|
||||
label: "edit",
|
||||
description:
|
||||
"Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits.",
|
||||
parameters: editSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
path,
|
||||
oldText,
|
||||
newText,
|
||||
}: { path: string; oldText: string; newText: string },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
|
||||
return new Promise<{
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: EditToolDetails | undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the edit operation
|
||||
(async () => {
|
||||
try {
|
||||
// Check if file exists
|
||||
try {
|
||||
await ops.access(absolutePath);
|
||||
} catch {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(new Error(`File not found: ${path}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if aborted before reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const rawContent = buffer.toString("utf-8");
|
||||
|
||||
// Check if aborted after reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Strip BOM before matching (LLM won't include invisible BOM in oldText)
|
||||
const { bom, text: content } = stripBom(rawContent);
|
||||
|
||||
const originalEnding = detectLineEnding(content);
|
||||
const normalizedContent = normalizeToLF(content);
|
||||
const normalizedOldText = normalizeToLF(oldText);
|
||||
const normalizedNewText = normalizeToLF(newText);
|
||||
|
||||
// Find the old text using fuzzy matching (tries exact match first, then fuzzy)
|
||||
const matchResult = fuzzyFindText(
|
||||
normalizedContent,
|
||||
normalizedOldText,
|
||||
);
|
||||
|
||||
if (!matchResult.found) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Count occurrences using fuzzy-normalized content for consistency
|
||||
const fuzzyContent = normalizeForFuzzyMatch(normalizedContent);
|
||||
const fuzzyOldText = normalizeForFuzzyMatch(normalizedOldText);
|
||||
const occurrences = fuzzyContent.split(fuzzyOldText).length - 1;
|
||||
|
||||
if (occurrences > 1) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if aborted before writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform replacement using the matched text position
|
||||
// When fuzzy matching was used, contentForReplacement is the normalized version
|
||||
const baseContent = matchResult.contentForReplacement;
|
||||
const newContent =
|
||||
baseContent.substring(0, matchResult.index) +
|
||||
normalizedNewText +
|
||||
baseContent.substring(
|
||||
matchResult.index + matchResult.matchLength,
|
||||
);
|
||||
|
||||
// Verify the replacement actually changed something
|
||||
if (baseContent === newContent) {
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
reject(
|
||||
new Error(
|
||||
`No changes made to ${path}. The replacement produced identical content. This might indicate an issue with special characters or the text not existing as expected.`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
bom + restoreLineEndings(newContent, originalEnding);
|
||||
await ops.writeFile(absolutePath, finalContent);
|
||||
|
||||
// Check if aborted after writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
const diffResult = generateDiffString(baseContent, newContent);
|
||||
resolve({
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Successfully replaced text in ${path}.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
diff: diffResult.diff,
|
||||
firstChangedLine: diffResult.firstChangedLine,
|
||||
},
|
||||
});
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default edit tool using process.cwd() - for backwards compatibility */
|
||||
export const editTool = createEditTool(process.cwd());
|
||||
308
packages/coding-agent/src/core/tools/find.ts
Normal file
308
packages/coding-agent/src/core/tools/find.ts
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawnSync } from "child_process";
|
||||
import { existsSync } from "fs";
|
||||
import { globSync } from "glob";
|
||||
import path from "path";
|
||||
import { ensureTool } from "../../utils/tools-manager.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const findSchema = Type.Object({
|
||||
pattern: Type.String({
|
||||
description:
|
||||
"Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'",
|
||||
}),
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory to search in (default: current directory)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({ description: "Maximum number of results (default: 1000)" }),
|
||||
),
|
||||
});
|
||||
|
||||
export type FindToolInput = Static<typeof findSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 1000;
|
||||
|
||||
export interface FindToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
resultLimitReached?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the find tool.
|
||||
* Override these to delegate file search to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface FindOperations {
|
||||
/** Check if path exists */
|
||||
exists: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Find files matching glob pattern. Returns relative paths. */
|
||||
glob: (
|
||||
pattern: string,
|
||||
cwd: string,
|
||||
options: { ignore: string[]; limit: number },
|
||||
) => Promise<string[]> | string[];
|
||||
}
|
||||
|
||||
const defaultFindOperations: FindOperations = {
|
||||
exists: existsSync,
|
||||
glob: (_pattern, _searchCwd, _options) => {
|
||||
// This is a placeholder - actual fd execution happens in execute
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
export interface FindToolOptions {
|
||||
/** Custom operations for find. Default: local filesystem + fd */
|
||||
operations?: FindOperations;
|
||||
}
|
||||
|
||||
export function createFindTool(
|
||||
cwd: string,
|
||||
options?: FindToolOptions,
|
||||
): AgentTool<typeof findSchema> {
|
||||
const customOps = options?.operations;
|
||||
|
||||
return {
|
||||
name: "find",
|
||||
label: "find",
|
||||
description: `Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} results or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
|
||||
parameters: findSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
pattern,
|
||||
path: searchDir,
|
||||
limit,
|
||||
}: { pattern: string; path?: string; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
const onAbort = () => reject(new Error("Operation aborted"));
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const searchPath = resolveToCwd(searchDir || ".", cwd);
|
||||
const effectiveLimit = limit ?? DEFAULT_LIMIT;
|
||||
const ops = customOps ?? defaultFindOperations;
|
||||
|
||||
// If custom operations provided with glob, use that
|
||||
if (customOps?.glob) {
|
||||
if (!(await ops.exists(searchPath))) {
|
||||
reject(new Error(`Path not found: ${searchPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
const results = await ops.glob(pattern, searchPath, {
|
||||
ignore: ["**/node_modules/**", "**/.git/**"],
|
||||
limit: effectiveLimit,
|
||||
});
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (results.length === 0) {
|
||||
resolve({
|
||||
content: [
|
||||
{ type: "text", text: "No files found matching pattern" },
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Relativize paths
|
||||
const relativized = results.map((p) => {
|
||||
if (p.startsWith(searchPath)) {
|
||||
return p.slice(searchPath.length + 1);
|
||||
}
|
||||
return path.relative(searchPath, p);
|
||||
});
|
||||
|
||||
const resultLimitReached = relativized.length >= effectiveLimit;
|
||||
const rawOutput = relativized.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let resultOutput = truncation.content;
|
||||
const details: FindToolDetails = {};
|
||||
const notices: string[] = [];
|
||||
|
||||
if (resultLimitReached) {
|
||||
notices.push(`${effectiveLimit} results limit reached`);
|
||||
details.resultLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
resultOutput += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: resultOutput }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Default: use fd
|
||||
const fdPath = await ensureTool("fd", true);
|
||||
if (!fdPath) {
|
||||
reject(
|
||||
new Error("fd is not available and could not be downloaded"),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build fd arguments
|
||||
const args: string[] = [
|
||||
"--glob",
|
||||
"--color=never",
|
||||
"--hidden",
|
||||
"--max-results",
|
||||
String(effectiveLimit),
|
||||
];
|
||||
|
||||
// Include .gitignore files
|
||||
const gitignoreFiles = new Set<string>();
|
||||
const rootGitignore = path.join(searchPath, ".gitignore");
|
||||
if (existsSync(rootGitignore)) {
|
||||
gitignoreFiles.add(rootGitignore);
|
||||
}
|
||||
|
||||
try {
|
||||
const nestedGitignores = globSync("**/.gitignore", {
|
||||
cwd: searchPath,
|
||||
dot: true,
|
||||
absolute: true,
|
||||
ignore: ["**/node_modules/**", "**/.git/**"],
|
||||
});
|
||||
for (const file of nestedGitignores) {
|
||||
gitignoreFiles.add(file);
|
||||
}
|
||||
} catch {
|
||||
// Ignore glob errors
|
||||
}
|
||||
|
||||
for (const gitignorePath of gitignoreFiles) {
|
||||
args.push("--ignore-file", gitignorePath);
|
||||
}
|
||||
|
||||
args.push(pattern, searchPath);
|
||||
|
||||
const result = spawnSync(fdPath, args, {
|
||||
encoding: "utf-8",
|
||||
maxBuffer: 10 * 1024 * 1024,
|
||||
});
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (result.error) {
|
||||
reject(new Error(`Failed to run fd: ${result.error.message}`));
|
||||
return;
|
||||
}
|
||||
|
||||
const output = result.stdout?.trim() || "";
|
||||
|
||||
if (result.status !== 0) {
|
||||
const errorMsg =
|
||||
result.stderr?.trim() || `fd exited with code ${result.status}`;
|
||||
if (!output) {
|
||||
reject(new Error(errorMsg));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!output) {
|
||||
resolve({
|
||||
content: [
|
||||
{ type: "text", text: "No files found matching pattern" },
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = output.split("\n");
|
||||
const relativized: string[] = [];
|
||||
|
||||
for (const rawLine of lines) {
|
||||
const line = rawLine.replace(/\r$/, "").trim();
|
||||
if (!line) continue;
|
||||
|
||||
const hadTrailingSlash =
|
||||
line.endsWith("/") || line.endsWith("\\");
|
||||
let relativePath = line;
|
||||
if (line.startsWith(searchPath)) {
|
||||
relativePath = line.slice(searchPath.length + 1);
|
||||
} else {
|
||||
relativePath = path.relative(searchPath, line);
|
||||
}
|
||||
|
||||
if (hadTrailingSlash && !relativePath.endsWith("/")) {
|
||||
relativePath += "/";
|
||||
}
|
||||
|
||||
relativized.push(relativePath);
|
||||
}
|
||||
|
||||
const resultLimitReached = relativized.length >= effectiveLimit;
|
||||
const rawOutput = relativized.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let resultOutput = truncation.content;
|
||||
const details: FindToolDetails = {};
|
||||
const notices: string[] = [];
|
||||
|
||||
if (resultLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} results limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
|
||||
);
|
||||
details.resultLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
resultOutput += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: resultOutput }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
} catch (e: any) {
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
reject(e);
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default find tool using process.cwd() - for backwards compatibility */
|
||||
export const findTool = createFindTool(process.cwd());
|
||||
412
packages/coding-agent/src/core/tools/grep.ts
Normal file
412
packages/coding-agent/src/core/tools/grep.ts
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
import { createInterface } from "node:readline";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import { readFileSync, statSync } from "fs";
|
||||
import path from "path";
|
||||
import { ensureTool } from "../../utils/tools-manager.js";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
GREP_MAX_LINE_LENGTH,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
truncateLine,
|
||||
} from "./truncate.js";
|
||||
|
||||
const grepSchema = Type.Object({
|
||||
pattern: Type.String({
|
||||
description: "Search pattern (regex or literal string)",
|
||||
}),
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory or file to search (default: current directory)",
|
||||
}),
|
||||
),
|
||||
glob: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'",
|
||||
}),
|
||||
),
|
||||
ignoreCase: Type.Optional(
|
||||
Type.Boolean({ description: "Case-insensitive search (default: false)" }),
|
||||
),
|
||||
literal: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Treat pattern as literal string instead of regex (default: false)",
|
||||
}),
|
||||
),
|
||||
context: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Number of lines to show before and after each match (default: 0)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of matches to return (default: 100)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type GrepToolInput = Static<typeof grepSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 100;
|
||||
|
||||
export interface GrepToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
matchLimitReached?: number;
|
||||
linesTruncated?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the grep tool.
|
||||
* Override these to delegate search to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface GrepOperations {
|
||||
/** Check if path is a directory. Throws if path doesn't exist. */
|
||||
isDirectory: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Read file contents for context lines */
|
||||
readFile: (absolutePath: string) => Promise<string> | string;
|
||||
}
|
||||
|
||||
const defaultGrepOperations: GrepOperations = {
|
||||
isDirectory: (p) => statSync(p).isDirectory(),
|
||||
readFile: (p) => readFileSync(p, "utf-8"),
|
||||
};
|
||||
|
||||
export interface GrepToolOptions {
|
||||
/** Custom operations for grep. Default: local filesystem + ripgrep */
|
||||
operations?: GrepOperations;
|
||||
}
|
||||
|
||||
export function createGrepTool(
|
||||
cwd: string,
|
||||
options?: GrepToolOptions,
|
||||
): AgentTool<typeof grepSchema> {
|
||||
const customOps = options?.operations;
|
||||
|
||||
return {
|
||||
name: "grep",
|
||||
label: "grep",
|
||||
description: `Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} matches or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Long lines are truncated to ${GREP_MAX_LINE_LENGTH} chars.`,
|
||||
parameters: grepSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
pattern,
|
||||
path: searchDir,
|
||||
glob,
|
||||
ignoreCase,
|
||||
literal,
|
||||
context,
|
||||
limit,
|
||||
}: {
|
||||
pattern: string;
|
||||
path?: string;
|
||||
glob?: string;
|
||||
ignoreCase?: boolean;
|
||||
literal?: boolean;
|
||||
context?: number;
|
||||
limit?: number;
|
||||
},
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let settled = false;
|
||||
const settle = (fn: () => void) => {
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
fn();
|
||||
}
|
||||
};
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const rgPath = await ensureTool("rg", true);
|
||||
if (!rgPath) {
|
||||
settle(() =>
|
||||
reject(
|
||||
new Error(
|
||||
"ripgrep (rg) is not available and could not be downloaded",
|
||||
),
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const searchPath = resolveToCwd(searchDir || ".", cwd);
|
||||
const ops = customOps ?? defaultGrepOperations;
|
||||
|
||||
let isDirectory: boolean;
|
||||
try {
|
||||
isDirectory = await ops.isDirectory(searchPath);
|
||||
} catch (_err) {
|
||||
settle(() => reject(new Error(`Path not found: ${searchPath}`)));
|
||||
return;
|
||||
}
|
||||
const contextValue = context && context > 0 ? context : 0;
|
||||
const effectiveLimit = Math.max(1, limit ?? DEFAULT_LIMIT);
|
||||
|
||||
const formatPath = (filePath: string): string => {
|
||||
if (isDirectory) {
|
||||
const relative = path.relative(searchPath, filePath);
|
||||
if (relative && !relative.startsWith("..")) {
|
||||
return relative.replace(/\\/g, "/");
|
||||
}
|
||||
}
|
||||
return path.basename(filePath);
|
||||
};
|
||||
|
||||
const fileCache = new Map<string, string[]>();
|
||||
const getFileLines = async (
|
||||
filePath: string,
|
||||
): Promise<string[]> => {
|
||||
let lines = fileCache.get(filePath);
|
||||
if (!lines) {
|
||||
try {
|
||||
const content = await ops.readFile(filePath);
|
||||
lines = content
|
||||
.replace(/\r\n/g, "\n")
|
||||
.replace(/\r/g, "\n")
|
||||
.split("\n");
|
||||
} catch {
|
||||
lines = [];
|
||||
}
|
||||
fileCache.set(filePath, lines);
|
||||
}
|
||||
return lines;
|
||||
};
|
||||
|
||||
const args: string[] = [
|
||||
"--json",
|
||||
"--line-number",
|
||||
"--color=never",
|
||||
"--hidden",
|
||||
];
|
||||
|
||||
if (ignoreCase) {
|
||||
args.push("--ignore-case");
|
||||
}
|
||||
|
||||
if (literal) {
|
||||
args.push("--fixed-strings");
|
||||
}
|
||||
|
||||
if (glob) {
|
||||
args.push("--glob", glob);
|
||||
}
|
||||
|
||||
args.push(pattern, searchPath);
|
||||
|
||||
const child = spawn(rgPath, args, {
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
const rl = createInterface({ input: child.stdout });
|
||||
let stderr = "";
|
||||
let matchCount = 0;
|
||||
let matchLimitReached = false;
|
||||
let linesTruncated = false;
|
||||
let aborted = false;
|
||||
let killedDueToLimit = false;
|
||||
const outputLines: string[] = [];
|
||||
|
||||
const cleanup = () => {
|
||||
rl.close();
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
};
|
||||
|
||||
const stopChild = (dueToLimit: boolean = false) => {
|
||||
if (!child.killed) {
|
||||
killedDueToLimit = dueToLimit;
|
||||
child.kill();
|
||||
}
|
||||
};
|
||||
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
stopChild();
|
||||
};
|
||||
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
child.stderr?.on("data", (chunk) => {
|
||||
stderr += chunk.toString();
|
||||
});
|
||||
|
||||
const formatBlock = async (
|
||||
filePath: string,
|
||||
lineNumber: number,
|
||||
): Promise<string[]> => {
|
||||
const relativePath = formatPath(filePath);
|
||||
const lines = await getFileLines(filePath);
|
||||
if (!lines.length) {
|
||||
return [`${relativePath}:${lineNumber}: (unable to read file)`];
|
||||
}
|
||||
|
||||
const block: string[] = [];
|
||||
const start =
|
||||
contextValue > 0
|
||||
? Math.max(1, lineNumber - contextValue)
|
||||
: lineNumber;
|
||||
const end =
|
||||
contextValue > 0
|
||||
? Math.min(lines.length, lineNumber + contextValue)
|
||||
: lineNumber;
|
||||
|
||||
for (let current = start; current <= end; current++) {
|
||||
const lineText = lines[current - 1] ?? "";
|
||||
const sanitized = lineText.replace(/\r/g, "");
|
||||
const isMatchLine = current === lineNumber;
|
||||
|
||||
// Truncate long lines
|
||||
const { text: truncatedText, wasTruncated } =
|
||||
truncateLine(sanitized);
|
||||
if (wasTruncated) {
|
||||
linesTruncated = true;
|
||||
}
|
||||
|
||||
if (isMatchLine) {
|
||||
block.push(`${relativePath}:${current}: ${truncatedText}`);
|
||||
} else {
|
||||
block.push(`${relativePath}-${current}- ${truncatedText}`);
|
||||
}
|
||||
}
|
||||
|
||||
return block;
|
||||
};
|
||||
|
||||
// Collect matches during streaming, format after
|
||||
const matches: Array<{ filePath: string; lineNumber: number }> = [];
|
||||
|
||||
rl.on("line", (line) => {
|
||||
if (!line.trim() || matchCount >= effectiveLimit) {
|
||||
return;
|
||||
}
|
||||
|
||||
let event: any;
|
||||
try {
|
||||
event = JSON.parse(line);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === "match") {
|
||||
matchCount++;
|
||||
const filePath = event.data?.path?.text;
|
||||
const lineNumber = event.data?.line_number;
|
||||
|
||||
if (filePath && typeof lineNumber === "number") {
|
||||
matches.push({ filePath, lineNumber });
|
||||
}
|
||||
|
||||
if (matchCount >= effectiveLimit) {
|
||||
matchLimitReached = true;
|
||||
stopChild(true);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
child.on("error", (error) => {
|
||||
cleanup();
|
||||
settle(() =>
|
||||
reject(new Error(`Failed to run ripgrep: ${error.message}`)),
|
||||
);
|
||||
});
|
||||
|
||||
child.on("close", async (code) => {
|
||||
cleanup();
|
||||
|
||||
if (aborted) {
|
||||
settle(() => reject(new Error("Operation aborted")));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!killedDueToLimit && code !== 0 && code !== 1) {
|
||||
const errorMsg =
|
||||
stderr.trim() || `ripgrep exited with code ${code}`;
|
||||
settle(() => reject(new Error(errorMsg)));
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchCount === 0) {
|
||||
settle(() =>
|
||||
resolve({
|
||||
content: [{ type: "text", text: "No matches found" }],
|
||||
details: undefined,
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Format matches (async to support remote file reading)
|
||||
for (const match of matches) {
|
||||
const block = await formatBlock(
|
||||
match.filePath,
|
||||
match.lineNumber,
|
||||
);
|
||||
outputLines.push(...block);
|
||||
}
|
||||
|
||||
// Apply byte truncation (no line limit since we already have match limit)
|
||||
const rawOutput = outputLines.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let output = truncation.content;
|
||||
const details: GrepToolDetails = {};
|
||||
|
||||
// Build notices
|
||||
const notices: string[] = [];
|
||||
|
||||
if (matchLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} matches limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
|
||||
);
|
||||
details.matchLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (linesTruncated) {
|
||||
notices.push(
|
||||
`Some lines truncated to ${GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines`,
|
||||
);
|
||||
details.linesTruncated = true;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
output += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
settle(() =>
|
||||
resolve({
|
||||
content: [{ type: "text", text: output }],
|
||||
details:
|
||||
Object.keys(details).length > 0 ? details : undefined,
|
||||
}),
|
||||
);
|
||||
});
|
||||
} catch (err) {
|
||||
settle(() => reject(err as Error));
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default grep tool using process.cwd() - for backwards compatibility */
|
||||
export const grepTool = createGrepTool(process.cwd());
|
||||
150
packages/coding-agent/src/core/tools/index.ts
Normal file
150
packages/coding-agent/src/core/tools/index.ts
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
export {
|
||||
type BashOperations,
|
||||
type BashSpawnContext,
|
||||
type BashSpawnHook,
|
||||
type BashToolDetails,
|
||||
type BashToolInput,
|
||||
type BashToolOptions,
|
||||
bashTool,
|
||||
createBashTool,
|
||||
} from "./bash.js";
|
||||
export {
|
||||
createEditTool,
|
||||
type EditOperations,
|
||||
type EditToolDetails,
|
||||
type EditToolInput,
|
||||
type EditToolOptions,
|
||||
editTool,
|
||||
} from "./edit.js";
|
||||
export {
|
||||
createFindTool,
|
||||
type FindOperations,
|
||||
type FindToolDetails,
|
||||
type FindToolInput,
|
||||
type FindToolOptions,
|
||||
findTool,
|
||||
} from "./find.js";
|
||||
export {
|
||||
createGrepTool,
|
||||
type GrepOperations,
|
||||
type GrepToolDetails,
|
||||
type GrepToolInput,
|
||||
type GrepToolOptions,
|
||||
grepTool,
|
||||
} from "./grep.js";
|
||||
export {
|
||||
createLsTool,
|
||||
type LsOperations,
|
||||
type LsToolDetails,
|
||||
type LsToolInput,
|
||||
type LsToolOptions,
|
||||
lsTool,
|
||||
} from "./ls.js";
|
||||
export {
|
||||
createReadTool,
|
||||
type ReadOperations,
|
||||
type ReadToolDetails,
|
||||
type ReadToolInput,
|
||||
type ReadToolOptions,
|
||||
readTool,
|
||||
} from "./read.js";
|
||||
export {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationOptions,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
truncateLine,
|
||||
truncateTail,
|
||||
} from "./truncate.js";
|
||||
export {
|
||||
createWriteTool,
|
||||
type WriteOperations,
|
||||
type WriteToolInput,
|
||||
type WriteToolOptions,
|
||||
writeTool,
|
||||
} from "./write.js";
|
||||
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type BashToolOptions, bashTool, createBashTool } from "./bash.js";
|
||||
import { createEditTool, editTool } from "./edit.js";
|
||||
import { createFindTool, findTool } from "./find.js";
|
||||
import { createGrepTool, grepTool } from "./grep.js";
|
||||
import { createLsTool, lsTool } from "./ls.js";
|
||||
import { createReadTool, type ReadToolOptions, readTool } from "./read.js";
|
||||
import { createWriteTool, writeTool } from "./write.js";
|
||||
|
||||
/** Tool type (AgentTool from pi-ai) */
|
||||
export type Tool = AgentTool<any>;
|
||||
|
||||
// Default tools for full access mode (using process.cwd())
|
||||
export const codingTools: Tool[] = [readTool, bashTool, editTool, writeTool];
|
||||
|
||||
// Read-only tools for exploration without modification (using process.cwd())
|
||||
export const readOnlyTools: Tool[] = [readTool, grepTool, findTool, lsTool];
|
||||
|
||||
// All available tools (using process.cwd())
|
||||
export const allTools = {
|
||||
read: readTool,
|
||||
bash: bashTool,
|
||||
edit: editTool,
|
||||
write: writeTool,
|
||||
grep: grepTool,
|
||||
find: findTool,
|
||||
ls: lsTool,
|
||||
};
|
||||
|
||||
export type ToolName = keyof typeof allTools;
|
||||
|
||||
export interface ToolsOptions {
|
||||
/** Options for the read tool */
|
||||
read?: ReadToolOptions;
|
||||
/** Options for the bash tool */
|
||||
bash?: BashToolOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create coding tools configured for a specific working directory.
|
||||
*/
|
||||
export function createCodingTools(cwd: string, options?: ToolsOptions): Tool[] {
|
||||
return [
|
||||
createReadTool(cwd, options?.read),
|
||||
createBashTool(cwd, options?.bash),
|
||||
createEditTool(cwd),
|
||||
createWriteTool(cwd),
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create read-only tools configured for a specific working directory.
|
||||
*/
|
||||
export function createReadOnlyTools(
|
||||
cwd: string,
|
||||
options?: ToolsOptions,
|
||||
): Tool[] {
|
||||
return [
|
||||
createReadTool(cwd, options?.read),
|
||||
createGrepTool(cwd),
|
||||
createFindTool(cwd),
|
||||
createLsTool(cwd),
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create all tools configured for a specific working directory.
|
||||
*/
|
||||
export function createAllTools(
|
||||
cwd: string,
|
||||
options?: ToolsOptions,
|
||||
): Record<ToolName, Tool> {
|
||||
return {
|
||||
read: createReadTool(cwd, options?.read),
|
||||
bash: createBashTool(cwd, options?.bash),
|
||||
edit: createEditTool(cwd),
|
||||
write: createWriteTool(cwd),
|
||||
grep: createGrepTool(cwd),
|
||||
find: createFindTool(cwd),
|
||||
ls: createLsTool(cwd),
|
||||
};
|
||||
}
|
||||
197
packages/coding-agent/src/core/tools/ls.ts
Normal file
197
packages/coding-agent/src/core/tools/ls.ts
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { existsSync, readdirSync, statSync } from "fs";
|
||||
import nodePath from "path";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const lsSchema = Type.Object({
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: "Directory to list (default: current directory)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of entries to return (default: 500)",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export type LsToolInput = Static<typeof lsSchema>;
|
||||
|
||||
const DEFAULT_LIMIT = 500;
|
||||
|
||||
export interface LsToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
entryLimitReached?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the ls tool.
|
||||
* Override these to delegate directory listing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface LsOperations {
|
||||
/** Check if path exists */
|
||||
exists: (absolutePath: string) => Promise<boolean> | boolean;
|
||||
/** Get file/directory stats. Throws if not found. */
|
||||
stat: (
|
||||
absolutePath: string,
|
||||
) => Promise<{ isDirectory: () => boolean }> | { isDirectory: () => boolean };
|
||||
/** Read directory entries */
|
||||
readdir: (absolutePath: string) => Promise<string[]> | string[];
|
||||
}
|
||||
|
||||
const defaultLsOperations: LsOperations = {
|
||||
exists: existsSync,
|
||||
stat: statSync,
|
||||
readdir: readdirSync,
|
||||
};
|
||||
|
||||
export interface LsToolOptions {
|
||||
/** Custom operations for directory listing. Default: local filesystem */
|
||||
operations?: LsOperations;
|
||||
}
|
||||
|
||||
export function createLsTool(
|
||||
cwd: string,
|
||||
options?: LsToolOptions,
|
||||
): AgentTool<typeof lsSchema> {
|
||||
const ops = options?.operations ?? defaultLsOperations;
|
||||
|
||||
return {
|
||||
name: "ls",
|
||||
label: "ls",
|
||||
description: `List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to ${DEFAULT_LIMIT} entries or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
|
||||
parameters: lsSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ path, limit }: { path?: string; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
const onAbort = () => reject(new Error("Operation aborted"));
|
||||
signal?.addEventListener("abort", onAbort, { once: true });
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const dirPath = resolveToCwd(path || ".", cwd);
|
||||
const effectiveLimit = limit ?? DEFAULT_LIMIT;
|
||||
|
||||
// Check if path exists
|
||||
if (!(await ops.exists(dirPath))) {
|
||||
reject(new Error(`Path not found: ${dirPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if path is a directory
|
||||
const stat = await ops.stat(dirPath);
|
||||
if (!stat.isDirectory()) {
|
||||
reject(new Error(`Not a directory: ${dirPath}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Read directory entries
|
||||
let entries: string[];
|
||||
try {
|
||||
entries = await ops.readdir(dirPath);
|
||||
} catch (e: any) {
|
||||
reject(new Error(`Cannot read directory: ${e.message}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort alphabetically (case-insensitive)
|
||||
entries.sort((a, b) =>
|
||||
a.toLowerCase().localeCompare(b.toLowerCase()),
|
||||
);
|
||||
|
||||
// Format entries with directory indicators
|
||||
const results: string[] = [];
|
||||
let entryLimitReached = false;
|
||||
|
||||
for (const entry of entries) {
|
||||
if (results.length >= effectiveLimit) {
|
||||
entryLimitReached = true;
|
||||
break;
|
||||
}
|
||||
|
||||
const fullPath = nodePath.join(dirPath, entry);
|
||||
let suffix = "";
|
||||
|
||||
try {
|
||||
const entryStat = await ops.stat(fullPath);
|
||||
if (entryStat.isDirectory()) {
|
||||
suffix = "/";
|
||||
}
|
||||
} catch {
|
||||
// Skip entries we can't stat
|
||||
continue;
|
||||
}
|
||||
|
||||
results.push(entry + suffix);
|
||||
}
|
||||
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
|
||||
if (results.length === 0) {
|
||||
resolve({
|
||||
content: [{ type: "text", text: "(empty directory)" }],
|
||||
details: undefined,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply byte truncation (no line limit since we already have entry limit)
|
||||
const rawOutput = results.join("\n");
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: Number.MAX_SAFE_INTEGER,
|
||||
});
|
||||
|
||||
let output = truncation.content;
|
||||
const details: LsToolDetails = {};
|
||||
|
||||
// Build notices
|
||||
const notices: string[] = [];
|
||||
|
||||
if (entryLimitReached) {
|
||||
notices.push(
|
||||
`${effectiveLimit} entries limit reached. Use limit=${effectiveLimit * 2} for more`,
|
||||
);
|
||||
details.entryLimitReached = effectiveLimit;
|
||||
}
|
||||
|
||||
if (truncation.truncated) {
|
||||
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
|
||||
details.truncation = truncation;
|
||||
}
|
||||
|
||||
if (notices.length > 0) {
|
||||
output += `\n\n[${notices.join(". ")}]`;
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [{ type: "text", text: output }],
|
||||
details: Object.keys(details).length > 0 ? details : undefined,
|
||||
});
|
||||
} catch (e: any) {
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
reject(e);
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default ls tool using process.cwd() - for backwards compatibility */
|
||||
export const lsTool = createLsTool(process.cwd());
|
||||
94
packages/coding-agent/src/core/tools/path-utils.ts
Normal file
94
packages/coding-agent/src/core/tools/path-utils.ts
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import { accessSync, constants } from "node:fs";
|
||||
import * as os from "node:os";
|
||||
import { isAbsolute, resolve as resolvePath } from "node:path";
|
||||
|
||||
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
|
||||
const NARROW_NO_BREAK_SPACE = "\u202F";
|
||||
function normalizeUnicodeSpaces(str: string): string {
|
||||
return str.replace(UNICODE_SPACES, " ");
|
||||
}
|
||||
|
||||
function tryMacOSScreenshotPath(filePath: string): string {
|
||||
return filePath.replace(/ (AM|PM)\./g, `${NARROW_NO_BREAK_SPACE}$1.`);
|
||||
}
|
||||
|
||||
function tryNFDVariant(filePath: string): string {
|
||||
// macOS stores filenames in NFD (decomposed) form, try converting user input to NFD
|
||||
return filePath.normalize("NFD");
|
||||
}
|
||||
|
||||
function tryCurlyQuoteVariant(filePath: string): string {
|
||||
// macOS uses U+2019 (right single quotation mark) in screenshot names like "Capture d'écran"
|
||||
// Users typically type U+0027 (straight apostrophe)
|
||||
return filePath.replace(/'/g, "\u2019");
|
||||
}
|
||||
|
||||
function fileExists(filePath: string): boolean {
|
||||
try {
|
||||
accessSync(filePath, constants.F_OK);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeAtPrefix(filePath: string): string {
|
||||
return filePath.startsWith("@") ? filePath.slice(1) : filePath;
|
||||
}
|
||||
|
||||
export function expandPath(filePath: string): string {
|
||||
const normalized = normalizeUnicodeSpaces(normalizeAtPrefix(filePath));
|
||||
if (normalized === "~") {
|
||||
return os.homedir();
|
||||
}
|
||||
if (normalized.startsWith("~/")) {
|
||||
return os.homedir() + normalized.slice(1);
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a path relative to the given cwd.
|
||||
* Handles ~ expansion and absolute paths.
|
||||
*/
|
||||
export function resolveToCwd(filePath: string, cwd: string): string {
|
||||
const expanded = expandPath(filePath);
|
||||
if (isAbsolute(expanded)) {
|
||||
return expanded;
|
||||
}
|
||||
return resolvePath(cwd, expanded);
|
||||
}
|
||||
|
||||
export function resolveReadPath(filePath: string, cwd: string): string {
|
||||
const resolved = resolveToCwd(filePath, cwd);
|
||||
|
||||
if (fileExists(resolved)) {
|
||||
return resolved;
|
||||
}
|
||||
|
||||
// Try macOS AM/PM variant (narrow no-break space before AM/PM)
|
||||
const amPmVariant = tryMacOSScreenshotPath(resolved);
|
||||
if (amPmVariant !== resolved && fileExists(amPmVariant)) {
|
||||
return amPmVariant;
|
||||
}
|
||||
|
||||
// Try NFD variant (macOS stores filenames in NFD form)
|
||||
const nfdVariant = tryNFDVariant(resolved);
|
||||
if (nfdVariant !== resolved && fileExists(nfdVariant)) {
|
||||
return nfdVariant;
|
||||
}
|
||||
|
||||
// Try curly quote variant (macOS uses U+2019 in screenshot names)
|
||||
const curlyVariant = tryCurlyQuoteVariant(resolved);
|
||||
if (curlyVariant !== resolved && fileExists(curlyVariant)) {
|
||||
return curlyVariant;
|
||||
}
|
||||
|
||||
// Try combined NFD + curly quote (for French macOS screenshots like "Capture d'écran")
|
||||
const nfdCurlyVariant = tryCurlyQuoteVariant(nfdVariant);
|
||||
if (nfdCurlyVariant !== resolved && fileExists(nfdCurlyVariant)) {
|
||||
return nfdCurlyVariant;
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
265
packages/coding-agent/src/core/tools/read.ts
Normal file
265
packages/coding-agent/src/core/tools/read.ts
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { constants } from "fs";
|
||||
import { access as fsAccess, readFile as fsReadFile } from "fs/promises";
|
||||
import { formatDimensionNote, resizeImage } from "../../utils/image-resize.js";
|
||||
import { detectSupportedImageMimeTypeFromFile } from "../../utils/mime.js";
|
||||
import { resolveReadPath } from "./path-utils.js";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
} from "./truncate.js";
|
||||
|
||||
const readSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to read (relative or absolute)",
|
||||
}),
|
||||
offset: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Line number to start reading from (1-indexed)",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({ description: "Maximum number of lines to read" }),
|
||||
),
|
||||
});
|
||||
|
||||
export type ReadToolInput = Static<typeof readSchema>;
|
||||
|
||||
export interface ReadToolDetails {
|
||||
truncation?: TruncationResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pluggable operations for the read tool.
|
||||
* Override these to delegate file reading to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface ReadOperations {
|
||||
/** Read file contents as a Buffer */
|
||||
readFile: (absolutePath: string) => Promise<Buffer>;
|
||||
/** Check if file is readable (throw if not) */
|
||||
access: (absolutePath: string) => Promise<void>;
|
||||
/** Detect image MIME type, return null/undefined for non-images */
|
||||
detectImageMimeType?: (
|
||||
absolutePath: string,
|
||||
) => Promise<string | null | undefined>;
|
||||
}
|
||||
|
||||
const defaultReadOperations: ReadOperations = {
|
||||
readFile: (path) => fsReadFile(path),
|
||||
access: (path) => fsAccess(path, constants.R_OK),
|
||||
detectImageMimeType: detectSupportedImageMimeTypeFromFile,
|
||||
};
|
||||
|
||||
export interface ReadToolOptions {
|
||||
/** Whether to auto-resize images to 2000x2000 max. Default: true */
|
||||
autoResizeImages?: boolean;
|
||||
/** Custom operations for file reading. Default: local filesystem */
|
||||
operations?: ReadOperations;
|
||||
}
|
||||
|
||||
export function createReadTool(
|
||||
cwd: string,
|
||||
options?: ReadToolOptions,
|
||||
): AgentTool<typeof readSchema> {
|
||||
const autoResizeImages = options?.autoResizeImages ?? true;
|
||||
const ops = options?.operations ?? defaultReadOperations;
|
||||
|
||||
return {
|
||||
name: "read",
|
||||
label: "read",
|
||||
description: `Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files. When you need the full file, continue with offset until complete.`,
|
||||
parameters: readSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{
|
||||
path,
|
||||
offset,
|
||||
limit,
|
||||
}: { path: string; offset?: number; limit?: number },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveReadPath(path, cwd);
|
||||
|
||||
return new Promise<{
|
||||
content: (TextContent | ImageContent)[];
|
||||
details: ReadToolDetails | undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the read operation
|
||||
(async () => {
|
||||
try {
|
||||
// Check if file exists
|
||||
await ops.access(absolutePath);
|
||||
|
||||
// Check if aborted before reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
const mimeType = ops.detectImageMimeType
|
||||
? await ops.detectImageMimeType(absolutePath)
|
||||
: undefined;
|
||||
|
||||
// Read the file based on type
|
||||
let content: (TextContent | ImageContent)[];
|
||||
let details: ReadToolDetails | undefined;
|
||||
|
||||
if (mimeType) {
|
||||
// Read as image (binary)
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const base64 = buffer.toString("base64");
|
||||
|
||||
if (autoResizeImages) {
|
||||
// Resize image if needed
|
||||
const resized = await resizeImage({
|
||||
type: "image",
|
||||
data: base64,
|
||||
mimeType,
|
||||
});
|
||||
const dimensionNote = formatDimensionNote(resized);
|
||||
|
||||
let textNote = `Read image file [${resized.mimeType}]`;
|
||||
if (dimensionNote) {
|
||||
textNote += `\n${dimensionNote}`;
|
||||
}
|
||||
|
||||
content = [
|
||||
{ type: "text", text: textNote },
|
||||
{
|
||||
type: "image",
|
||||
data: resized.data,
|
||||
mimeType: resized.mimeType,
|
||||
},
|
||||
];
|
||||
} else {
|
||||
const textNote = `Read image file [${mimeType}]`;
|
||||
content = [
|
||||
{ type: "text", text: textNote },
|
||||
{ type: "image", data: base64, mimeType },
|
||||
];
|
||||
}
|
||||
} else {
|
||||
// Read as text
|
||||
const buffer = await ops.readFile(absolutePath);
|
||||
const textContent = buffer.toString("utf-8");
|
||||
const allLines = textContent.split("\n");
|
||||
const totalFileLines = allLines.length;
|
||||
|
||||
// Apply offset if specified (1-indexed to 0-indexed)
|
||||
const startLine = offset ? Math.max(0, offset - 1) : 0;
|
||||
const startLineDisplay = startLine + 1; // For display (1-indexed)
|
||||
|
||||
// Check if offset is out of bounds
|
||||
if (startLine >= allLines.length) {
|
||||
throw new Error(
|
||||
`Offset ${offset} is beyond end of file (${allLines.length} lines total)`,
|
||||
);
|
||||
}
|
||||
|
||||
// If limit is specified by user, use it; otherwise we'll let truncateHead decide
|
||||
let selectedContent: string;
|
||||
let userLimitedLines: number | undefined;
|
||||
if (limit !== undefined) {
|
||||
const endLine = Math.min(startLine + limit, allLines.length);
|
||||
selectedContent = allLines.slice(startLine, endLine).join("\n");
|
||||
userLimitedLines = endLine - startLine;
|
||||
} else {
|
||||
selectedContent = allLines.slice(startLine).join("\n");
|
||||
}
|
||||
|
||||
// Apply truncation (respects both line and byte limits)
|
||||
const truncation = truncateHead(selectedContent);
|
||||
|
||||
let outputText: string;
|
||||
|
||||
if (truncation.firstLineExceedsLimit) {
|
||||
// First line at offset exceeds 30KB - tell model to use bash
|
||||
const firstLineSize = formatSize(
|
||||
Buffer.byteLength(allLines[startLine], "utf-8"),
|
||||
);
|
||||
outputText = `[Line ${startLineDisplay} is ${firstLineSize}, exceeds ${formatSize(DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '${startLineDisplay}p' ${path} | head -c ${DEFAULT_MAX_BYTES}]`;
|
||||
details = { truncation };
|
||||
} else if (truncation.truncated) {
|
||||
// Truncation occurred - build actionable notice
|
||||
const endLineDisplay =
|
||||
startLineDisplay + truncation.outputLines - 1;
|
||||
const nextOffset = endLineDisplay + 1;
|
||||
|
||||
outputText = truncation.content;
|
||||
|
||||
if (truncation.truncatedBy === "lines") {
|
||||
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines}. Use offset=${nextOffset} to continue.]`;
|
||||
} else {
|
||||
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Use offset=${nextOffset} to continue.]`;
|
||||
}
|
||||
details = { truncation };
|
||||
} else if (
|
||||
userLimitedLines !== undefined &&
|
||||
startLine + userLimitedLines < allLines.length
|
||||
) {
|
||||
// User specified limit, there's more content, but no truncation
|
||||
const remaining =
|
||||
allLines.length - (startLine + userLimitedLines);
|
||||
const nextOffset = startLine + userLimitedLines + 1;
|
||||
|
||||
outputText = truncation.content;
|
||||
outputText += `\n\n[${remaining} more lines in file. Use offset=${nextOffset} to continue.]`;
|
||||
} else {
|
||||
// No truncation, no user limit exceeded
|
||||
outputText = truncation.content;
|
||||
}
|
||||
|
||||
content = [{ type: "text", text: outputText }];
|
||||
}
|
||||
|
||||
// Check if aborted after reading
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
resolve({ content, details });
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default read tool using process.cwd() - for backwards compatibility */
|
||||
export const readTool = createReadTool(process.cwd());
|
||||
279
packages/coding-agent/src/core/tools/truncate.ts
Normal file
279
packages/coding-agent/src/core/tools/truncate.ts
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
/**
|
||||
* Shared truncation utilities for tool outputs.
|
||||
*
|
||||
* Truncation is based on two independent limits - whichever is hit first wins:
|
||||
* - Line limit (default: 2000 lines)
|
||||
* - Byte limit (default: 50KB)
|
||||
*
|
||||
* Never returns partial lines (except bash tail truncation edge case).
|
||||
*/
|
||||
|
||||
export const DEFAULT_MAX_LINES = 2000;
|
||||
export const DEFAULT_MAX_BYTES = 50 * 1024; // 50KB
|
||||
export const GREP_MAX_LINE_LENGTH = 500; // Max chars per grep match line
|
||||
|
||||
export interface TruncationResult {
|
||||
/** The truncated content */
|
||||
content: string;
|
||||
/** Whether truncation occurred */
|
||||
truncated: boolean;
|
||||
/** Which limit was hit: "lines", "bytes", or null if not truncated */
|
||||
truncatedBy: "lines" | "bytes" | null;
|
||||
/** Total number of lines in the original content */
|
||||
totalLines: number;
|
||||
/** Total number of bytes in the original content */
|
||||
totalBytes: number;
|
||||
/** Number of complete lines in the truncated output */
|
||||
outputLines: number;
|
||||
/** Number of bytes in the truncated output */
|
||||
outputBytes: number;
|
||||
/** Whether the last line was partially truncated (only for tail truncation edge case) */
|
||||
lastLinePartial: boolean;
|
||||
/** Whether the first line exceeded the byte limit (for head truncation) */
|
||||
firstLineExceedsLimit: boolean;
|
||||
/** The max lines limit that was applied */
|
||||
maxLines: number;
|
||||
/** The max bytes limit that was applied */
|
||||
maxBytes: number;
|
||||
}
|
||||
|
||||
export interface TruncationOptions {
|
||||
/** Maximum number of lines (default: 2000) */
|
||||
maxLines?: number;
|
||||
/** Maximum number of bytes (default: 50KB) */
|
||||
maxBytes?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format bytes as human-readable size.
|
||||
*/
|
||||
export function formatSize(bytes: number): string {
|
||||
if (bytes < 1024) {
|
||||
return `${bytes}B`;
|
||||
} else if (bytes < 1024 * 1024) {
|
||||
return `${(bytes / 1024).toFixed(1)}KB`;
|
||||
} else {
|
||||
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate content from the head (keep first N lines/bytes).
|
||||
* Suitable for file reads where you want to see the beginning.
|
||||
*
|
||||
* Never returns partial lines. If first line exceeds byte limit,
|
||||
* returns empty content with firstLineExceedsLimit=true.
|
||||
*/
|
||||
export function truncateHead(
|
||||
content: string,
|
||||
options: TruncationOptions = {},
|
||||
): TruncationResult {
|
||||
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
|
||||
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
|
||||
|
||||
const totalBytes = Buffer.byteLength(content, "utf-8");
|
||||
const lines = content.split("\n");
|
||||
const totalLines = lines.length;
|
||||
|
||||
// Check if no truncation needed
|
||||
if (totalLines <= maxLines && totalBytes <= maxBytes) {
|
||||
return {
|
||||
content,
|
||||
truncated: false,
|
||||
truncatedBy: null,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: totalLines,
|
||||
outputBytes: totalBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Check if first line alone exceeds byte limit
|
||||
const firstLineBytes = Buffer.byteLength(lines[0], "utf-8");
|
||||
if (firstLineBytes > maxBytes) {
|
||||
return {
|
||||
content: "",
|
||||
truncated: true,
|
||||
truncatedBy: "bytes",
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: 0,
|
||||
outputBytes: 0,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: true,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Collect complete lines that fit
|
||||
const outputLinesArr: string[] = [];
|
||||
let outputBytesCount = 0;
|
||||
let truncatedBy: "lines" | "bytes" = "lines";
|
||||
|
||||
for (let i = 0; i < lines.length && i < maxLines; i++) {
|
||||
const line = lines[i];
|
||||
const lineBytes = Buffer.byteLength(line, "utf-8") + (i > 0 ? 1 : 0); // +1 for newline
|
||||
|
||||
if (outputBytesCount + lineBytes > maxBytes) {
|
||||
truncatedBy = "bytes";
|
||||
break;
|
||||
}
|
||||
|
||||
outputLinesArr.push(line);
|
||||
outputBytesCount += lineBytes;
|
||||
}
|
||||
|
||||
// If we exited due to line limit
|
||||
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
|
||||
truncatedBy = "lines";
|
||||
}
|
||||
|
||||
const outputContent = outputLinesArr.join("\n");
|
||||
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
|
||||
|
||||
return {
|
||||
content: outputContent,
|
||||
truncated: true,
|
||||
truncatedBy,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: outputLinesArr.length,
|
||||
outputBytes: finalOutputBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate content from the tail (keep last N lines/bytes).
|
||||
* Suitable for bash output where you want to see the end (errors, final results).
|
||||
*
|
||||
* May return partial first line if the last line of original content exceeds byte limit.
|
||||
*/
|
||||
export function truncateTail(
|
||||
content: string,
|
||||
options: TruncationOptions = {},
|
||||
): TruncationResult {
|
||||
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
|
||||
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
|
||||
|
||||
const totalBytes = Buffer.byteLength(content, "utf-8");
|
||||
const lines = content.split("\n");
|
||||
const totalLines = lines.length;
|
||||
|
||||
// Check if no truncation needed
|
||||
if (totalLines <= maxLines && totalBytes <= maxBytes) {
|
||||
return {
|
||||
content,
|
||||
truncated: false,
|
||||
truncatedBy: null,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: totalLines,
|
||||
outputBytes: totalBytes,
|
||||
lastLinePartial: false,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
// Work backwards from the end
|
||||
const outputLinesArr: string[] = [];
|
||||
let outputBytesCount = 0;
|
||||
let truncatedBy: "lines" | "bytes" = "lines";
|
||||
let lastLinePartial = false;
|
||||
|
||||
for (
|
||||
let i = lines.length - 1;
|
||||
i >= 0 && outputLinesArr.length < maxLines;
|
||||
i--
|
||||
) {
|
||||
const line = lines[i];
|
||||
const lineBytes =
|
||||
Buffer.byteLength(line, "utf-8") + (outputLinesArr.length > 0 ? 1 : 0); // +1 for newline
|
||||
|
||||
if (outputBytesCount + lineBytes > maxBytes) {
|
||||
truncatedBy = "bytes";
|
||||
// Edge case: if we haven't added ANY lines yet and this line exceeds maxBytes,
|
||||
// take the end of the line (partial)
|
||||
if (outputLinesArr.length === 0) {
|
||||
const truncatedLine = truncateStringToBytesFromEnd(line, maxBytes);
|
||||
outputLinesArr.unshift(truncatedLine);
|
||||
outputBytesCount = Buffer.byteLength(truncatedLine, "utf-8");
|
||||
lastLinePartial = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
outputLinesArr.unshift(line);
|
||||
outputBytesCount += lineBytes;
|
||||
}
|
||||
|
||||
// If we exited due to line limit
|
||||
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
|
||||
truncatedBy = "lines";
|
||||
}
|
||||
|
||||
const outputContent = outputLinesArr.join("\n");
|
||||
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
|
||||
|
||||
return {
|
||||
content: outputContent,
|
||||
truncated: true,
|
||||
truncatedBy,
|
||||
totalLines,
|
||||
totalBytes,
|
||||
outputLines: outputLinesArr.length,
|
||||
outputBytes: finalOutputBytes,
|
||||
lastLinePartial,
|
||||
firstLineExceedsLimit: false,
|
||||
maxLines,
|
||||
maxBytes,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate a string to fit within a byte limit (from the end).
|
||||
* Handles multi-byte UTF-8 characters correctly.
|
||||
*/
|
||||
function truncateStringToBytesFromEnd(str: string, maxBytes: number): string {
|
||||
const buf = Buffer.from(str, "utf-8");
|
||||
if (buf.length <= maxBytes) {
|
||||
return str;
|
||||
}
|
||||
|
||||
// Start from the end, skip maxBytes back
|
||||
let start = buf.length - maxBytes;
|
||||
|
||||
// Find a valid UTF-8 boundary (start of a character)
|
||||
while (start < buf.length && (buf[start] & 0xc0) === 0x80) {
|
||||
start++;
|
||||
}
|
||||
|
||||
return buf.slice(start).toString("utf-8");
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate a single line to max characters, adding [truncated] suffix.
|
||||
* Used for grep match lines.
|
||||
*/
|
||||
export function truncateLine(
|
||||
line: string,
|
||||
maxChars: number = GREP_MAX_LINE_LENGTH,
|
||||
): { text: string; wasTruncated: boolean } {
|
||||
if (line.length <= maxChars) {
|
||||
return { text: line, wasTruncated: false };
|
||||
}
|
||||
return {
|
||||
text: `${line.slice(0, maxChars)}... [truncated]`,
|
||||
wasTruncated: true,
|
||||
};
|
||||
}
|
||||
129
packages/coding-agent/src/core/tools/write.ts
Normal file
129
packages/coding-agent/src/core/tools/write.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import { mkdir as fsMkdir, writeFile as fsWriteFile } from "fs/promises";
|
||||
import { dirname } from "path";
|
||||
import { resolveToCwd } from "./path-utils.js";
|
||||
|
||||
const writeSchema = Type.Object({
|
||||
path: Type.String({
|
||||
description: "Path to the file to write (relative or absolute)",
|
||||
}),
|
||||
content: Type.String({ description: "Content to write to the file" }),
|
||||
});
|
||||
|
||||
export type WriteToolInput = Static<typeof writeSchema>;
|
||||
|
||||
/**
|
||||
* Pluggable operations for the write tool.
|
||||
* Override these to delegate file writing to remote systems (e.g., SSH).
|
||||
*/
|
||||
export interface WriteOperations {
|
||||
/** Write content to a file */
|
||||
writeFile: (absolutePath: string, content: string) => Promise<void>;
|
||||
/** Create directory (recursively) */
|
||||
mkdir: (dir: string) => Promise<void>;
|
||||
}
|
||||
|
||||
const defaultWriteOperations: WriteOperations = {
|
||||
writeFile: (path, content) => fsWriteFile(path, content, "utf-8"),
|
||||
mkdir: (dir) => fsMkdir(dir, { recursive: true }).then(() => {}),
|
||||
};
|
||||
|
||||
export interface WriteToolOptions {
|
||||
/** Custom operations for file writing. Default: local filesystem */
|
||||
operations?: WriteOperations;
|
||||
}
|
||||
|
||||
export function createWriteTool(
|
||||
cwd: string,
|
||||
options?: WriteToolOptions,
|
||||
): AgentTool<typeof writeSchema> {
|
||||
const ops = options?.operations ?? defaultWriteOperations;
|
||||
|
||||
return {
|
||||
name: "write",
|
||||
label: "write",
|
||||
description:
|
||||
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories.",
|
||||
parameters: writeSchema,
|
||||
execute: async (
|
||||
_toolCallId: string,
|
||||
{ path, content }: { path: string; content: string },
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const absolutePath = resolveToCwd(path, cwd);
|
||||
const dir = dirname(absolutePath);
|
||||
|
||||
return new Promise<{
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Operation aborted"));
|
||||
return;
|
||||
}
|
||||
|
||||
let aborted = false;
|
||||
|
||||
// Set up abort handler
|
||||
const onAbort = () => {
|
||||
aborted = true;
|
||||
reject(new Error("Operation aborted"));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
// Perform the write operation
|
||||
(async () => {
|
||||
try {
|
||||
// Create parent directories if needed
|
||||
await ops.mkdir(dir);
|
||||
|
||||
// Check if aborted before writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Write the file
|
||||
await ops.writeFile(absolutePath, content);
|
||||
|
||||
// Check if aborted after writing
|
||||
if (aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
resolve({
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Successfully wrote ${content.length} bytes to ${path}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
});
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
if (signal) {
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
if (!aborted) {
|
||||
reject(error);
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Default write tool using process.cwd() - for backwards compatibility */
|
||||
export const writeTool = createWriteTool(process.cwd());
|
||||
205
packages/coding-agent/src/core/vercel-ai-stream.ts
Normal file
205
packages/coding-agent/src/core/vercel-ai-stream.ts
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
import { randomUUID } from "node:crypto";
|
||||
import type { ServerResponse } from "node:http";
|
||||
import type { AgentSessionEvent } from "./agent-session.js";
|
||||
|
||||
/**
|
||||
* Write a single Vercel AI SDK v5+ SSE chunk to the response.
|
||||
* Format: `data: <JSON>\n\n`
|
||||
* For the terminal [DONE] sentinel: `data: [DONE]\n\n`
|
||||
*/
|
||||
function writeChunk(response: ServerResponse, chunk: object | string): void {
|
||||
if (response.writableEnded) return;
|
||||
const payload = typeof chunk === "string" ? chunk : JSON.stringify(chunk);
|
||||
response.write(`data: ${payload}\n\n`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the user's text from the request body.
|
||||
* Supports both useChat format ({ messages: UIMessage[] }) and simple gateway format ({ text: string }).
|
||||
*/
|
||||
export function extractUserText(body: Record<string, unknown>): string | null {
|
||||
// Simple gateway format
|
||||
if (typeof body.text === "string" && body.text.trim()) {
|
||||
return body.text;
|
||||
}
|
||||
// Convenience format
|
||||
if (typeof body.prompt === "string" && body.prompt.trim()) {
|
||||
return body.prompt;
|
||||
}
|
||||
// Vercel AI SDK useChat format - extract last user message
|
||||
if (Array.isArray(body.messages)) {
|
||||
for (let i = body.messages.length - 1; i >= 0; i--) {
|
||||
const msg = body.messages[i] as Record<string, unknown>;
|
||||
if (msg.role !== "user") continue;
|
||||
// v5+ format with parts array
|
||||
if (Array.isArray(msg.parts)) {
|
||||
for (const part of msg.parts as Array<Record<string, unknown>>) {
|
||||
if (part.type === "text" && typeof part.text === "string") {
|
||||
return part.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
// v4 format with content string
|
||||
if (typeof msg.content === "string" && msg.content.trim()) {
|
||||
return msg.content;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AgentSessionEvent listener that translates events to Vercel AI SDK v5+ SSE
|
||||
* chunks and writes them to the HTTP response.
|
||||
*
|
||||
* Returns the listener function. The caller is responsible for subscribing/unsubscribing.
|
||||
*/
|
||||
export function createVercelStreamListener(
|
||||
response: ServerResponse,
|
||||
messageId?: string,
|
||||
): (event: AgentSessionEvent) => void {
|
||||
// Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
|
||||
// handleChat now subscribes this listener immediately before the queued prompt starts,
|
||||
// so these guards only need to bound the stream to that prompt's event span.
|
||||
let active = false;
|
||||
const msgId = messageId ?? randomUUID();
|
||||
|
||||
return (event: AgentSessionEvent) => {
|
||||
if (response.writableEnded) return;
|
||||
|
||||
// Activate on our agent_start, deactivate on agent_end
|
||||
if (event.type === "agent_start") {
|
||||
if (!active) {
|
||||
active = true;
|
||||
writeChunk(response, { type: "start", messageId: msgId });
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (event.type === "agent_end") {
|
||||
active = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Drop events that don't belong to our message
|
||||
if (!active) return;
|
||||
|
||||
switch (event.type) {
|
||||
case "turn_start":
|
||||
writeChunk(response, { type: "start-step" });
|
||||
return;
|
||||
|
||||
case "message_update": {
|
||||
const inner = event.assistantMessageEvent;
|
||||
switch (inner.type) {
|
||||
case "text_start":
|
||||
writeChunk(response, {
|
||||
type: "text-start",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "text_delta":
|
||||
writeChunk(response, {
|
||||
type: "text-delta",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
delta: inner.delta,
|
||||
});
|
||||
return;
|
||||
case "text_end":
|
||||
writeChunk(response, {
|
||||
type: "text-end",
|
||||
id: `text_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "toolcall_start": {
|
||||
const content = inner.partial.content[inner.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
writeChunk(response, {
|
||||
type: "tool-input-start",
|
||||
toolCallId: content.id,
|
||||
toolName: content.name,
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
case "toolcall_delta": {
|
||||
const content = inner.partial.content[inner.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
writeChunk(response, {
|
||||
type: "tool-input-delta",
|
||||
toolCallId: content.id,
|
||||
inputTextDelta: inner.delta,
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
case "toolcall_end":
|
||||
writeChunk(response, {
|
||||
type: "tool-input-available",
|
||||
toolCallId: inner.toolCall.id,
|
||||
toolName: inner.toolCall.name,
|
||||
input: inner.toolCall.arguments,
|
||||
});
|
||||
return;
|
||||
case "thinking_start":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-start",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
case "thinking_delta":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-delta",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
delta: inner.delta,
|
||||
});
|
||||
return;
|
||||
case "thinking_end":
|
||||
writeChunk(response, {
|
||||
type: "reasoning-end",
|
||||
id: `reasoning_${inner.contentIndex}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
writeChunk(response, { type: "finish-step" });
|
||||
return;
|
||||
|
||||
case "tool_execution_end":
|
||||
writeChunk(response, {
|
||||
type: "tool-output-available",
|
||||
toolCallId: event.toolCallId,
|
||||
output: event.result,
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Write the terminal finish sequence and end the response.
|
||||
*/
|
||||
export function finishVercelStream(
|
||||
response: ServerResponse,
|
||||
finishReason: string = "stop",
|
||||
): void {
|
||||
if (response.writableEnded) return;
|
||||
writeChunk(response, { type: "finish", finishReason });
|
||||
writeChunk(response, "[DONE]");
|
||||
response.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* Write an error chunk and end the response.
|
||||
*/
|
||||
export function errorVercelStream(
|
||||
response: ServerResponse,
|
||||
errorText: string,
|
||||
): void {
|
||||
if (response.writableEnded) return;
|
||||
writeChunk(response, { type: "error", errorText });
|
||||
writeChunk(response, "[DONE]");
|
||||
response.end();
|
||||
}
|
||||
353
packages/coding-agent/src/index.ts
Normal file
353
packages/coding-agent/src/index.ts
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
// Core session management
|
||||
|
||||
// Config paths
|
||||
export { getAgentDir, VERSION } from "./config.js";
|
||||
export {
|
||||
AgentSession,
|
||||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type ModelCycleResult,
|
||||
type ParsedSkillBlock,
|
||||
type PromptOptions,
|
||||
parseSkillBlock,
|
||||
type SessionStats,
|
||||
} from "./core/agent-session.js";
|
||||
// Auth and model registry
|
||||
export {
|
||||
type ApiKeyCredential,
|
||||
type AuthCredential,
|
||||
AuthStorage,
|
||||
type AuthStorageBackend,
|
||||
FileAuthStorageBackend,
|
||||
InMemoryAuthStorageBackend,
|
||||
type OAuthCredential,
|
||||
} from "./core/auth-storage.js";
|
||||
// Compaction
|
||||
export {
|
||||
type BranchPreparation,
|
||||
type BranchSummaryResult,
|
||||
type CollectEntriesResult,
|
||||
type CompactionResult,
|
||||
type CutPointResult,
|
||||
calculateContextTokens,
|
||||
collectEntriesForBranchSummary,
|
||||
compact,
|
||||
DEFAULT_COMPACTION_SETTINGS,
|
||||
estimateTokens,
|
||||
type FileOperations,
|
||||
findCutPoint,
|
||||
findTurnStartIndex,
|
||||
type GenerateBranchSummaryOptions,
|
||||
generateBranchSummary,
|
||||
generateSummary,
|
||||
getLastAssistantUsage,
|
||||
prepareBranchEntries,
|
||||
serializeConversation,
|
||||
shouldCompact,
|
||||
} from "./core/compaction/index.js";
|
||||
export {
|
||||
createEventBus,
|
||||
type EventBus,
|
||||
type EventBusController,
|
||||
} from "./core/event-bus.js";
|
||||
// Extension system
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
AgentToolResult,
|
||||
AgentToolUpdateCallback,
|
||||
AppAction,
|
||||
BashToolCallEvent,
|
||||
BeforeAgentStartEvent,
|
||||
CompactOptions,
|
||||
ContextEvent,
|
||||
ContextUsage,
|
||||
CustomToolCallEvent,
|
||||
EditToolCallEvent,
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFactory,
|
||||
ExtensionFlag,
|
||||
ExtensionHandler,
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
ExtensionUIDialogOptions,
|
||||
ExtensionWidgetOptions,
|
||||
FindToolCallEvent,
|
||||
GrepToolCallEvent,
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
KeybindingsManager,
|
||||
LoadExtensionsResult,
|
||||
LsToolCallEvent,
|
||||
MessageRenderer,
|
||||
MessageRenderOptions,
|
||||
ProviderConfig,
|
||||
ProviderModelConfig,
|
||||
ReadToolCallEvent,
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
SessionBeforeCompactEvent,
|
||||
SessionBeforeForkEvent,
|
||||
SessionBeforeSwitchEvent,
|
||||
SessionBeforeTreeEvent,
|
||||
SessionCompactEvent,
|
||||
SessionForkEvent,
|
||||
SessionShutdownEvent,
|
||||
SessionStartEvent,
|
||||
SessionSwitchEvent,
|
||||
SessionTreeEvent,
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
TerminalInputHandler,
|
||||
ToolCallEvent,
|
||||
ToolDefinition,
|
||||
ToolInfo,
|
||||
ToolRenderResultOptions,
|
||||
ToolResultEvent,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
WidgetPlacement,
|
||||
WriteToolCallEvent,
|
||||
} from "./core/extensions/index.js";
|
||||
export {
|
||||
createExtensionRuntime,
|
||||
discoverAndLoadExtensions,
|
||||
ExtensionRunner,
|
||||
isBashToolResult,
|
||||
isEditToolResult,
|
||||
isFindToolResult,
|
||||
isGrepToolResult,
|
||||
isLsToolResult,
|
||||
isReadToolResult,
|
||||
isToolCallEventType,
|
||||
isWriteToolResult,
|
||||
wrapRegisteredTool,
|
||||
wrapRegisteredTools,
|
||||
wrapToolsWithExtensions,
|
||||
wrapToolWithExtensions,
|
||||
} from "./core/extensions/index.js";
|
||||
// Footer data provider (git branch + extension statuses - data not otherwise available to extensions)
|
||||
export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js";
|
||||
export {
|
||||
createGatewaySessionManager,
|
||||
type GatewayConfig,
|
||||
type GatewayMessageRequest,
|
||||
type GatewayMessageResult,
|
||||
GatewayRuntime,
|
||||
type GatewayRuntimeOptions,
|
||||
type GatewaySessionFactory,
|
||||
type GatewaySessionSnapshot,
|
||||
getActiveGatewayRuntime,
|
||||
sanitizeSessionKey,
|
||||
setActiveGatewayRuntime,
|
||||
} from "./core/gateway-runtime.js";
|
||||
export { convertToLlm } from "./core/messages.js";
|
||||
export { ModelRegistry } from "./core/model-registry.js";
|
||||
export type {
|
||||
PackageManager,
|
||||
PathMetadata,
|
||||
ProgressCallback,
|
||||
ProgressEvent,
|
||||
ResolvedPaths,
|
||||
ResolvedResource,
|
||||
} from "./core/package-manager.js";
|
||||
export { DefaultPackageManager } from "./core/package-manager.js";
|
||||
export type {
|
||||
ResourceCollision,
|
||||
ResourceDiagnostic,
|
||||
ResourceLoader,
|
||||
} from "./core/resource-loader.js";
|
||||
export { DefaultResourceLoader } from "./core/resource-loader.js";
|
||||
// SDK for programmatic usage
|
||||
export {
|
||||
type CreateAgentSessionOptions,
|
||||
type CreateAgentSessionResult,
|
||||
// Factory
|
||||
createAgentSession,
|
||||
createBashTool,
|
||||
// Tool factories (for custom cwd)
|
||||
createCodingTools,
|
||||
createEditTool,
|
||||
createFindTool,
|
||||
createGrepTool,
|
||||
createLsTool,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createWriteTool,
|
||||
type PromptTemplate,
|
||||
// Pre-built tools (use process.cwd())
|
||||
readOnlyTools,
|
||||
} from "./core/sdk.js";
|
||||
export {
|
||||
type BranchSummaryEntry,
|
||||
buildSessionContext,
|
||||
type CompactionEntry,
|
||||
CURRENT_SESSION_VERSION,
|
||||
type CustomEntry,
|
||||
type CustomMessageEntry,
|
||||
type FileEntry,
|
||||
getLatestCompactionEntry,
|
||||
type ModelChangeEntry,
|
||||
migrateSessionEntries,
|
||||
type NewSessionOptions,
|
||||
parseSessionEntries,
|
||||
type SessionContext,
|
||||
type SessionEntry,
|
||||
type SessionEntryBase,
|
||||
type SessionHeader,
|
||||
type SessionInfo,
|
||||
type SessionInfoEntry,
|
||||
SessionManager,
|
||||
type SessionMessageEntry,
|
||||
type ThinkingLevelChangeEntry,
|
||||
} from "./core/session-manager.js";
|
||||
export {
|
||||
type CompactionSettings,
|
||||
type GatewaySettings,
|
||||
type ImageSettings,
|
||||
type PackageSource,
|
||||
type RetrySettings,
|
||||
SettingsManager,
|
||||
} from "./core/settings-manager.js";
|
||||
// Skills
|
||||
export {
|
||||
formatSkillsForPrompt,
|
||||
type LoadSkillsFromDirOptions,
|
||||
type LoadSkillsResult,
|
||||
loadSkills,
|
||||
loadSkillsFromDir,
|
||||
type Skill,
|
||||
type SkillFrontmatter,
|
||||
} from "./core/skills.js";
|
||||
// Tools
|
||||
export {
|
||||
type BashOperations,
|
||||
type BashSpawnContext,
|
||||
type BashSpawnHook,
|
||||
type BashToolDetails,
|
||||
type BashToolInput,
|
||||
type BashToolOptions,
|
||||
bashTool,
|
||||
codingTools,
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
type EditOperations,
|
||||
type EditToolDetails,
|
||||
type EditToolInput,
|
||||
type EditToolOptions,
|
||||
editTool,
|
||||
type FindOperations,
|
||||
type FindToolDetails,
|
||||
type FindToolInput,
|
||||
type FindToolOptions,
|
||||
findTool,
|
||||
formatSize,
|
||||
type GrepOperations,
|
||||
type GrepToolDetails,
|
||||
type GrepToolInput,
|
||||
type GrepToolOptions,
|
||||
grepTool,
|
||||
type LsOperations,
|
||||
type LsToolDetails,
|
||||
type LsToolInput,
|
||||
type LsToolOptions,
|
||||
lsTool,
|
||||
type ReadOperations,
|
||||
type ReadToolDetails,
|
||||
type ReadToolInput,
|
||||
type ReadToolOptions,
|
||||
readTool,
|
||||
type ToolsOptions,
|
||||
type TruncationOptions,
|
||||
type TruncationResult,
|
||||
truncateHead,
|
||||
truncateLine,
|
||||
truncateTail,
|
||||
type WriteOperations,
|
||||
type WriteToolInput,
|
||||
type WriteToolOptions,
|
||||
writeTool,
|
||||
} from "./core/tools/index.js";
|
||||
// Main entry point
|
||||
export { main } from "./main.js";
|
||||
// Run modes for programmatic SDK usage
|
||||
export {
|
||||
InteractiveMode,
|
||||
type InteractiveModeOptions,
|
||||
type PrintModeOptions,
|
||||
runPrintMode,
|
||||
runRpcMode,
|
||||
} from "./modes/index.js";
|
||||
// UI components for extensions
|
||||
export {
|
||||
ArminComponent,
|
||||
AssistantMessageComponent,
|
||||
appKey,
|
||||
appKeyHint,
|
||||
BashExecutionComponent,
|
||||
BorderedLoader,
|
||||
BranchSummaryMessageComponent,
|
||||
CompactionSummaryMessageComponent,
|
||||
CustomEditor,
|
||||
CustomMessageComponent,
|
||||
DynamicBorder,
|
||||
ExtensionEditorComponent,
|
||||
ExtensionInputComponent,
|
||||
ExtensionSelectorComponent,
|
||||
editorKey,
|
||||
FooterComponent,
|
||||
keyHint,
|
||||
LoginDialogComponent,
|
||||
ModelSelectorComponent,
|
||||
OAuthSelectorComponent,
|
||||
type RenderDiffOptions,
|
||||
rawKeyHint,
|
||||
renderDiff,
|
||||
SessionSelectorComponent,
|
||||
type SettingsCallbacks,
|
||||
type SettingsConfig,
|
||||
SettingsSelectorComponent,
|
||||
ShowImagesSelectorComponent,
|
||||
SkillInvocationMessageComponent,
|
||||
ThemeSelectorComponent,
|
||||
ThinkingSelectorComponent,
|
||||
ToolExecutionComponent,
|
||||
type ToolExecutionOptions,
|
||||
TreeSelectorComponent,
|
||||
truncateToVisualLines,
|
||||
UserMessageComponent,
|
||||
UserMessageSelectorComponent,
|
||||
type VisualTruncateResult,
|
||||
} from "./modes/interactive/components/index.js";
|
||||
// Theme utilities for custom tools and extensions
|
||||
export {
|
||||
getLanguageFromPath,
|
||||
getMarkdownTheme,
|
||||
getSelectListTheme,
|
||||
getSettingsListTheme,
|
||||
highlightCode,
|
||||
initTheme,
|
||||
Theme,
|
||||
type ThemeColor,
|
||||
} from "./modes/interactive/theme/theme.js";
|
||||
// Clipboard utilities
|
||||
export { copyToClipboard } from "./utils/clipboard.js";
|
||||
export { parseFrontmatter, stripFrontmatter } from "./utils/frontmatter.js";
|
||||
// Shell utilities
|
||||
export { getShellConfig } from "./utils/shell.js";
|
||||
1098
packages/coding-agent/src/main.ts
Normal file
1098
packages/coding-agent/src/main.ts
Normal file
File diff suppressed because it is too large
Load diff
317
packages/coding-agent/src/migrations.ts
Normal file
317
packages/coding-agent/src/migrations.ts
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
/**
|
||||
* One-time migrations that run on startup.
|
||||
*/
|
||||
|
||||
import chalk from "chalk";
|
||||
import {
|
||||
existsSync,
|
||||
mkdirSync,
|
||||
readdirSync,
|
||||
readFileSync,
|
||||
renameSync,
|
||||
rmSync,
|
||||
writeFileSync,
|
||||
} from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import { CONFIG_DIR_NAME, getAgentDir, getBinDir } from "./config.js";
|
||||
|
||||
const MIGRATION_GUIDE_URL =
|
||||
"https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/CHANGELOG.md#extensions-migration";
|
||||
const EXTENSIONS_DOC_URL =
|
||||
"https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/docs/extensions.md";
|
||||
|
||||
/**
|
||||
* Migrate legacy oauth.json and settings.json apiKeys to auth.json.
|
||||
*
|
||||
* @returns Array of provider names that were migrated
|
||||
*/
|
||||
export function migrateAuthToAuthJson(): string[] {
|
||||
const agentDir = getAgentDir();
|
||||
const authPath = join(agentDir, "auth.json");
|
||||
const oauthPath = join(agentDir, "oauth.json");
|
||||
const settingsPath = join(agentDir, "settings.json");
|
||||
|
||||
// Skip if auth.json already exists
|
||||
if (existsSync(authPath)) return [];
|
||||
|
||||
const migrated: Record<string, unknown> = {};
|
||||
const providers: string[] = [];
|
||||
|
||||
// Migrate oauth.json
|
||||
if (existsSync(oauthPath)) {
|
||||
try {
|
||||
const oauth = JSON.parse(readFileSync(oauthPath, "utf-8"));
|
||||
for (const [provider, cred] of Object.entries(oauth)) {
|
||||
migrated[provider] = { type: "oauth", ...(cred as object) };
|
||||
providers.push(provider);
|
||||
}
|
||||
renameSync(oauthPath, `${oauthPath}.migrated`);
|
||||
} catch {
|
||||
// Skip on error
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate settings.json apiKeys
|
||||
if (existsSync(settingsPath)) {
|
||||
try {
|
||||
const content = readFileSync(settingsPath, "utf-8");
|
||||
const settings = JSON.parse(content);
|
||||
if (settings.apiKeys && typeof settings.apiKeys === "object") {
|
||||
for (const [provider, key] of Object.entries(settings.apiKeys)) {
|
||||
if (!migrated[provider] && typeof key === "string") {
|
||||
migrated[provider] = { type: "api_key", key };
|
||||
providers.push(provider);
|
||||
}
|
||||
}
|
||||
delete settings.apiKeys;
|
||||
writeFileSync(settingsPath, JSON.stringify(settings, null, 2));
|
||||
}
|
||||
} catch {
|
||||
// Skip on error
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(migrated).length > 0) {
|
||||
mkdirSync(dirname(authPath), { recursive: true });
|
||||
writeFileSync(authPath, JSON.stringify(migrated, null, 2), { mode: 0o600 });
|
||||
}
|
||||
|
||||
return providers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate sessions from ~/.pi/agent/*.jsonl to proper session directories.
|
||||
*
|
||||
* Bug in v0.30.0: Sessions were saved to ~/.pi/agent/ instead of
|
||||
* ~/.pi/agent/sessions/<encoded-cwd>/. This migration moves them
|
||||
* to the correct location based on the cwd in their session header.
|
||||
*
|
||||
* See: https://github.com/badlogic/pi-mono/issues/320
|
||||
*/
|
||||
export function migrateSessionsFromAgentRoot(): void {
|
||||
const agentDir = getAgentDir();
|
||||
|
||||
// Find all .jsonl files directly in agentDir (not in subdirectories)
|
||||
let files: string[];
|
||||
try {
|
||||
files = readdirSync(agentDir)
|
||||
.filter((f) => f.endsWith(".jsonl"))
|
||||
.map((f) => join(agentDir, f));
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (files.length === 0) return;
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
// Read first line to get session header
|
||||
const content = readFileSync(file, "utf8");
|
||||
const firstLine = content.split("\n")[0];
|
||||
if (!firstLine?.trim()) continue;
|
||||
|
||||
const header = JSON.parse(firstLine);
|
||||
if (header.type !== "session" || !header.cwd) continue;
|
||||
|
||||
const cwd: string = header.cwd;
|
||||
|
||||
// Compute the correct session directory (same encoding as session-manager.ts)
|
||||
const safePath = `--${cwd.replace(/^[/\\]/, "").replace(/[/\\:]/g, "-")}--`;
|
||||
const correctDir = join(agentDir, "sessions", safePath);
|
||||
|
||||
// Create directory if needed
|
||||
if (!existsSync(correctDir)) {
|
||||
mkdirSync(correctDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Move the file
|
||||
const fileName = file.split("/").pop() || file.split("\\").pop();
|
||||
const newPath = join(correctDir, fileName!);
|
||||
|
||||
if (existsSync(newPath)) continue; // Skip if target exists
|
||||
|
||||
renameSync(file, newPath);
|
||||
} catch {
|
||||
// Skip files that can't be migrated
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrate commands/ to prompts/ if needed.
|
||||
* Works for both regular directories and symlinks.
|
||||
*/
|
||||
function migrateCommandsToPrompts(baseDir: string, label: string): boolean {
|
||||
const commandsDir = join(baseDir, "commands");
|
||||
const promptsDir = join(baseDir, "prompts");
|
||||
|
||||
if (existsSync(commandsDir) && !existsSync(promptsDir)) {
|
||||
try {
|
||||
renameSync(commandsDir, promptsDir);
|
||||
console.log(chalk.green(`Migrated ${label} commands/ → prompts/`));
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.log(
|
||||
chalk.yellow(
|
||||
`Warning: Could not migrate ${label} commands/ to prompts/: ${err instanceof Error ? err.message : err}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Move fd/rg binaries from tools/ to bin/ if they exist.
|
||||
*/
|
||||
function migrateToolsToBin(): void {
|
||||
const agentDir = getAgentDir();
|
||||
const toolsDir = join(agentDir, "tools");
|
||||
const binDir = getBinDir();
|
||||
|
||||
if (!existsSync(toolsDir)) return;
|
||||
|
||||
const binaries = ["fd", "rg", "fd.exe", "rg.exe"];
|
||||
let movedAny = false;
|
||||
|
||||
for (const bin of binaries) {
|
||||
const oldPath = join(toolsDir, bin);
|
||||
const newPath = join(binDir, bin);
|
||||
|
||||
if (existsSync(oldPath)) {
|
||||
if (!existsSync(binDir)) {
|
||||
mkdirSync(binDir, { recursive: true });
|
||||
}
|
||||
if (!existsSync(newPath)) {
|
||||
try {
|
||||
renameSync(oldPath, newPath);
|
||||
movedAny = true;
|
||||
} catch {
|
||||
// Ignore errors
|
||||
}
|
||||
} else {
|
||||
// Target exists, just delete the old one
|
||||
try {
|
||||
rmSync?.(oldPath, { force: true });
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (movedAny) {
|
||||
console.log(chalk.green(`Migrated managed binaries tools/ → bin/`));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for deprecated hooks/ and tools/ directories.
|
||||
* Note: tools/ may contain fd/rg binaries extracted by pi, so only warn if it has other files.
|
||||
*/
|
||||
function checkDeprecatedExtensionDirs(
|
||||
baseDir: string,
|
||||
label: string,
|
||||
): string[] {
|
||||
const hooksDir = join(baseDir, "hooks");
|
||||
const toolsDir = join(baseDir, "tools");
|
||||
const warnings: string[] = [];
|
||||
|
||||
if (existsSync(hooksDir)) {
|
||||
warnings.push(
|
||||
`${label} hooks/ directory found. Hooks have been renamed to extensions.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (existsSync(toolsDir)) {
|
||||
// Check if tools/ contains anything other than fd/rg (which are auto-extracted binaries)
|
||||
try {
|
||||
const entries = readdirSync(toolsDir);
|
||||
const customTools = entries.filter((e) => {
|
||||
const lower = e.toLowerCase();
|
||||
return (
|
||||
lower !== "fd" &&
|
||||
lower !== "rg" &&
|
||||
lower !== "fd.exe" &&
|
||||
lower !== "rg.exe" &&
|
||||
!e.startsWith(".") // Ignore .DS_Store and other hidden files
|
||||
);
|
||||
});
|
||||
if (customTools.length > 0) {
|
||||
warnings.push(
|
||||
`${label} tools/ directory contains custom tools. Custom tools have been merged into extensions.`,
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
// Ignore read errors
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run extension system migrations (commands→prompts) and collect warnings about deprecated directories.
|
||||
*/
|
||||
function migrateExtensionSystem(cwd: string): string[] {
|
||||
const agentDir = getAgentDir();
|
||||
const projectDir = join(cwd, CONFIG_DIR_NAME);
|
||||
|
||||
// Migrate commands/ to prompts/
|
||||
migrateCommandsToPrompts(agentDir, "Global");
|
||||
migrateCommandsToPrompts(projectDir, "Project");
|
||||
|
||||
// Check for deprecated directories
|
||||
const warnings = [
|
||||
...checkDeprecatedExtensionDirs(agentDir, "Global"),
|
||||
...checkDeprecatedExtensionDirs(projectDir, "Project"),
|
||||
];
|
||||
|
||||
return warnings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Print deprecation warnings and wait for keypress.
|
||||
*/
|
||||
export async function showDeprecationWarnings(
|
||||
warnings: string[],
|
||||
): Promise<void> {
|
||||
if (warnings.length === 0) return;
|
||||
|
||||
for (const warning of warnings) {
|
||||
console.log(chalk.yellow(`Warning: ${warning}`));
|
||||
}
|
||||
console.log(
|
||||
chalk.yellow(`\nMove your extensions to the extensions/ directory.`),
|
||||
);
|
||||
console.log(chalk.yellow(`Migration guide: ${MIGRATION_GUIDE_URL}`));
|
||||
console.log(chalk.yellow(`Documentation: ${EXTENSIONS_DOC_URL}`));
|
||||
console.log(chalk.dim(`\nPress any key to continue...`));
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
process.stdin.setRawMode?.(true);
|
||||
process.stdin.resume();
|
||||
process.stdin.once("data", () => {
|
||||
process.stdin.setRawMode?.(false);
|
||||
process.stdin.pause();
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
console.log();
|
||||
}
|
||||
|
||||
/**
|
||||
* Run all migrations. Called once on startup.
|
||||
*
|
||||
* @returns Object with migration results and deprecation warnings
|
||||
*/
|
||||
export function runMigrations(cwd: string = process.cwd()): {
|
||||
migratedAuthProviders: string[];
|
||||
deprecationWarnings: string[];
|
||||
} {
|
||||
const migratedAuthProviders = migrateAuthToAuthJson();
|
||||
migrateSessionsFromAgentRoot();
|
||||
migrateToolsToBin();
|
||||
const deprecationWarnings = migrateExtensionSystem(cwd);
|
||||
return { migratedAuthProviders, deprecationWarnings };
|
||||
}
|
||||
233
packages/coding-agent/src/modes/daemon-mode.ts
Normal file
233
packages/coding-agent/src/modes/daemon-mode.ts
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
/**
|
||||
* Daemon mode (always-on background execution).
|
||||
*
|
||||
* Starts agent extensions, accepts messages from extension sources
|
||||
* (webhooks, queues, Telegram/Slack gateways, etc.), and stays alive
|
||||
* until explicitly stopped.
|
||||
*/
|
||||
|
||||
import type { ImageContent } from "@mariozechner/pi-ai";
|
||||
import type { AgentSession } from "../core/agent-session.js";
|
||||
import {
|
||||
GatewayRuntime,
|
||||
type GatewaySessionFactory,
|
||||
setActiveGatewayRuntime,
|
||||
} from "../core/gateway-runtime.js";
|
||||
import type { GatewaySettings } from "../core/settings-manager.js";
|
||||
|
||||
/**
|
||||
* Options for daemon mode.
|
||||
*/
|
||||
export interface DaemonModeOptions {
|
||||
/** First message to send at startup (can include @file content expansion by caller). */
|
||||
initialMessage?: string;
|
||||
/** Images to attach to the startup message. */
|
||||
initialImages?: ImageContent[];
|
||||
/** Additional startup messages (sent after initialMessage, one by one). */
|
||||
messages?: string[];
|
||||
/** Factory for creating additional gateway-owned sessions. */
|
||||
createSession: GatewaySessionFactory;
|
||||
/** Gateway config from settings/env. */
|
||||
gateway: GatewaySettings;
|
||||
}
|
||||
|
||||
export interface DaemonModeResult {
|
||||
reason: "shutdown";
|
||||
}
|
||||
|
||||
function createCommandContextActions(session: AgentSession) {
|
||||
return {
|
||||
waitForIdle: () => session.agent.waitForIdle(),
|
||||
newSession: async (options?: {
|
||||
parentSession?: string;
|
||||
setup?: (
|
||||
sessionManager: typeof session.sessionManager,
|
||||
) => Promise<void> | void;
|
||||
}) => {
|
||||
const success = await session.newSession({
|
||||
parentSession: options?.parentSession,
|
||||
});
|
||||
if (success && options?.setup) {
|
||||
await options.setup(session.sessionManager);
|
||||
}
|
||||
return { cancelled: !success };
|
||||
},
|
||||
fork: async (entryId: string) => {
|
||||
const result = await session.fork(entryId);
|
||||
return { cancelled: result.cancelled };
|
||||
},
|
||||
navigateTree: async (
|
||||
targetId: string,
|
||||
options?: {
|
||||
summarize?: boolean;
|
||||
customInstructions?: string;
|
||||
replaceInstructions?: boolean;
|
||||
label?: string;
|
||||
},
|
||||
) => {
|
||||
const result = await session.navigateTree(targetId, {
|
||||
summarize: options?.summarize,
|
||||
customInstructions: options?.customInstructions,
|
||||
replaceInstructions: options?.replaceInstructions,
|
||||
label: options?.label,
|
||||
});
|
||||
return { cancelled: result.cancelled };
|
||||
},
|
||||
switchSession: async (sessionPath: string) => {
|
||||
const success = await session.switchSession(sessionPath);
|
||||
return { cancelled: !success };
|
||||
},
|
||||
reload: async () => {
|
||||
await session.reload();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Run in daemon mode.
|
||||
* Stays alive indefinitely unless stopped by signal or extension trigger.
|
||||
*/
|
||||
export async function runDaemonMode(
|
||||
session: AgentSession,
|
||||
options: DaemonModeOptions,
|
||||
): Promise<DaemonModeResult> {
|
||||
const { initialMessage, initialImages, messages = [] } = options;
|
||||
let isShuttingDown = false;
|
||||
let resolveReady: (result: DaemonModeResult) => void = () => {};
|
||||
const ready = new Promise<DaemonModeResult>((resolve) => {
|
||||
resolveReady = resolve;
|
||||
});
|
||||
const gatewayBind =
|
||||
process.env.PI_GATEWAY_BIND ?? options.gateway.bind ?? "127.0.0.1";
|
||||
const gatewayPort =
|
||||
Number.parseInt(process.env.PI_GATEWAY_PORT ?? "", 10) ||
|
||||
options.gateway.port ||
|
||||
8787;
|
||||
const gatewayToken =
|
||||
process.env.PI_GATEWAY_TOKEN ?? options.gateway.bearerToken;
|
||||
const gateway = new GatewayRuntime({
|
||||
config: {
|
||||
bind: gatewayBind,
|
||||
port: gatewayPort,
|
||||
bearerToken: gatewayToken,
|
||||
session: {
|
||||
idleMinutes: options.gateway.session?.idleMinutes ?? 60,
|
||||
maxQueuePerSession: options.gateway.session?.maxQueuePerSession ?? 8,
|
||||
},
|
||||
webhook: {
|
||||
enabled: options.gateway.webhook?.enabled ?? true,
|
||||
basePath: options.gateway.webhook?.basePath ?? "/webhooks",
|
||||
secret:
|
||||
process.env.PI_GATEWAY_WEBHOOK_SECRET ??
|
||||
options.gateway.webhook?.secret,
|
||||
},
|
||||
},
|
||||
primarySessionKey: "web:main",
|
||||
primarySession: session,
|
||||
createSession: options.createSession,
|
||||
log: (message) => {
|
||||
console.error(`[pi-gateway] ${message}`);
|
||||
},
|
||||
});
|
||||
setActiveGatewayRuntime(gateway);
|
||||
|
||||
const shutdown = async (reason: "signal" | "extension"): Promise<void> => {
|
||||
if (isShuttingDown) return;
|
||||
isShuttingDown = true;
|
||||
|
||||
console.error(`[pi-gateway] shutdown requested: ${reason}`);
|
||||
setActiveGatewayRuntime(null);
|
||||
await gateway.stop();
|
||||
|
||||
const runner = session.extensionRunner;
|
||||
if (runner?.hasHandlers("session_shutdown")) {
|
||||
await runner.emit({ type: "session_shutdown" });
|
||||
}
|
||||
|
||||
session.dispose();
|
||||
resolveReady({ reason: "shutdown" });
|
||||
};
|
||||
|
||||
const handleShutdownSignal = (signal: NodeJS.Signals) => {
|
||||
void shutdown("signal").catch((error) => {
|
||||
console.error(
|
||||
`[pi-gateway] shutdown failed for ${signal}: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
resolveReady({ reason: "shutdown" });
|
||||
});
|
||||
};
|
||||
const sigintHandler = () => handleShutdownSignal("SIGINT");
|
||||
const sigtermHandler = () => handleShutdownSignal("SIGTERM");
|
||||
const sigquitHandler = () => handleShutdownSignal("SIGQUIT");
|
||||
const sighupHandler = () => handleShutdownSignal("SIGHUP");
|
||||
const unhandledRejectionHandler = (error: unknown) => {
|
||||
console.error(
|
||||
`[pi-gateway] unhandled rejection: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
};
|
||||
|
||||
process.once("SIGINT", sigintHandler);
|
||||
process.once("SIGTERM", sigtermHandler);
|
||||
process.once("SIGQUIT", sigquitHandler);
|
||||
process.once("SIGHUP", sighupHandler);
|
||||
process.on("unhandledRejection", unhandledRejectionHandler);
|
||||
|
||||
await session.bindExtensions({
|
||||
commandContextActions: createCommandContextActions(session),
|
||||
shutdownHandler: () => {
|
||||
void shutdown("extension").catch((error) => {
|
||||
console.error(
|
||||
`[pi-gateway] extension shutdown failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
resolveReady({ reason: "shutdown" });
|
||||
});
|
||||
},
|
||||
onError: (err) => {
|
||||
console.error(`Extension error (${err.extensionPath}): ${err.error}`);
|
||||
},
|
||||
});
|
||||
|
||||
// Emit structured events to stderr for supervisor logs.
|
||||
session.subscribe((event) => {
|
||||
console.error(
|
||||
JSON.stringify({
|
||||
type: event.type,
|
||||
sessionId: session.sessionId,
|
||||
messageCount: session.messages.length,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
// Startup probes/messages.
|
||||
if (initialMessage) {
|
||||
await session.prompt(initialMessage, { images: initialImages });
|
||||
}
|
||||
for (const message of messages) {
|
||||
await session.prompt(message);
|
||||
}
|
||||
|
||||
await gateway.start();
|
||||
console.error(
|
||||
`[pi-gateway] startup complete (session=${session.sessionId ?? "unknown"}, bind=${gatewayBind}, port=${gatewayPort})`,
|
||||
);
|
||||
|
||||
// Keep process alive forever.
|
||||
const keepAlive = setInterval(() => {
|
||||
// Intentionally keep the daemon event loop active.
|
||||
}, 1000);
|
||||
|
||||
const cleanup = () => {
|
||||
clearInterval(keepAlive);
|
||||
process.removeListener("SIGINT", sigintHandler);
|
||||
process.removeListener("SIGTERM", sigtermHandler);
|
||||
process.removeListener("SIGQUIT", sigquitHandler);
|
||||
process.removeListener("SIGHUP", sighupHandler);
|
||||
process.removeListener("unhandledRejection", unhandledRejectionHandler);
|
||||
};
|
||||
|
||||
try {
|
||||
return await ready;
|
||||
} finally {
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
26
packages/coding-agent/src/modes/index.ts
Normal file
26
packages/coding-agent/src/modes/index.ts
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Run modes for the coding agent.
|
||||
*/
|
||||
|
||||
export {
|
||||
type DaemonModeOptions,
|
||||
type DaemonModeResult,
|
||||
runDaemonMode,
|
||||
} from "./daemon-mode.js";
|
||||
export {
|
||||
InteractiveMode,
|
||||
type InteractiveModeOptions,
|
||||
} from "./interactive/interactive-mode.js";
|
||||
export { type PrintModeOptions, runPrintMode } from "./print-mode.js";
|
||||
export {
|
||||
type ModelInfo,
|
||||
RpcClient,
|
||||
type RpcClientOptions,
|
||||
type RpcEventListener,
|
||||
} from "./rpc/rpc-client.js";
|
||||
export { runRpcMode } from "./rpc/rpc-mode.js";
|
||||
export type {
|
||||
RpcCommand,
|
||||
RpcResponse,
|
||||
RpcSessionState,
|
||||
} from "./rpc/rpc-types.js";
|
||||
422
packages/coding-agent/src/modes/interactive/components/armin.ts
Normal file
422
packages/coding-agent/src/modes/interactive/components/armin.ts
Normal file
|
|
@ -0,0 +1,422 @@
|
|||
/**
|
||||
* Armin says hi! A fun easter egg with animated XBM art.
|
||||
*/
|
||||
|
||||
import type { Component, TUI } from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
// XBM image: 31x36 pixels, LSB first, 1=background, 0=foreground
|
||||
const WIDTH = 31;
|
||||
const HEIGHT = 36;
|
||||
const BITS = [
|
||||
0xff, 0xff, 0xff, 0x7f, 0xff, 0xf0, 0xff, 0x7f, 0xff, 0xed, 0xff, 0x7f, 0xff,
|
||||
0xdb, 0xff, 0x7f, 0xff, 0xb7, 0xff, 0x7f, 0xff, 0x77, 0xfe, 0x7f, 0x3f, 0xf8,
|
||||
0xfe, 0x7f, 0xdf, 0xff, 0xfe, 0x7f, 0xdf, 0x3f, 0xfc, 0x7f, 0x9f, 0xc3, 0xfb,
|
||||
0x7f, 0x6f, 0xfc, 0xf4, 0x7f, 0xf7, 0x0f, 0xf7, 0x7f, 0xf7, 0xff, 0xf7, 0x7f,
|
||||
0xf7, 0xff, 0xe3, 0x7f, 0xf7, 0x07, 0xe8, 0x7f, 0xef, 0xf8, 0x67, 0x70, 0x0f,
|
||||
0xff, 0xbb, 0x6f, 0xf1, 0x00, 0xd0, 0x5b, 0xfd, 0x3f, 0xec, 0x53, 0xc1, 0xff,
|
||||
0xef, 0x57, 0x9f, 0xfd, 0xee, 0x5f, 0x9f, 0xfc, 0xae, 0x5f, 0x1f, 0x78, 0xac,
|
||||
0x5f, 0x3f, 0x00, 0x50, 0x6c, 0x7f, 0x00, 0xdc, 0x77, 0xff, 0xc0, 0x3f, 0x78,
|
||||
0xff, 0x01, 0xf8, 0x7f, 0xff, 0x03, 0x9c, 0x78, 0xff, 0x07, 0x8c, 0x7c, 0xff,
|
||||
0x0f, 0xce, 0x78, 0xff, 0xff, 0xcf, 0x7f, 0xff, 0xff, 0xcf, 0x78, 0xff, 0xff,
|
||||
0xdf, 0x78, 0xff, 0xff, 0xdf, 0x7d, 0xff, 0xff, 0x3f, 0x7e, 0xff, 0xff, 0xff,
|
||||
0x7f,
|
||||
];
|
||||
|
||||
const BYTES_PER_ROW = Math.ceil(WIDTH / 8);
|
||||
const DISPLAY_HEIGHT = Math.ceil(HEIGHT / 2); // Half-block rendering
|
||||
|
||||
type Effect =
|
||||
| "typewriter"
|
||||
| "scanline"
|
||||
| "rain"
|
||||
| "fade"
|
||||
| "crt"
|
||||
| "glitch"
|
||||
| "dissolve";
|
||||
|
||||
const EFFECTS: Effect[] = [
|
||||
"typewriter",
|
||||
"scanline",
|
||||
"rain",
|
||||
"fade",
|
||||
"crt",
|
||||
"glitch",
|
||||
"dissolve",
|
||||
];
|
||||
|
||||
// Get pixel at (x, y): true = foreground, false = background
|
||||
function getPixel(x: number, y: number): boolean {
|
||||
if (y >= HEIGHT) return false;
|
||||
const byteIndex = y * BYTES_PER_ROW + Math.floor(x / 8);
|
||||
const bitIndex = x % 8;
|
||||
return ((BITS[byteIndex] >> bitIndex) & 1) === 0;
|
||||
}
|
||||
|
||||
// Get the character for a cell (2 vertical pixels packed)
|
||||
function getChar(x: number, row: number): string {
|
||||
const upper = getPixel(x, row * 2);
|
||||
const lower = getPixel(x, row * 2 + 1);
|
||||
if (upper && lower) return "█";
|
||||
if (upper) return "▀";
|
||||
if (lower) return "▄";
|
||||
return " ";
|
||||
}
|
||||
|
||||
// Build the final image grid
|
||||
function buildFinalGrid(): string[][] {
|
||||
const grid: string[][] = [];
|
||||
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
|
||||
const line: string[] = [];
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
line.push(getChar(x, row));
|
||||
}
|
||||
grid.push(line);
|
||||
}
|
||||
return grid;
|
||||
}
|
||||
|
||||
export class ArminComponent implements Component {
|
||||
private ui: TUI;
|
||||
private interval: ReturnType<typeof setInterval> | null = null;
|
||||
private effect: Effect;
|
||||
private finalGrid: string[][];
|
||||
private currentGrid: string[][];
|
||||
private effectState: Record<string, unknown> = {};
|
||||
private cachedLines: string[] = [];
|
||||
private cachedWidth = 0;
|
||||
private gridVersion = 0;
|
||||
private cachedVersion = -1;
|
||||
|
||||
constructor(ui: TUI) {
|
||||
this.ui = ui;
|
||||
this.effect = EFFECTS[Math.floor(Math.random() * EFFECTS.length)];
|
||||
this.finalGrid = buildFinalGrid();
|
||||
this.currentGrid = this.createEmptyGrid();
|
||||
|
||||
this.initEffect();
|
||||
this.startAnimation();
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
this.cachedWidth = 0;
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
if (width === this.cachedWidth && this.cachedVersion === this.gridVersion) {
|
||||
return this.cachedLines;
|
||||
}
|
||||
|
||||
const padding = 1;
|
||||
const availableWidth = width - padding;
|
||||
|
||||
this.cachedLines = this.currentGrid.map((row) => {
|
||||
// Clip row to available width before applying color
|
||||
const clipped = row.slice(0, availableWidth).join("");
|
||||
const padRight = Math.max(0, width - padding - clipped.length);
|
||||
return ` ${theme.fg("accent", clipped)}${" ".repeat(padRight)}`;
|
||||
});
|
||||
|
||||
// Add "ARMIN SAYS HI" at the end
|
||||
const message = "ARMIN SAYS HI";
|
||||
const msgPadRight = Math.max(0, width - padding - message.length);
|
||||
this.cachedLines.push(
|
||||
` ${theme.fg("accent", message)}${" ".repeat(msgPadRight)}`,
|
||||
);
|
||||
|
||||
this.cachedWidth = width;
|
||||
this.cachedVersion = this.gridVersion;
|
||||
|
||||
return this.cachedLines;
|
||||
}
|
||||
|
||||
private createEmptyGrid(): string[][] {
|
||||
return Array.from({ length: DISPLAY_HEIGHT }, () => Array(WIDTH).fill(" "));
|
||||
}
|
||||
|
||||
private initEffect(): void {
|
||||
switch (this.effect) {
|
||||
case "typewriter":
|
||||
this.effectState = { pos: 0 };
|
||||
break;
|
||||
case "scanline":
|
||||
this.effectState = { row: 0 };
|
||||
break;
|
||||
case "rain":
|
||||
// Track falling position for each column
|
||||
this.effectState = {
|
||||
drops: Array.from({ length: WIDTH }, () => ({
|
||||
y: -Math.floor(Math.random() * DISPLAY_HEIGHT * 2),
|
||||
settled: 0,
|
||||
})),
|
||||
};
|
||||
break;
|
||||
case "fade": {
|
||||
// Shuffle all pixel positions
|
||||
const positions: [number, number][] = [];
|
||||
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
positions.push([row, x]);
|
||||
}
|
||||
}
|
||||
// Fisher-Yates shuffle
|
||||
for (let i = positions.length - 1; i > 0; i--) {
|
||||
const j = Math.floor(Math.random() * (i + 1));
|
||||
[positions[i], positions[j]] = [positions[j], positions[i]];
|
||||
}
|
||||
this.effectState = { positions, idx: 0 };
|
||||
break;
|
||||
}
|
||||
case "crt":
|
||||
this.effectState = { expansion: 0 };
|
||||
break;
|
||||
case "glitch":
|
||||
this.effectState = { phase: 0, glitchFrames: 8 };
|
||||
break;
|
||||
case "dissolve": {
|
||||
// Start with random noise
|
||||
this.currentGrid = Array.from({ length: DISPLAY_HEIGHT }, () =>
|
||||
Array.from({ length: WIDTH }, () => {
|
||||
const chars = [" ", "░", "▒", "▓", "█", "▀", "▄"];
|
||||
return chars[Math.floor(Math.random() * chars.length)];
|
||||
}),
|
||||
);
|
||||
// Shuffle positions for gradual resolve
|
||||
const dissolvePositions: [number, number][] = [];
|
||||
for (let row = 0; row < DISPLAY_HEIGHT; row++) {
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
dissolvePositions.push([row, x]);
|
||||
}
|
||||
}
|
||||
for (let i = dissolvePositions.length - 1; i > 0; i--) {
|
||||
const j = Math.floor(Math.random() * (i + 1));
|
||||
[dissolvePositions[i], dissolvePositions[j]] = [
|
||||
dissolvePositions[j],
|
||||
dissolvePositions[i],
|
||||
];
|
||||
}
|
||||
this.effectState = { positions: dissolvePositions, idx: 0 };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private startAnimation(): void {
|
||||
const fps = this.effect === "glitch" ? 60 : 30;
|
||||
this.interval = setInterval(() => {
|
||||
const done = this.tickEffect();
|
||||
this.updateDisplay();
|
||||
this.ui.requestRender();
|
||||
if (done) {
|
||||
this.stopAnimation();
|
||||
}
|
||||
}, 1000 / fps);
|
||||
}
|
||||
|
||||
private stopAnimation(): void {
|
||||
if (this.interval) {
|
||||
clearInterval(this.interval);
|
||||
this.interval = null;
|
||||
}
|
||||
}
|
||||
|
||||
private tickEffect(): boolean {
|
||||
switch (this.effect) {
|
||||
case "typewriter":
|
||||
return this.tickTypewriter();
|
||||
case "scanline":
|
||||
return this.tickScanline();
|
||||
case "rain":
|
||||
return this.tickRain();
|
||||
case "fade":
|
||||
return this.tickFade();
|
||||
case "crt":
|
||||
return this.tickCrt();
|
||||
case "glitch":
|
||||
return this.tickGlitch();
|
||||
case "dissolve":
|
||||
return this.tickDissolve();
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private tickTypewriter(): boolean {
|
||||
const state = this.effectState as { pos: number };
|
||||
const pixelsPerFrame = 3;
|
||||
|
||||
for (let i = 0; i < pixelsPerFrame; i++) {
|
||||
const row = Math.floor(state.pos / WIDTH);
|
||||
const x = state.pos % WIDTH;
|
||||
if (row >= DISPLAY_HEIGHT) return true;
|
||||
this.currentGrid[row][x] = this.finalGrid[row][x];
|
||||
state.pos++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private tickScanline(): boolean {
|
||||
const state = this.effectState as { row: number };
|
||||
if (state.row >= DISPLAY_HEIGHT) return true;
|
||||
|
||||
// Copy row
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
this.currentGrid[state.row][x] = this.finalGrid[state.row][x];
|
||||
}
|
||||
state.row++;
|
||||
return false;
|
||||
}
|
||||
|
||||
private tickRain(): boolean {
|
||||
const state = this.effectState as {
|
||||
drops: { y: number; settled: number }[];
|
||||
};
|
||||
|
||||
let allSettled = true;
|
||||
this.currentGrid = this.createEmptyGrid();
|
||||
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
const drop = state.drops[x];
|
||||
|
||||
// Draw settled pixels
|
||||
for (
|
||||
let row = DISPLAY_HEIGHT - 1;
|
||||
row >= DISPLAY_HEIGHT - drop.settled;
|
||||
row--
|
||||
) {
|
||||
if (row >= 0) {
|
||||
this.currentGrid[row][x] = this.finalGrid[row][x];
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this column is done
|
||||
if (drop.settled >= DISPLAY_HEIGHT) continue;
|
||||
|
||||
allSettled = false;
|
||||
|
||||
// Find the target row for this column (lowest non-space pixel)
|
||||
let targetRow = -1;
|
||||
for (let row = DISPLAY_HEIGHT - 1 - drop.settled; row >= 0; row--) {
|
||||
if (this.finalGrid[row][x] !== " ") {
|
||||
targetRow = row;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Move drop down
|
||||
drop.y++;
|
||||
|
||||
// Draw falling drop
|
||||
if (drop.y >= 0 && drop.y < DISPLAY_HEIGHT) {
|
||||
if (targetRow >= 0 && drop.y >= targetRow) {
|
||||
// Settle
|
||||
drop.settled = DISPLAY_HEIGHT - targetRow;
|
||||
drop.y = -Math.floor(Math.random() * 5) - 1;
|
||||
} else {
|
||||
// Still falling
|
||||
this.currentGrid[drop.y][x] = "▓";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allSettled;
|
||||
}
|
||||
|
||||
private tickFade(): boolean {
|
||||
const state = this.effectState as {
|
||||
positions: [number, number][];
|
||||
idx: number;
|
||||
};
|
||||
const pixelsPerFrame = 15;
|
||||
|
||||
for (let i = 0; i < pixelsPerFrame; i++) {
|
||||
if (state.idx >= state.positions.length) return true;
|
||||
const [row, x] = state.positions[state.idx];
|
||||
this.currentGrid[row][x] = this.finalGrid[row][x];
|
||||
state.idx++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private tickCrt(): boolean {
|
||||
const state = this.effectState as { expansion: number };
|
||||
const midRow = Math.floor(DISPLAY_HEIGHT / 2);
|
||||
|
||||
this.currentGrid = this.createEmptyGrid();
|
||||
|
||||
// Draw from middle expanding outward
|
||||
const top = midRow - state.expansion;
|
||||
const bottom = midRow + state.expansion;
|
||||
|
||||
for (
|
||||
let row = Math.max(0, top);
|
||||
row <= Math.min(DISPLAY_HEIGHT - 1, bottom);
|
||||
row++
|
||||
) {
|
||||
for (let x = 0; x < WIDTH; x++) {
|
||||
this.currentGrid[row][x] = this.finalGrid[row][x];
|
||||
}
|
||||
}
|
||||
|
||||
state.expansion++;
|
||||
return state.expansion > DISPLAY_HEIGHT;
|
||||
}
|
||||
|
||||
private tickGlitch(): boolean {
|
||||
const state = this.effectState as { phase: number; glitchFrames: number };
|
||||
|
||||
if (state.phase < state.glitchFrames) {
|
||||
// Glitch phase: show corrupted version
|
||||
this.currentGrid = this.finalGrid.map((row) => {
|
||||
const offset = Math.floor(Math.random() * 7) - 3;
|
||||
const glitchRow = [...row];
|
||||
|
||||
// Random horizontal offset
|
||||
if (Math.random() < 0.3) {
|
||||
const shifted = glitchRow
|
||||
.slice(offset)
|
||||
.concat(glitchRow.slice(0, offset));
|
||||
return shifted.slice(0, WIDTH);
|
||||
}
|
||||
|
||||
// Random vertical swap
|
||||
if (Math.random() < 0.2) {
|
||||
const swapRow = Math.floor(Math.random() * DISPLAY_HEIGHT);
|
||||
return [...this.finalGrid[swapRow]];
|
||||
}
|
||||
|
||||
return glitchRow;
|
||||
});
|
||||
state.phase++;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Final frame: show clean image
|
||||
this.currentGrid = this.finalGrid.map((row) => [...row]);
|
||||
return true;
|
||||
}
|
||||
|
||||
private tickDissolve(): boolean {
|
||||
const state = this.effectState as {
|
||||
positions: [number, number][];
|
||||
idx: number;
|
||||
};
|
||||
const pixelsPerFrame = 20;
|
||||
|
||||
for (let i = 0; i < pixelsPerFrame; i++) {
|
||||
if (state.idx >= state.positions.length) return true;
|
||||
const [row, x] = state.positions[state.idx];
|
||||
this.currentGrid[row][x] = this.finalGrid[row][x];
|
||||
state.idx++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.gridVersion++;
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.stopAnimation();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
Container,
|
||||
Markdown,
|
||||
type MarkdownTheme,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a complete assistant message
|
||||
*/
|
||||
export class AssistantMessageComponent extends Container {
|
||||
private contentContainer: Container;
|
||||
private hideThinkingBlock: boolean;
|
||||
private markdownTheme: MarkdownTheme;
|
||||
private lastMessage?: AssistantMessage;
|
||||
|
||||
constructor(
|
||||
message?: AssistantMessage,
|
||||
hideThinkingBlock = false,
|
||||
markdownTheme: MarkdownTheme = getMarkdownTheme(),
|
||||
) {
|
||||
super();
|
||||
|
||||
this.hideThinkingBlock = hideThinkingBlock;
|
||||
this.markdownTheme = markdownTheme;
|
||||
|
||||
// Container for text/thinking content
|
||||
this.contentContainer = new Container();
|
||||
this.addChild(this.contentContainer);
|
||||
|
||||
if (message) {
|
||||
this.updateContent(message);
|
||||
}
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
if (this.lastMessage) {
|
||||
this.updateContent(this.lastMessage);
|
||||
}
|
||||
}
|
||||
|
||||
setHideThinkingBlock(hide: boolean): void {
|
||||
this.hideThinkingBlock = hide;
|
||||
}
|
||||
|
||||
updateContent(message: AssistantMessage): void {
|
||||
this.lastMessage = message;
|
||||
|
||||
// Clear content container
|
||||
this.contentContainer.clear();
|
||||
|
||||
const hasVisibleContent = message.content.some(
|
||||
(c) =>
|
||||
(c.type === "text" && c.text.trim()) ||
|
||||
(c.type === "thinking" && c.thinking.trim()),
|
||||
);
|
||||
|
||||
if (hasVisibleContent) {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
}
|
||||
|
||||
// Render content in order
|
||||
for (let i = 0; i < message.content.length; i++) {
|
||||
const content = message.content[i];
|
||||
if (content.type === "text" && content.text.trim()) {
|
||||
// Assistant text messages with no background - trim the text
|
||||
// Set paddingY=0 to avoid extra spacing before tool executions
|
||||
this.contentContainer.addChild(
|
||||
new Markdown(content.text.trim(), 1, 0, this.markdownTheme),
|
||||
);
|
||||
} else if (content.type === "thinking" && content.thinking.trim()) {
|
||||
// Add spacing only when another visible assistant content block follows.
|
||||
// This avoids a superfluous blank line before separately-rendered tool execution blocks.
|
||||
const hasVisibleContentAfter = message.content
|
||||
.slice(i + 1)
|
||||
.some(
|
||||
(c) =>
|
||||
(c.type === "text" && c.text.trim()) ||
|
||||
(c.type === "thinking" && c.thinking.trim()),
|
||||
);
|
||||
|
||||
if (this.hideThinkingBlock) {
|
||||
// Show static "Thinking..." label when hidden
|
||||
this.contentContainer.addChild(
|
||||
new Text(
|
||||
theme.italic(theme.fg("thinkingText", "Thinking...")),
|
||||
1,
|
||||
0,
|
||||
),
|
||||
);
|
||||
if (hasVisibleContentAfter) {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
}
|
||||
} else {
|
||||
// Thinking traces in thinkingText color, italic
|
||||
this.contentContainer.addChild(
|
||||
new Markdown(content.thinking.trim(), 1, 0, this.markdownTheme, {
|
||||
color: (text: string) => theme.fg("thinkingText", text),
|
||||
italic: true,
|
||||
}),
|
||||
);
|
||||
if (hasVisibleContentAfter) {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if aborted - show after partial content
|
||||
// But only if there are no tool calls (tool execution components will show the error)
|
||||
const hasToolCalls = message.content.some((c) => c.type === "toolCall");
|
||||
if (!hasToolCalls) {
|
||||
if (message.stopReason === "aborted") {
|
||||
const abortMessage =
|
||||
message.errorMessage && message.errorMessage !== "Request was aborted"
|
||||
? message.errorMessage
|
||||
: "Operation aborted";
|
||||
if (hasVisibleContent) {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
} else {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
}
|
||||
this.contentContainer.addChild(
|
||||
new Text(theme.fg("error", abortMessage), 1, 0),
|
||||
);
|
||||
} else if (message.stopReason === "error") {
|
||||
const errorMsg = message.errorMessage || "Unknown error";
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(
|
||||
new Text(theme.fg("error", `Error: ${errorMsg}`), 1, 0),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
/**
|
||||
* Component for displaying bash command execution with streaming output.
|
||||
*/
|
||||
|
||||
import {
|
||||
Container,
|
||||
Loader,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import stripAnsi from "strip-ansi";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
type TruncationResult,
|
||||
truncateTail,
|
||||
} from "../../../core/tools/truncate.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { editorKey, keyHint } from "./keybinding-hints.js";
|
||||
import { truncateToVisualLines } from "./visual-truncate.js";
|
||||
|
||||
// Preview line limit when not expanded (matches tool execution behavior)
|
||||
const PREVIEW_LINES = 20;
|
||||
|
||||
export class BashExecutionComponent extends Container {
|
||||
private command: string;
|
||||
private outputLines: string[] = [];
|
||||
private status: "running" | "complete" | "cancelled" | "error" = "running";
|
||||
private exitCode: number | undefined = undefined;
|
||||
private loader: Loader;
|
||||
private truncationResult?: TruncationResult;
|
||||
private fullOutputPath?: string;
|
||||
private expanded = false;
|
||||
private contentContainer: Container;
|
||||
private ui: TUI;
|
||||
|
||||
constructor(command: string, ui: TUI, excludeFromContext = false) {
|
||||
super();
|
||||
this.command = command;
|
||||
this.ui = ui;
|
||||
|
||||
// Use dim border for excluded-from-context commands (!! prefix)
|
||||
const colorKey = excludeFromContext ? "dim" : "bashMode";
|
||||
const borderColor = (str: string) => theme.fg(colorKey, str);
|
||||
|
||||
// Add spacer
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Top border
|
||||
this.addChild(new DynamicBorder(borderColor));
|
||||
|
||||
// Content container (holds dynamic content between borders)
|
||||
this.contentContainer = new Container();
|
||||
this.addChild(this.contentContainer);
|
||||
|
||||
// Command header
|
||||
const header = new Text(
|
||||
theme.fg(colorKey, theme.bold(`$ ${command}`)),
|
||||
1,
|
||||
0,
|
||||
);
|
||||
this.contentContainer.addChild(header);
|
||||
|
||||
// Loader
|
||||
this.loader = new Loader(
|
||||
ui,
|
||||
(spinner) => theme.fg(colorKey, spinner),
|
||||
(text) => theme.fg("muted", text),
|
||||
`Running... (${editorKey("selectCancel")} to cancel)`, // Plain text for loader
|
||||
);
|
||||
this.contentContainer.addChild(this.loader);
|
||||
|
||||
// Bottom border
|
||||
this.addChild(new DynamicBorder(borderColor));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set whether the output is expanded (shows full output) or collapsed (preview only).
|
||||
*/
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
appendOutput(chunk: string): void {
|
||||
// Strip ANSI codes and normalize line endings
|
||||
// Note: binary data is already sanitized in tui-renderer.ts executeBashCommand
|
||||
const clean = stripAnsi(chunk).replace(/\r\n/g, "\n").replace(/\r/g, "\n");
|
||||
|
||||
// Append to output lines
|
||||
const newLines = clean.split("\n");
|
||||
if (this.outputLines.length > 0 && newLines.length > 0) {
|
||||
// Append first chunk to last line (incomplete line continuation)
|
||||
this.outputLines[this.outputLines.length - 1] += newLines[0];
|
||||
this.outputLines.push(...newLines.slice(1));
|
||||
} else {
|
||||
this.outputLines.push(...newLines);
|
||||
}
|
||||
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setComplete(
|
||||
exitCode: number | undefined,
|
||||
cancelled: boolean,
|
||||
truncationResult?: TruncationResult,
|
||||
fullOutputPath?: string,
|
||||
): void {
|
||||
this.exitCode = exitCode;
|
||||
this.status = cancelled
|
||||
? "cancelled"
|
||||
: exitCode !== 0 && exitCode !== undefined && exitCode !== null
|
||||
? "error"
|
||||
: "complete";
|
||||
this.truncationResult = truncationResult;
|
||||
this.fullOutputPath = fullOutputPath;
|
||||
|
||||
// Stop loader
|
||||
this.loader.stop();
|
||||
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
// Apply truncation for LLM context limits (same limits as bash tool)
|
||||
const fullOutput = this.outputLines.join("\n");
|
||||
const contextTruncation = truncateTail(fullOutput, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
|
||||
// Get the lines to potentially display (after context truncation)
|
||||
const availableLines = contextTruncation.content
|
||||
? contextTruncation.content.split("\n")
|
||||
: [];
|
||||
|
||||
// Apply preview truncation based on expanded state
|
||||
const previewLogicalLines = availableLines.slice(-PREVIEW_LINES);
|
||||
const hiddenLineCount = availableLines.length - previewLogicalLines.length;
|
||||
|
||||
// Rebuild content container
|
||||
this.contentContainer.clear();
|
||||
|
||||
// Command header
|
||||
const header = new Text(
|
||||
theme.fg("bashMode", theme.bold(`$ ${this.command}`)),
|
||||
1,
|
||||
0,
|
||||
);
|
||||
this.contentContainer.addChild(header);
|
||||
|
||||
// Output
|
||||
if (availableLines.length > 0) {
|
||||
if (this.expanded) {
|
||||
// Show all lines
|
||||
const displayText = availableLines
|
||||
.map((line) => theme.fg("muted", line))
|
||||
.join("\n");
|
||||
this.contentContainer.addChild(new Text(`\n${displayText}`, 1, 0));
|
||||
} else {
|
||||
// Use shared visual truncation utility
|
||||
const styledOutput = previewLogicalLines
|
||||
.map((line) => theme.fg("muted", line))
|
||||
.join("\n");
|
||||
const { visualLines } = truncateToVisualLines(
|
||||
`\n${styledOutput}`,
|
||||
PREVIEW_LINES,
|
||||
this.ui.terminal.columns,
|
||||
1, // padding
|
||||
);
|
||||
this.contentContainer.addChild({
|
||||
render: () => visualLines,
|
||||
invalidate: () => {},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Loader or status
|
||||
if (this.status === "running") {
|
||||
this.contentContainer.addChild(this.loader);
|
||||
} else {
|
||||
const statusParts: string[] = [];
|
||||
|
||||
// Show how many lines are hidden (collapsed preview)
|
||||
if (hiddenLineCount > 0) {
|
||||
if (this.expanded) {
|
||||
statusParts.push(`(${keyHint("expandTools", "to collapse")})`);
|
||||
} else {
|
||||
statusParts.push(
|
||||
`${theme.fg("muted", `... ${hiddenLineCount} more lines`)} (${keyHint("expandTools", "to expand")})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (this.status === "cancelled") {
|
||||
statusParts.push(theme.fg("warning", "(cancelled)"));
|
||||
} else if (this.status === "error") {
|
||||
statusParts.push(theme.fg("error", `(exit ${this.exitCode})`));
|
||||
}
|
||||
|
||||
// Add truncation warning (context truncation, not preview truncation)
|
||||
const wasTruncated =
|
||||
this.truncationResult?.truncated || contextTruncation.truncated;
|
||||
if (wasTruncated && this.fullOutputPath) {
|
||||
statusParts.push(
|
||||
theme.fg(
|
||||
"warning",
|
||||
`Output truncated. Full output: ${this.fullOutputPath}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (statusParts.length > 0) {
|
||||
this.contentContainer.addChild(
|
||||
new Text(`\n${statusParts.join("\n")}`, 1, 0),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the raw output for creating BashExecutionMessage.
|
||||
*/
|
||||
getOutput(): string {
|
||||
return this.outputLines.join("\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the command that was executed.
|
||||
*/
|
||||
getCommand(): string {
|
||||
return this.command;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
import {
|
||||
CancellableLoader,
|
||||
Container,
|
||||
Loader,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { Theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { keyHint } from "./keybinding-hints.js";
|
||||
|
||||
/** Loader wrapped with borders for extension UI */
|
||||
export class BorderedLoader extends Container {
|
||||
private loader: CancellableLoader | Loader;
|
||||
private cancellable: boolean;
|
||||
private signalController?: AbortController;
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
theme: Theme,
|
||||
message: string,
|
||||
options?: { cancellable?: boolean },
|
||||
) {
|
||||
super();
|
||||
this.cancellable = options?.cancellable ?? true;
|
||||
const borderColor = (s: string) => theme.fg("border", s);
|
||||
this.addChild(new DynamicBorder(borderColor));
|
||||
if (this.cancellable) {
|
||||
this.loader = new CancellableLoader(
|
||||
tui,
|
||||
(s) => theme.fg("accent", s),
|
||||
(s) => theme.fg("muted", s),
|
||||
message,
|
||||
);
|
||||
} else {
|
||||
this.signalController = new AbortController();
|
||||
this.loader = new Loader(
|
||||
tui,
|
||||
(s) => theme.fg("accent", s),
|
||||
(s) => theme.fg("muted", s),
|
||||
message,
|
||||
);
|
||||
}
|
||||
this.addChild(this.loader);
|
||||
if (this.cancellable) {
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new Text(keyHint("selectCancel", "cancel"), 1, 0));
|
||||
}
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder(borderColor));
|
||||
}
|
||||
|
||||
get signal(): AbortSignal {
|
||||
if (this.cancellable) {
|
||||
return (this.loader as CancellableLoader).signal;
|
||||
}
|
||||
return this.signalController?.signal ?? new AbortController().signal;
|
||||
}
|
||||
|
||||
set onAbort(fn: (() => void) | undefined) {
|
||||
if (this.cancellable) {
|
||||
(this.loader as CancellableLoader).onAbort = fn;
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
if (this.cancellable) {
|
||||
(this.loader as CancellableLoader).handleInput(data);
|
||||
}
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
if ("dispose" in this.loader && typeof this.loader.dispose === "function") {
|
||||
this.loader.dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import {
|
||||
Box,
|
||||
Markdown,
|
||||
type MarkdownTheme,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { BranchSummaryMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
import { editorKey } from "./keybinding-hints.js";
|
||||
|
||||
/**
|
||||
* Component that renders a branch summary message with collapsed/expanded state.
|
||||
* Uses same background color as custom messages for visual consistency.
|
||||
*/
|
||||
export class BranchSummaryMessageComponent extends Box {
|
||||
private expanded = false;
|
||||
private message: BranchSummaryMessage;
|
||||
private markdownTheme: MarkdownTheme;
|
||||
|
||||
constructor(
|
||||
message: BranchSummaryMessage,
|
||||
markdownTheme: MarkdownTheme = getMarkdownTheme(),
|
||||
) {
|
||||
super(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
this.message = message;
|
||||
this.markdownTheme = markdownTheme;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[branch]\x1b[22m`);
|
||||
this.addChild(new Text(label, 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (this.expanded) {
|
||||
const header = "**Branch Summary**\n\n";
|
||||
this.addChild(
|
||||
new Markdown(header + this.message.summary, 0, 0, this.markdownTheme, {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
this.addChild(
|
||||
new Text(
|
||||
theme.fg("customMessageText", "Branch summary (") +
|
||||
theme.fg("dim", editorKey("expandTools")) +
|
||||
theme.fg("customMessageText", " to expand)"),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import {
|
||||
Box,
|
||||
Markdown,
|
||||
type MarkdownTheme,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { CompactionSummaryMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
import { editorKey } from "./keybinding-hints.js";
|
||||
|
||||
/**
|
||||
* Component that renders a compaction message with collapsed/expanded state.
|
||||
* Uses same background color as custom messages for visual consistency.
|
||||
*/
|
||||
export class CompactionSummaryMessageComponent extends Box {
|
||||
private expanded = false;
|
||||
private message: CompactionSummaryMessage;
|
||||
private markdownTheme: MarkdownTheme;
|
||||
|
||||
constructor(
|
||||
message: CompactionSummaryMessage,
|
||||
markdownTheme: MarkdownTheme = getMarkdownTheme(),
|
||||
) {
|
||||
super(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
this.message = message;
|
||||
this.markdownTheme = markdownTheme;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
const tokenStr = this.message.tokensBefore.toLocaleString();
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[compaction]\x1b[22m`);
|
||||
this.addChild(new Text(label, 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (this.expanded) {
|
||||
const header = `**Compacted from ${tokenStr} tokens**\n\n`;
|
||||
this.addChild(
|
||||
new Markdown(header + this.message.summary, 0, 0, this.markdownTheme, {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
this.addChild(
|
||||
new Text(
|
||||
theme.fg("customMessageText", `Compacted from ${tokenStr} tokens (`) +
|
||||
theme.fg("dim", editorKey("expandTools")) +
|
||||
theme.fg("customMessageText", " to expand)"),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,669 @@
|
|||
/**
|
||||
* TUI component for managing package resources (enable/disable)
|
||||
*/
|
||||
|
||||
import { basename, dirname, join, relative } from "node:path";
|
||||
import {
|
||||
type Component,
|
||||
Container,
|
||||
type Focusable,
|
||||
getEditorKeybindings,
|
||||
Input,
|
||||
matchesKey,
|
||||
Spacer,
|
||||
truncateToWidth,
|
||||
visibleWidth,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { CONFIG_DIR_NAME } from "../../../config.js";
|
||||
import type {
|
||||
PathMetadata,
|
||||
ResolvedPaths,
|
||||
ResolvedResource,
|
||||
} from "../../../core/package-manager.js";
|
||||
import type {
|
||||
PackageSource,
|
||||
SettingsManager,
|
||||
} from "../../../core/settings-manager.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { rawKeyHint } from "./keybinding-hints.js";
|
||||
|
||||
type ResourceType = "extensions" | "skills" | "prompts" | "themes";
|
||||
|
||||
const RESOURCE_TYPE_LABELS: Record<ResourceType, string> = {
|
||||
extensions: "Extensions",
|
||||
skills: "Skills",
|
||||
prompts: "Prompts",
|
||||
themes: "Themes",
|
||||
};
|
||||
|
||||
interface ResourceItem {
|
||||
path: string;
|
||||
enabled: boolean;
|
||||
metadata: PathMetadata;
|
||||
resourceType: ResourceType;
|
||||
displayName: string;
|
||||
groupKey: string;
|
||||
subgroupKey: string;
|
||||
}
|
||||
|
||||
interface ResourceSubgroup {
|
||||
type: ResourceType;
|
||||
label: string;
|
||||
items: ResourceItem[];
|
||||
}
|
||||
|
||||
interface ResourceGroup {
|
||||
key: string;
|
||||
label: string;
|
||||
scope: "user" | "project" | "temporary";
|
||||
origin: "package" | "top-level";
|
||||
source: string;
|
||||
subgroups: ResourceSubgroup[];
|
||||
}
|
||||
|
||||
function getGroupLabel(metadata: PathMetadata): string {
|
||||
if (metadata.origin === "package") {
|
||||
return `${metadata.source} (${metadata.scope})`;
|
||||
}
|
||||
// Top-level resources
|
||||
if (metadata.source === "auto") {
|
||||
return metadata.scope === "user" ? "User (~/.pi/agent/)" : "Project (.pi/)";
|
||||
}
|
||||
return metadata.scope === "user" ? "User settings" : "Project settings";
|
||||
}
|
||||
|
||||
function buildGroups(resolved: ResolvedPaths): ResourceGroup[] {
|
||||
const groupMap = new Map<string, ResourceGroup>();
|
||||
|
||||
const addToGroup = (
|
||||
resources: ResolvedResource[],
|
||||
resourceType: ResourceType,
|
||||
) => {
|
||||
for (const res of resources) {
|
||||
const { path, enabled, metadata } = res;
|
||||
const groupKey = `${metadata.origin}:${metadata.scope}:${metadata.source}`;
|
||||
|
||||
if (!groupMap.has(groupKey)) {
|
||||
groupMap.set(groupKey, {
|
||||
key: groupKey,
|
||||
label: getGroupLabel(metadata),
|
||||
scope: metadata.scope,
|
||||
origin: metadata.origin,
|
||||
source: metadata.source,
|
||||
subgroups: [],
|
||||
});
|
||||
}
|
||||
|
||||
const group = groupMap.get(groupKey)!;
|
||||
const subgroupKey = `${groupKey}:${resourceType}`;
|
||||
|
||||
let subgroup = group.subgroups.find((sg) => sg.type === resourceType);
|
||||
if (!subgroup) {
|
||||
subgroup = {
|
||||
type: resourceType,
|
||||
label: RESOURCE_TYPE_LABELS[resourceType],
|
||||
items: [],
|
||||
};
|
||||
group.subgroups.push(subgroup);
|
||||
}
|
||||
|
||||
const fileName = basename(path);
|
||||
const parentFolder = basename(dirname(path));
|
||||
let displayName: string;
|
||||
if (resourceType === "extensions" && parentFolder !== "extensions") {
|
||||
displayName = `${parentFolder}/${fileName}`;
|
||||
} else if (resourceType === "skills" && fileName === "SKILL.md") {
|
||||
displayName = parentFolder;
|
||||
} else {
|
||||
displayName = fileName;
|
||||
}
|
||||
subgroup.items.push({
|
||||
path,
|
||||
enabled,
|
||||
metadata,
|
||||
resourceType,
|
||||
displayName,
|
||||
groupKey,
|
||||
subgroupKey,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
addToGroup(resolved.extensions, "extensions");
|
||||
addToGroup(resolved.skills, "skills");
|
||||
addToGroup(resolved.prompts, "prompts");
|
||||
addToGroup(resolved.themes, "themes");
|
||||
|
||||
// Sort groups: packages first, then top-level; user before project
|
||||
const groups = Array.from(groupMap.values());
|
||||
groups.sort((a, b) => {
|
||||
if (a.origin !== b.origin) {
|
||||
return a.origin === "package" ? -1 : 1;
|
||||
}
|
||||
if (a.scope !== b.scope) {
|
||||
return a.scope === "user" ? -1 : 1;
|
||||
}
|
||||
return a.source.localeCompare(b.source);
|
||||
});
|
||||
|
||||
// Sort subgroups within each group by type order, and items by name
|
||||
const typeOrder: Record<ResourceType, number> = {
|
||||
extensions: 0,
|
||||
skills: 1,
|
||||
prompts: 2,
|
||||
themes: 3,
|
||||
};
|
||||
for (const group of groups) {
|
||||
group.subgroups.sort((a, b) => typeOrder[a.type] - typeOrder[b.type]);
|
||||
for (const subgroup of group.subgroups) {
|
||||
subgroup.items.sort((a, b) => a.displayName.localeCompare(b.displayName));
|
||||
}
|
||||
}
|
||||
|
||||
return groups;
|
||||
}
|
||||
|
||||
type FlatEntry =
|
||||
| { type: "group"; group: ResourceGroup }
|
||||
| { type: "subgroup"; subgroup: ResourceSubgroup; group: ResourceGroup }
|
||||
| { type: "item"; item: ResourceItem };
|
||||
|
||||
class ConfigSelectorHeader implements Component {
|
||||
invalidate(): void {}
|
||||
|
||||
render(width: number): string[] {
|
||||
const title = theme.bold("Resource Configuration");
|
||||
const sep = theme.fg("muted", " · ");
|
||||
const hint =
|
||||
rawKeyHint("space", "toggle") + sep + rawKeyHint("esc", "close");
|
||||
const hintWidth = visibleWidth(hint);
|
||||
const titleWidth = visibleWidth(title);
|
||||
const spacing = Math.max(1, width - titleWidth - hintWidth);
|
||||
|
||||
return [
|
||||
truncateToWidth(`${title}${" ".repeat(spacing)}${hint}`, width, ""),
|
||||
theme.fg("muted", "Type to filter resources"),
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
class ResourceList implements Component, Focusable {
|
||||
private groups: ResourceGroup[];
|
||||
private flatItems: FlatEntry[] = [];
|
||||
private filteredItems: FlatEntry[] = [];
|
||||
private selectedIndex = 0;
|
||||
private searchInput: Input;
|
||||
private maxVisible = 15;
|
||||
private settingsManager: SettingsManager;
|
||||
private cwd: string;
|
||||
private agentDir: string;
|
||||
|
||||
public onCancel?: () => void;
|
||||
public onExit?: () => void;
|
||||
public onToggle?: (item: ResourceItem, newEnabled: boolean) => void;
|
||||
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.searchInput.focused = value;
|
||||
}
|
||||
|
||||
constructor(
|
||||
groups: ResourceGroup[],
|
||||
settingsManager: SettingsManager,
|
||||
cwd: string,
|
||||
agentDir: string,
|
||||
) {
|
||||
this.groups = groups;
|
||||
this.settingsManager = settingsManager;
|
||||
this.cwd = cwd;
|
||||
this.agentDir = agentDir;
|
||||
this.searchInput = new Input();
|
||||
this.buildFlatList();
|
||||
this.filteredItems = [...this.flatItems];
|
||||
}
|
||||
|
||||
private buildFlatList(): void {
|
||||
this.flatItems = [];
|
||||
for (const group of this.groups) {
|
||||
this.flatItems.push({ type: "group", group });
|
||||
for (const subgroup of group.subgroups) {
|
||||
this.flatItems.push({ type: "subgroup", subgroup, group });
|
||||
for (const item of subgroup.items) {
|
||||
this.flatItems.push({ type: "item", item });
|
||||
}
|
||||
}
|
||||
}
|
||||
// Start selection on first item (not header)
|
||||
this.selectedIndex = this.flatItems.findIndex((e) => e.type === "item");
|
||||
if (this.selectedIndex < 0) this.selectedIndex = 0;
|
||||
}
|
||||
|
||||
private findNextItem(fromIndex: number, direction: 1 | -1): number {
|
||||
let idx = fromIndex + direction;
|
||||
while (idx >= 0 && idx < this.filteredItems.length) {
|
||||
if (this.filteredItems[idx].type === "item") {
|
||||
return idx;
|
||||
}
|
||||
idx += direction;
|
||||
}
|
||||
return fromIndex; // Stay at current if no item found
|
||||
}
|
||||
|
||||
private filterItems(query: string): void {
|
||||
if (!query.trim()) {
|
||||
this.filteredItems = [...this.flatItems];
|
||||
this.selectFirstItem();
|
||||
return;
|
||||
}
|
||||
|
||||
const lowerQuery = query.toLowerCase();
|
||||
const matchingItems = new Set<ResourceItem>();
|
||||
const matchingSubgroups = new Set<ResourceSubgroup>();
|
||||
const matchingGroups = new Set<ResourceGroup>();
|
||||
|
||||
for (const entry of this.flatItems) {
|
||||
if (entry.type === "item") {
|
||||
const item = entry.item;
|
||||
if (
|
||||
item.displayName.toLowerCase().includes(lowerQuery) ||
|
||||
item.resourceType.toLowerCase().includes(lowerQuery) ||
|
||||
item.path.toLowerCase().includes(lowerQuery)
|
||||
) {
|
||||
matchingItems.add(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find which subgroups and groups contain matching items
|
||||
for (const group of this.groups) {
|
||||
for (const subgroup of group.subgroups) {
|
||||
for (const item of subgroup.items) {
|
||||
if (matchingItems.has(item)) {
|
||||
matchingSubgroups.add(subgroup);
|
||||
matchingGroups.add(group);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.filteredItems = [];
|
||||
for (const entry of this.flatItems) {
|
||||
if (entry.type === "group" && matchingGroups.has(entry.group)) {
|
||||
this.filteredItems.push(entry);
|
||||
} else if (
|
||||
entry.type === "subgroup" &&
|
||||
matchingSubgroups.has(entry.subgroup)
|
||||
) {
|
||||
this.filteredItems.push(entry);
|
||||
} else if (entry.type === "item" && matchingItems.has(entry.item)) {
|
||||
this.filteredItems.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
this.selectFirstItem();
|
||||
}
|
||||
|
||||
private selectFirstItem(): void {
|
||||
const firstItemIndex = this.filteredItems.findIndex(
|
||||
(e) => e.type === "item",
|
||||
);
|
||||
this.selectedIndex = firstItemIndex >= 0 ? firstItemIndex : 0;
|
||||
}
|
||||
|
||||
updateItem(item: ResourceItem, enabled: boolean): void {
|
||||
item.enabled = enabled;
|
||||
// Update in groups too
|
||||
for (const group of this.groups) {
|
||||
for (const subgroup of group.subgroups) {
|
||||
const found = subgroup.items.find(
|
||||
(i) => i.path === item.path && i.resourceType === item.resourceType,
|
||||
);
|
||||
if (found) {
|
||||
found.enabled = enabled;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
invalidate(): void {}
|
||||
|
||||
render(width: number): string[] {
|
||||
const lines: string[] = [];
|
||||
|
||||
// Search input
|
||||
lines.push(...this.searchInput.render(width));
|
||||
lines.push("");
|
||||
|
||||
if (this.filteredItems.length === 0) {
|
||||
lines.push(theme.fg("muted", " No resources found"));
|
||||
return lines;
|
||||
}
|
||||
|
||||
// Calculate visible range
|
||||
const startIndex = Math.max(
|
||||
0,
|
||||
Math.min(
|
||||
this.selectedIndex - Math.floor(this.maxVisible / 2),
|
||||
this.filteredItems.length - this.maxVisible,
|
||||
),
|
||||
);
|
||||
const endIndex = Math.min(
|
||||
startIndex + this.maxVisible,
|
||||
this.filteredItems.length,
|
||||
);
|
||||
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = this.filteredItems[i];
|
||||
const isSelected = i === this.selectedIndex;
|
||||
|
||||
if (entry.type === "group") {
|
||||
// Main group header (no cursor)
|
||||
const groupLine = theme.fg("accent", theme.bold(entry.group.label));
|
||||
lines.push(truncateToWidth(` ${groupLine}`, width, ""));
|
||||
} else if (entry.type === "subgroup") {
|
||||
// Subgroup header (indented, no cursor)
|
||||
const subgroupLine = theme.fg("muted", entry.subgroup.label);
|
||||
lines.push(truncateToWidth(` ${subgroupLine}`, width, ""));
|
||||
} else {
|
||||
// Resource item (cursor only on items)
|
||||
const item = entry.item;
|
||||
const cursor = isSelected ? "> " : " ";
|
||||
const checkbox = item.enabled
|
||||
? theme.fg("success", "[x]")
|
||||
: theme.fg("dim", "[ ]");
|
||||
const name = isSelected
|
||||
? theme.bold(item.displayName)
|
||||
: item.displayName;
|
||||
lines.push(
|
||||
truncateToWidth(`${cursor} ${checkbox} ${name}`, width, "..."),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll indicator
|
||||
if (startIndex > 0 || endIndex < this.filteredItems.length) {
|
||||
lines.push(
|
||||
theme.fg(
|
||||
"dim",
|
||||
` (${this.selectedIndex + 1}/${this.filteredItems.length})`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
|
||||
if (kb.matches(data, "selectUp")) {
|
||||
this.selectedIndex = this.findNextItem(this.selectedIndex, -1);
|
||||
return;
|
||||
}
|
||||
if (kb.matches(data, "selectDown")) {
|
||||
this.selectedIndex = this.findNextItem(this.selectedIndex, 1);
|
||||
return;
|
||||
}
|
||||
if (kb.matches(data, "selectPageUp")) {
|
||||
// Jump up by maxVisible, then find nearest item
|
||||
let target = Math.max(0, this.selectedIndex - this.maxVisible);
|
||||
while (
|
||||
target < this.filteredItems.length &&
|
||||
this.filteredItems[target].type !== "item"
|
||||
) {
|
||||
target++;
|
||||
}
|
||||
if (target < this.filteredItems.length) {
|
||||
this.selectedIndex = target;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (kb.matches(data, "selectPageDown")) {
|
||||
// Jump down by maxVisible, then find nearest item
|
||||
let target = Math.min(
|
||||
this.filteredItems.length - 1,
|
||||
this.selectedIndex + this.maxVisible,
|
||||
);
|
||||
while (target >= 0 && this.filteredItems[target].type !== "item") {
|
||||
target--;
|
||||
}
|
||||
if (target >= 0) {
|
||||
this.selectedIndex = target;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (kb.matches(data, "selectCancel")) {
|
||||
this.onCancel?.();
|
||||
return;
|
||||
}
|
||||
if (matchesKey(data, "ctrl+c")) {
|
||||
this.onExit?.();
|
||||
return;
|
||||
}
|
||||
if (data === " " || kb.matches(data, "selectConfirm")) {
|
||||
const entry = this.filteredItems[this.selectedIndex];
|
||||
if (entry?.type === "item") {
|
||||
const newEnabled = !entry.item.enabled;
|
||||
this.toggleResource(entry.item, newEnabled);
|
||||
this.updateItem(entry.item, newEnabled);
|
||||
this.onToggle?.(entry.item, newEnabled);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Pass to search input
|
||||
this.searchInput.handleInput(data);
|
||||
this.filterItems(this.searchInput.getValue());
|
||||
}
|
||||
|
||||
private toggleResource(item: ResourceItem, enabled: boolean): void {
|
||||
if (item.metadata.origin === "top-level") {
|
||||
this.toggleTopLevelResource(item, enabled);
|
||||
} else {
|
||||
this.togglePackageResource(item, enabled);
|
||||
}
|
||||
}
|
||||
|
||||
private toggleTopLevelResource(item: ResourceItem, enabled: boolean): void {
|
||||
const scope = item.metadata.scope as "user" | "project";
|
||||
const settings =
|
||||
scope === "project"
|
||||
? this.settingsManager.getProjectSettings()
|
||||
: this.settingsManager.getGlobalSettings();
|
||||
|
||||
const arrayKey = item.resourceType as
|
||||
| "extensions"
|
||||
| "skills"
|
||||
| "prompts"
|
||||
| "themes";
|
||||
const current = (settings[arrayKey] ?? []) as string[];
|
||||
|
||||
// Generate pattern for this resource
|
||||
const pattern = this.getResourcePattern(item);
|
||||
const disablePattern = `-${pattern}`;
|
||||
const enablePattern = `+${pattern}`;
|
||||
|
||||
// Filter out existing patterns for this resource
|
||||
const updated = current.filter((p) => {
|
||||
const stripped =
|
||||
p.startsWith("!") || p.startsWith("+") || p.startsWith("-")
|
||||
? p.slice(1)
|
||||
: p;
|
||||
return stripped !== pattern;
|
||||
});
|
||||
|
||||
if (enabled) {
|
||||
updated.push(enablePattern);
|
||||
} else {
|
||||
updated.push(disablePattern);
|
||||
}
|
||||
|
||||
if (scope === "project") {
|
||||
if (arrayKey === "extensions") {
|
||||
this.settingsManager.setProjectExtensionPaths(updated);
|
||||
} else if (arrayKey === "skills") {
|
||||
this.settingsManager.setProjectSkillPaths(updated);
|
||||
} else if (arrayKey === "prompts") {
|
||||
this.settingsManager.setProjectPromptTemplatePaths(updated);
|
||||
} else if (arrayKey === "themes") {
|
||||
this.settingsManager.setProjectThemePaths(updated);
|
||||
}
|
||||
} else {
|
||||
if (arrayKey === "extensions") {
|
||||
this.settingsManager.setExtensionPaths(updated);
|
||||
} else if (arrayKey === "skills") {
|
||||
this.settingsManager.setSkillPaths(updated);
|
||||
} else if (arrayKey === "prompts") {
|
||||
this.settingsManager.setPromptTemplatePaths(updated);
|
||||
} else if (arrayKey === "themes") {
|
||||
this.settingsManager.setThemePaths(updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private togglePackageResource(item: ResourceItem, enabled: boolean): void {
|
||||
const scope = item.metadata.scope as "user" | "project";
|
||||
const settings =
|
||||
scope === "project"
|
||||
? this.settingsManager.getProjectSettings()
|
||||
: this.settingsManager.getGlobalSettings();
|
||||
|
||||
const packages = [...(settings.packages ?? [])] as PackageSource[];
|
||||
const pkgIndex = packages.findIndex((pkg) => {
|
||||
const source = typeof pkg === "string" ? pkg : pkg.source;
|
||||
return source === item.metadata.source;
|
||||
});
|
||||
|
||||
if (pkgIndex === -1) return;
|
||||
|
||||
let pkg = packages[pkgIndex];
|
||||
|
||||
// Convert string to object form if needed
|
||||
if (typeof pkg === "string") {
|
||||
pkg = { source: pkg };
|
||||
packages[pkgIndex] = pkg;
|
||||
}
|
||||
|
||||
// Get the resource array for this type
|
||||
const arrayKey = item.resourceType as
|
||||
| "extensions"
|
||||
| "skills"
|
||||
| "prompts"
|
||||
| "themes";
|
||||
const current = (pkg[arrayKey] ?? []) as string[];
|
||||
|
||||
// Generate pattern relative to package root
|
||||
const pattern = this.getPackageResourcePattern(item);
|
||||
const disablePattern = `-${pattern}`;
|
||||
const enablePattern = `+${pattern}`;
|
||||
|
||||
// Filter out existing patterns for this resource
|
||||
const updated = current.filter((p) => {
|
||||
const stripped =
|
||||
p.startsWith("!") || p.startsWith("+") || p.startsWith("-")
|
||||
? p.slice(1)
|
||||
: p;
|
||||
return stripped !== pattern;
|
||||
});
|
||||
|
||||
if (enabled) {
|
||||
updated.push(enablePattern);
|
||||
} else {
|
||||
updated.push(disablePattern);
|
||||
}
|
||||
|
||||
(pkg as Record<string, unknown>)[arrayKey] =
|
||||
updated.length > 0 ? updated : undefined;
|
||||
|
||||
// Clean up empty filter object
|
||||
const hasFilters = ["extensions", "skills", "prompts", "themes"].some(
|
||||
(k) => (pkg as Record<string, unknown>)[k] !== undefined,
|
||||
);
|
||||
if (!hasFilters) {
|
||||
packages[pkgIndex] = (pkg as { source: string }).source;
|
||||
}
|
||||
|
||||
if (scope === "project") {
|
||||
this.settingsManager.setProjectPackages(packages);
|
||||
} else {
|
||||
this.settingsManager.setPackages(packages);
|
||||
}
|
||||
}
|
||||
|
||||
private getTopLevelBaseDir(scope: "user" | "project"): string {
|
||||
return scope === "project"
|
||||
? join(this.cwd, CONFIG_DIR_NAME)
|
||||
: this.agentDir;
|
||||
}
|
||||
|
||||
private getResourcePattern(item: ResourceItem): string {
|
||||
const scope = item.metadata.scope as "user" | "project";
|
||||
const baseDir = this.getTopLevelBaseDir(scope);
|
||||
return relative(baseDir, item.path);
|
||||
}
|
||||
|
||||
private getPackageResourcePattern(item: ResourceItem): string {
|
||||
const baseDir = item.metadata.baseDir ?? dirname(item.path);
|
||||
return relative(baseDir, item.path);
|
||||
}
|
||||
}
|
||||
|
||||
export class ConfigSelectorComponent extends Container implements Focusable {
|
||||
private resourceList: ResourceList;
|
||||
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.resourceList.focused = value;
|
||||
}
|
||||
|
||||
constructor(
|
||||
resolvedPaths: ResolvedPaths,
|
||||
settingsManager: SettingsManager,
|
||||
cwd: string,
|
||||
agentDir: string,
|
||||
onClose: () => void,
|
||||
onExit: () => void,
|
||||
requestRender: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
const groups = buildGroups(resolvedPaths);
|
||||
|
||||
// Add header
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new ConfigSelectorHeader());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Resource list
|
||||
this.resourceList = new ResourceList(
|
||||
groups,
|
||||
settingsManager,
|
||||
cwd,
|
||||
agentDir,
|
||||
);
|
||||
this.resourceList.onCancel = onClose;
|
||||
this.resourceList.onExit = onExit;
|
||||
this.resourceList.onToggle = () => requestRender();
|
||||
this.addChild(this.resourceList);
|
||||
|
||||
// Bottom border
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
getResourceList(): ResourceList {
|
||||
return this.resourceList;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Reusable countdown timer for dialog components.
|
||||
*/
|
||||
|
||||
import type { TUI } from "@mariozechner/pi-tui";
|
||||
|
||||
export class CountdownTimer {
|
||||
private intervalId: ReturnType<typeof setInterval> | undefined;
|
||||
private remainingSeconds: number;
|
||||
|
||||
constructor(
|
||||
timeoutMs: number,
|
||||
private tui: TUI | undefined,
|
||||
private onTick: (seconds: number) => void,
|
||||
private onExpire: () => void,
|
||||
) {
|
||||
this.remainingSeconds = Math.ceil(timeoutMs / 1000);
|
||||
this.onTick(this.remainingSeconds);
|
||||
|
||||
this.intervalId = setInterval(() => {
|
||||
this.remainingSeconds--;
|
||||
this.onTick(this.remainingSeconds);
|
||||
this.tui?.requestRender();
|
||||
|
||||
if (this.remainingSeconds <= 0) {
|
||||
this.dispose();
|
||||
this.onExpire();
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
if (this.intervalId) {
|
||||
clearInterval(this.intervalId);
|
||||
this.intervalId = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import {
|
||||
Editor,
|
||||
type EditorOptions,
|
||||
type EditorTheme,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type {
|
||||
AppAction,
|
||||
KeybindingsManager,
|
||||
} from "../../../core/keybindings.js";
|
||||
|
||||
/**
|
||||
* Custom editor that handles app-level keybindings for coding-agent.
|
||||
*/
|
||||
export class CustomEditor extends Editor {
|
||||
private keybindings: KeybindingsManager;
|
||||
public actionHandlers: Map<AppAction, () => void> = new Map();
|
||||
|
||||
// Special handlers that can be dynamically replaced
|
||||
public onEscape?: () => void;
|
||||
public onCtrlD?: () => void;
|
||||
public onPasteImage?: () => void;
|
||||
/** Handler for extension-registered shortcuts. Returns true if handled. */
|
||||
public onExtensionShortcut?: (data: string) => boolean;
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
theme: EditorTheme,
|
||||
keybindings: KeybindingsManager,
|
||||
options?: EditorOptions,
|
||||
) {
|
||||
super(tui, theme, options);
|
||||
this.keybindings = keybindings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a handler for an app action.
|
||||
*/
|
||||
onAction(action: AppAction, handler: () => void): void {
|
||||
this.actionHandlers.set(action, handler);
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
// Check extension-registered shortcuts first
|
||||
if (this.onExtensionShortcut?.(data)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for paste image keybinding
|
||||
if (this.keybindings.matches(data, "pasteImage")) {
|
||||
this.onPasteImage?.();
|
||||
return;
|
||||
}
|
||||
|
||||
// Check app keybindings first
|
||||
|
||||
// Escape/interrupt - only if autocomplete is NOT active
|
||||
if (this.keybindings.matches(data, "interrupt")) {
|
||||
if (!this.isShowingAutocomplete()) {
|
||||
// Use dynamic onEscape if set, otherwise registered handler
|
||||
const handler = this.onEscape ?? this.actionHandlers.get("interrupt");
|
||||
if (handler) {
|
||||
handler();
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Let parent handle escape for autocomplete cancellation
|
||||
super.handleInput(data);
|
||||
return;
|
||||
}
|
||||
|
||||
// Exit (Ctrl+D) - only when editor is empty
|
||||
if (this.keybindings.matches(data, "exit")) {
|
||||
if (this.getText().length === 0) {
|
||||
const handler = this.onCtrlD ?? this.actionHandlers.get("exit");
|
||||
if (handler) handler();
|
||||
return;
|
||||
}
|
||||
// Fall through to editor handling for delete-char-forward when not empty
|
||||
}
|
||||
|
||||
// Check all other app actions
|
||||
for (const [action, handler] of this.actionHandlers) {
|
||||
if (
|
||||
action !== "interrupt" &&
|
||||
action !== "exit" &&
|
||||
this.keybindings.matches(data, action)
|
||||
) {
|
||||
handler();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass to parent for editor handling
|
||||
super.handleInput(data);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
import type { TextContent } from "@mariozechner/pi-ai";
|
||||
import type { Component } from "@mariozechner/pi-tui";
|
||||
import {
|
||||
Box,
|
||||
Container,
|
||||
Markdown,
|
||||
type MarkdownTheme,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { MessageRenderer } from "../../../core/extensions/types.js";
|
||||
import type { CustomMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a custom message entry from extensions.
|
||||
* Uses distinct styling to differentiate from user messages.
|
||||
*/
|
||||
export class CustomMessageComponent extends Container {
|
||||
private message: CustomMessage<unknown>;
|
||||
private customRenderer?: MessageRenderer;
|
||||
private box: Box;
|
||||
private customComponent?: Component;
|
||||
private markdownTheme: MarkdownTheme;
|
||||
private _expanded = false;
|
||||
|
||||
constructor(
|
||||
message: CustomMessage<unknown>,
|
||||
customRenderer?: MessageRenderer,
|
||||
markdownTheme: MarkdownTheme = getMarkdownTheme(),
|
||||
) {
|
||||
super();
|
||||
this.message = message;
|
||||
this.customRenderer = customRenderer;
|
||||
this.markdownTheme = markdownTheme;
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create box with purple background (used for default rendering)
|
||||
this.box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
|
||||
this.rebuild();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
if (this._expanded !== expanded) {
|
||||
this._expanded = expanded;
|
||||
this.rebuild();
|
||||
}
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
this.rebuild();
|
||||
}
|
||||
|
||||
private rebuild(): void {
|
||||
// Remove previous content component
|
||||
if (this.customComponent) {
|
||||
this.removeChild(this.customComponent);
|
||||
this.customComponent = undefined;
|
||||
}
|
||||
this.removeChild(this.box);
|
||||
|
||||
// Try custom renderer first - it handles its own styling
|
||||
if (this.customRenderer) {
|
||||
try {
|
||||
const component = this.customRenderer(
|
||||
this.message,
|
||||
{ expanded: this._expanded },
|
||||
theme,
|
||||
);
|
||||
if (component) {
|
||||
// Custom renderer provides its own styled component
|
||||
this.customComponent = component;
|
||||
this.addChild(component);
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// Fall through to default rendering
|
||||
}
|
||||
}
|
||||
|
||||
// Default rendering uses our box
|
||||
this.addChild(this.box);
|
||||
this.box.clear();
|
||||
|
||||
// Default rendering: label + content
|
||||
const label = theme.fg(
|
||||
"customMessageLabel",
|
||||
`\x1b[1m[${this.message.customType}]\x1b[22m`,
|
||||
);
|
||||
this.box.addChild(new Text(label, 0, 0));
|
||||
this.box.addChild(new Spacer(1));
|
||||
|
||||
// Extract text content
|
||||
let text: string;
|
||||
if (typeof this.message.content === "string") {
|
||||
text = this.message.content;
|
||||
} else {
|
||||
text = this.message.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
this.box.addChild(
|
||||
new Markdown(text, 0, 0, this.markdownTheme, {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
179
packages/coding-agent/src/modes/interactive/components/diff.ts
Normal file
179
packages/coding-agent/src/modes/interactive/components/diff.ts
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
import * as Diff from "diff";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Parse diff line to extract prefix, line number, and content.
|
||||
* Format: "+123 content" or "-123 content" or " 123 content" or " ..."
|
||||
*/
|
||||
function parseDiffLine(
|
||||
line: string,
|
||||
): { prefix: string; lineNum: string; content: string } | null {
|
||||
const match = line.match(/^([+-\s])(\s*\d*)\s(.*)$/);
|
||||
if (!match) return null;
|
||||
return { prefix: match[1], lineNum: match[2], content: match[3] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace tabs with spaces for consistent rendering.
|
||||
*/
|
||||
function replaceTabs(text: string): string {
|
||||
return text.replace(/\t/g, " ");
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute word-level diff and render with inverse on changed parts.
|
||||
* Uses diffWords which groups whitespace with adjacent words for cleaner highlighting.
|
||||
* Strips leading whitespace from inverse to avoid highlighting indentation.
|
||||
*/
|
||||
function renderIntraLineDiff(
|
||||
oldContent: string,
|
||||
newContent: string,
|
||||
): { removedLine: string; addedLine: string } {
|
||||
const wordDiff = Diff.diffWords(oldContent, newContent);
|
||||
|
||||
let removedLine = "";
|
||||
let addedLine = "";
|
||||
let isFirstRemoved = true;
|
||||
let isFirstAdded = true;
|
||||
|
||||
for (const part of wordDiff) {
|
||||
if (part.removed) {
|
||||
let value = part.value;
|
||||
// Strip leading whitespace from the first removed part
|
||||
if (isFirstRemoved) {
|
||||
const leadingWs = value.match(/^(\s*)/)?.[1] || "";
|
||||
value = value.slice(leadingWs.length);
|
||||
removedLine += leadingWs;
|
||||
isFirstRemoved = false;
|
||||
}
|
||||
if (value) {
|
||||
removedLine += theme.inverse(value);
|
||||
}
|
||||
} else if (part.added) {
|
||||
let value = part.value;
|
||||
// Strip leading whitespace from the first added part
|
||||
if (isFirstAdded) {
|
||||
const leadingWs = value.match(/^(\s*)/)?.[1] || "";
|
||||
value = value.slice(leadingWs.length);
|
||||
addedLine += leadingWs;
|
||||
isFirstAdded = false;
|
||||
}
|
||||
if (value) {
|
||||
addedLine += theme.inverse(value);
|
||||
}
|
||||
} else {
|
||||
removedLine += part.value;
|
||||
addedLine += part.value;
|
||||
}
|
||||
}
|
||||
|
||||
return { removedLine, addedLine };
|
||||
}
|
||||
|
||||
export interface RenderDiffOptions {
|
||||
/** File path (unused, kept for API compatibility) */
|
||||
filePath?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render a diff string with colored lines and intra-line change highlighting.
|
||||
* - Context lines: dim/gray
|
||||
* - Removed lines: red, with inverse on changed tokens
|
||||
* - Added lines: green, with inverse on changed tokens
|
||||
*/
|
||||
export function renderDiff(
|
||||
diffText: string,
|
||||
_options: RenderDiffOptions = {},
|
||||
): string {
|
||||
const lines = diffText.split("\n");
|
||||
const result: string[] = [];
|
||||
|
||||
let i = 0;
|
||||
while (i < lines.length) {
|
||||
const line = lines[i];
|
||||
const parsed = parseDiffLine(line);
|
||||
|
||||
if (!parsed) {
|
||||
result.push(theme.fg("toolDiffContext", line));
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (parsed.prefix === "-") {
|
||||
// Collect consecutive removed lines
|
||||
const removedLines: { lineNum: string; content: string }[] = [];
|
||||
while (i < lines.length) {
|
||||
const p = parseDiffLine(lines[i]);
|
||||
if (!p || p.prefix !== "-") break;
|
||||
removedLines.push({ lineNum: p.lineNum, content: p.content });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Collect consecutive added lines
|
||||
const addedLines: { lineNum: string; content: string }[] = [];
|
||||
while (i < lines.length) {
|
||||
const p = parseDiffLine(lines[i]);
|
||||
if (!p || p.prefix !== "+") break;
|
||||
addedLines.push({ lineNum: p.lineNum, content: p.content });
|
||||
i++;
|
||||
}
|
||||
|
||||
// Only do intra-line diffing when there's exactly one removed and one added line
|
||||
// (indicating a single line modification). Otherwise, show lines as-is.
|
||||
if (removedLines.length === 1 && addedLines.length === 1) {
|
||||
const removed = removedLines[0];
|
||||
const added = addedLines[0];
|
||||
|
||||
const { removedLine, addedLine } = renderIntraLineDiff(
|
||||
replaceTabs(removed.content),
|
||||
replaceTabs(added.content),
|
||||
);
|
||||
|
||||
result.push(
|
||||
theme.fg("toolDiffRemoved", `-${removed.lineNum} ${removedLine}`),
|
||||
);
|
||||
result.push(
|
||||
theme.fg("toolDiffAdded", `+${added.lineNum} ${addedLine}`),
|
||||
);
|
||||
} else {
|
||||
// Show all removed lines first, then all added lines
|
||||
for (const removed of removedLines) {
|
||||
result.push(
|
||||
theme.fg(
|
||||
"toolDiffRemoved",
|
||||
`-${removed.lineNum} ${replaceTabs(removed.content)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
for (const added of addedLines) {
|
||||
result.push(
|
||||
theme.fg(
|
||||
"toolDiffAdded",
|
||||
`+${added.lineNum} ${replaceTabs(added.content)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (parsed.prefix === "+") {
|
||||
// Standalone added line
|
||||
result.push(
|
||||
theme.fg(
|
||||
"toolDiffAdded",
|
||||
`+${parsed.lineNum} ${replaceTabs(parsed.content)}`,
|
||||
),
|
||||
);
|
||||
i++;
|
||||
} else {
|
||||
// Context line
|
||||
result.push(
|
||||
theme.fg(
|
||||
"toolDiffContext",
|
||||
` ${parsed.lineNum} ${replaceTabs(parsed.content)}`,
|
||||
),
|
||||
);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
return result.join("\n");
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import type { Component } from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Dynamic border component that adjusts to viewport width.
|
||||
*
|
||||
* Note: When used from extensions loaded via jiti, the global `theme` may be undefined
|
||||
* because jiti creates a separate module cache. Always pass an explicit color
|
||||
* function when using DynamicBorder in components exported for extension use.
|
||||
*/
|
||||
export class DynamicBorder implements Component {
|
||||
private color: (str: string) => string;
|
||||
|
||||
constructor(
|
||||
color: (str: string) => string = (str) => theme.fg("border", str),
|
||||
) {
|
||||
this.color = color;
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
// No cached state to invalidate currently
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
return [this.color("─".repeat(Math.max(1, width)))];
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Multi-line editor component for extensions.
|
||||
* Supports Ctrl+G for external editor.
|
||||
*/
|
||||
|
||||
import { spawnSync } from "node:child_process";
|
||||
import * as fs from "node:fs";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import {
|
||||
Container,
|
||||
Editor,
|
||||
type EditorOptions,
|
||||
type Focusable,
|
||||
getEditorKeybindings,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { KeybindingsManager } from "../../../core/keybindings.js";
|
||||
import { getEditorTheme, theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { appKeyHint, keyHint } from "./keybinding-hints.js";
|
||||
|
||||
export class ExtensionEditorComponent extends Container implements Focusable {
|
||||
private editor: Editor;
|
||||
private onSubmitCallback: (value: string) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private tui: TUI;
|
||||
private keybindings: KeybindingsManager;
|
||||
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.editor.focused = value;
|
||||
}
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
keybindings: KeybindingsManager,
|
||||
title: string,
|
||||
prefill: string | undefined,
|
||||
onSubmit: (value: string) => void,
|
||||
onCancel: () => void,
|
||||
options?: EditorOptions,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.tui = tui;
|
||||
this.keybindings = keybindings;
|
||||
this.onSubmitCallback = onSubmit;
|
||||
this.onCancelCallback = onCancel;
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add title
|
||||
this.addChild(new Text(theme.fg("accent", title), 1, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create editor
|
||||
this.editor = new Editor(tui, getEditorTheme(), options);
|
||||
if (prefill) {
|
||||
this.editor.setText(prefill);
|
||||
}
|
||||
// Wire up Enter to submit (Shift+Enter for newlines, like the main editor)
|
||||
this.editor.onSubmit = (text: string) => {
|
||||
this.onSubmitCallback(text);
|
||||
};
|
||||
this.addChild(this.editor);
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add hint
|
||||
const hasExternalEditor = !!(process.env.VISUAL || process.env.EDITOR);
|
||||
const hint =
|
||||
keyHint("selectConfirm", "submit") +
|
||||
" " +
|
||||
keyHint("newLine", "newline") +
|
||||
" " +
|
||||
keyHint("selectCancel", "cancel") +
|
||||
(hasExternalEditor
|
||||
? ` ${appKeyHint(this.keybindings, "externalEditor", "external editor")}`
|
||||
: "");
|
||||
this.addChild(new Text(hint, 1, 0));
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
// Escape or Ctrl+C to cancel
|
||||
if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onCancelCallback();
|
||||
return;
|
||||
}
|
||||
|
||||
// External editor (app keybinding)
|
||||
if (this.keybindings.matches(keyData, "externalEditor")) {
|
||||
this.openExternalEditor();
|
||||
return;
|
||||
}
|
||||
|
||||
// Forward to editor
|
||||
this.editor.handleInput(keyData);
|
||||
}
|
||||
|
||||
private openExternalEditor(): void {
|
||||
const editorCmd = process.env.VISUAL || process.env.EDITOR;
|
||||
if (!editorCmd) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentText = this.editor.getText();
|
||||
const tmpFile = path.join(
|
||||
os.tmpdir(),
|
||||
`pi-extension-editor-${Date.now()}.md`,
|
||||
);
|
||||
|
||||
try {
|
||||
fs.writeFileSync(tmpFile, currentText, "utf-8");
|
||||
this.tui.stop();
|
||||
|
||||
const [editor, ...editorArgs] = editorCmd.split(" ");
|
||||
const result = spawnSync(editor, [...editorArgs, tmpFile], {
|
||||
stdio: "inherit",
|
||||
});
|
||||
|
||||
if (result.status === 0) {
|
||||
const newContent = fs.readFileSync(tmpFile, "utf-8").replace(/\n$/, "");
|
||||
this.editor.setText(newContent);
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
fs.unlinkSync(tmpFile);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
this.tui.start();
|
||||
// Force full re-render since external editor uses alternate screen
|
||||
this.tui.requestRender(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Simple text input component for extensions.
|
||||
*/
|
||||
|
||||
import {
|
||||
Container,
|
||||
type Focusable,
|
||||
getEditorKeybindings,
|
||||
Input,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { CountdownTimer } from "./countdown-timer.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { keyHint } from "./keybinding-hints.js";
|
||||
|
||||
export interface ExtensionInputOptions {
|
||||
tui?: TUI;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
export class ExtensionInputComponent extends Container implements Focusable {
|
||||
private input: Input;
|
||||
private onSubmitCallback: (value: string) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private titleText: Text;
|
||||
private baseTitle: string;
|
||||
private countdown: CountdownTimer | undefined;
|
||||
|
||||
// Focusable implementation - propagate to input for IME cursor positioning
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.input.focused = value;
|
||||
}
|
||||
|
||||
constructor(
|
||||
title: string,
|
||||
_placeholder: string | undefined,
|
||||
onSubmit: (value: string) => void,
|
||||
onCancel: () => void,
|
||||
opts?: ExtensionInputOptions,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.onSubmitCallback = onSubmit;
|
||||
this.onCancelCallback = onCancel;
|
||||
this.baseTitle = title;
|
||||
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
this.titleText = new Text(theme.fg("accent", title), 1, 0);
|
||||
this.addChild(this.titleText);
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (opts?.timeout && opts.timeout > 0 && opts.tui) {
|
||||
this.countdown = new CountdownTimer(
|
||||
opts.timeout,
|
||||
opts.tui,
|
||||
(s) =>
|
||||
this.titleText.setText(
|
||||
theme.fg("accent", `${this.baseTitle} (${s}s)`),
|
||||
),
|
||||
() => this.onCancelCallback(),
|
||||
);
|
||||
}
|
||||
|
||||
this.input = new Input();
|
||||
this.addChild(this.input);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Text(
|
||||
`${keyHint("selectConfirm", "submit")} ${keyHint("selectCancel", "cancel")}`,
|
||||
1,
|
||||
0,
|
||||
),
|
||||
);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
if (kb.matches(keyData, "selectConfirm") || keyData === "\n") {
|
||||
this.onSubmitCallback(this.input.getValue());
|
||||
} else if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onCancelCallback();
|
||||
} else {
|
||||
this.input.handleInput(keyData);
|
||||
}
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.countdown?.dispose();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Generic selector component for extensions.
|
||||
* Displays a list of string options with keyboard navigation.
|
||||
*/
|
||||
|
||||
import {
|
||||
Container,
|
||||
getEditorKeybindings,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { CountdownTimer } from "./countdown-timer.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { keyHint, rawKeyHint } from "./keybinding-hints.js";
|
||||
|
||||
export interface ExtensionSelectorOptions {
|
||||
tui?: TUI;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
export class ExtensionSelectorComponent extends Container {
|
||||
private options: string[];
|
||||
private selectedIndex = 0;
|
||||
private listContainer: Container;
|
||||
private onSelectCallback: (option: string) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private titleText: Text;
|
||||
private baseTitle: string;
|
||||
private countdown: CountdownTimer | undefined;
|
||||
|
||||
constructor(
|
||||
title: string,
|
||||
options: string[],
|
||||
onSelect: (option: string) => void,
|
||||
onCancel: () => void,
|
||||
opts?: ExtensionSelectorOptions,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.options = options;
|
||||
this.onSelectCallback = onSelect;
|
||||
this.onCancelCallback = onCancel;
|
||||
this.baseTitle = title;
|
||||
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
this.titleText = new Text(theme.fg("accent", title), 1, 0);
|
||||
this.addChild(this.titleText);
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (opts?.timeout && opts.timeout > 0 && opts.tui) {
|
||||
this.countdown = new CountdownTimer(
|
||||
opts.timeout,
|
||||
opts.tui,
|
||||
(s) =>
|
||||
this.titleText.setText(
|
||||
theme.fg("accent", `${this.baseTitle} (${s}s)`),
|
||||
),
|
||||
() => this.onCancelCallback(),
|
||||
);
|
||||
}
|
||||
|
||||
this.listContainer = new Container();
|
||||
this.addChild(this.listContainer);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Text(
|
||||
rawKeyHint("↑↓", "navigate") +
|
||||
" " +
|
||||
keyHint("selectConfirm", "select") +
|
||||
" " +
|
||||
keyHint("selectCancel", "cancel"),
|
||||
1,
|
||||
0,
|
||||
),
|
||||
);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
this.updateList();
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
this.listContainer.clear();
|
||||
for (let i = 0; i < this.options.length; i++) {
|
||||
const isSelected = i === this.selectedIndex;
|
||||
const text = isSelected
|
||||
? theme.fg("accent", "→ ") + theme.fg("accent", this.options[i])
|
||||
: ` ${theme.fg("text", this.options[i])}`;
|
||||
this.listContainer.addChild(new Text(text, 1, 0));
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
if (kb.matches(keyData, "selectUp") || keyData === "k") {
|
||||
this.selectedIndex = Math.max(0, this.selectedIndex - 1);
|
||||
this.updateList();
|
||||
} else if (kb.matches(keyData, "selectDown") || keyData === "j") {
|
||||
this.selectedIndex = Math.min(
|
||||
this.options.length - 1,
|
||||
this.selectedIndex + 1,
|
||||
);
|
||||
this.updateList();
|
||||
} else if (kb.matches(keyData, "selectConfirm") || keyData === "\n") {
|
||||
const selected = this.options[this.selectedIndex];
|
||||
if (selected) this.onSelectCallback(selected);
|
||||
} else if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onCancelCallback();
|
||||
}
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.countdown?.dispose();
|
||||
}
|
||||
}
|
||||
236
packages/coding-agent/src/modes/interactive/components/footer.ts
Normal file
236
packages/coding-agent/src/modes/interactive/components/footer.ts
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
import {
|
||||
type Component,
|
||||
truncateToWidth,
|
||||
visibleWidth,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { AgentSession } from "../../../core/agent-session.js";
|
||||
import type { ReadonlyFooterDataProvider } from "../../../core/footer-data-provider.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Sanitize text for display in a single-line status.
|
||||
* Removes newlines, tabs, carriage returns, and other control characters.
|
||||
*/
|
||||
function sanitizeStatusText(text: string): string {
|
||||
// Replace newlines, tabs, carriage returns with space, then collapse multiple spaces
|
||||
return text
|
||||
.replace(/[\r\n\t]/g, " ")
|
||||
.replace(/ +/g, " ")
|
||||
.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Format token counts (similar to web-ui)
|
||||
*/
|
||||
function formatTokens(count: number): string {
|
||||
if (count < 1000) return count.toString();
|
||||
if (count < 10000) return `${(count / 1000).toFixed(1)}k`;
|
||||
if (count < 1000000) return `${Math.round(count / 1000)}k`;
|
||||
if (count < 10000000) return `${(count / 1000000).toFixed(1)}M`;
|
||||
return `${Math.round(count / 1000000)}M`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Footer component that shows pwd, token stats, and context usage.
|
||||
* Computes token/context stats from session, gets git branch and extension statuses from provider.
|
||||
*/
|
||||
export class FooterComponent implements Component {
|
||||
private autoCompactEnabled = true;
|
||||
|
||||
constructor(
|
||||
private session: AgentSession,
|
||||
private footerData: ReadonlyFooterDataProvider,
|
||||
) {}
|
||||
|
||||
setAutoCompactEnabled(enabled: boolean): void {
|
||||
this.autoCompactEnabled = enabled;
|
||||
}
|
||||
|
||||
/**
|
||||
* No-op: git branch caching now handled by provider.
|
||||
* Kept for compatibility with existing call sites in interactive-mode.
|
||||
*/
|
||||
invalidate(): void {
|
||||
// No-op: git branch is cached/invalidated by provider
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up resources.
|
||||
* Git watcher cleanup now handled by provider.
|
||||
*/
|
||||
dispose(): void {
|
||||
// Git watcher cleanup handled by provider
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
const state = this.session.state;
|
||||
|
||||
// Calculate cumulative usage from ALL session entries (not just post-compaction messages)
|
||||
let totalInput = 0;
|
||||
let totalOutput = 0;
|
||||
let totalCacheRead = 0;
|
||||
let totalCacheWrite = 0;
|
||||
let totalCost = 0;
|
||||
|
||||
for (const entry of this.session.sessionManager.getEntries()) {
|
||||
if (entry.type === "message" && entry.message.role === "assistant") {
|
||||
totalInput += entry.message.usage.input;
|
||||
totalOutput += entry.message.usage.output;
|
||||
totalCacheRead += entry.message.usage.cacheRead;
|
||||
totalCacheWrite += entry.message.usage.cacheWrite;
|
||||
totalCost += entry.message.usage.cost.total;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate context usage from session (handles compaction correctly).
|
||||
// After compaction, tokens are unknown until the next LLM response.
|
||||
const contextUsage = this.session.getContextUsage();
|
||||
const contextWindow =
|
||||
contextUsage?.contextWindow ?? state.model?.contextWindow ?? 0;
|
||||
const contextPercentValue = contextUsage?.percent ?? 0;
|
||||
const contextPercent =
|
||||
contextUsage?.percent !== null ? contextPercentValue.toFixed(1) : "?";
|
||||
|
||||
// Replace home directory with ~
|
||||
let pwd = process.cwd();
|
||||
const home = process.env.HOME || process.env.USERPROFILE;
|
||||
if (home && pwd.startsWith(home)) {
|
||||
pwd = `~${pwd.slice(home.length)}`;
|
||||
}
|
||||
|
||||
// Add git branch if available
|
||||
const branch = this.footerData.getGitBranch();
|
||||
if (branch) {
|
||||
pwd = `${pwd} (${branch})`;
|
||||
}
|
||||
|
||||
// Add session name if set
|
||||
const sessionName = this.session.sessionManager.getSessionName();
|
||||
if (sessionName) {
|
||||
pwd = `${pwd} • ${sessionName}`;
|
||||
}
|
||||
|
||||
// Build stats line
|
||||
const statsParts = [];
|
||||
if (totalInput) statsParts.push(`↑${formatTokens(totalInput)}`);
|
||||
if (totalOutput) statsParts.push(`↓${formatTokens(totalOutput)}`);
|
||||
if (totalCacheRead) statsParts.push(`R${formatTokens(totalCacheRead)}`);
|
||||
if (totalCacheWrite) statsParts.push(`W${formatTokens(totalCacheWrite)}`);
|
||||
|
||||
// Show cost with "(sub)" indicator if using OAuth subscription
|
||||
const usingSubscription = state.model
|
||||
? this.session.modelRegistry.isUsingOAuth(state.model)
|
||||
: false;
|
||||
if (totalCost || usingSubscription) {
|
||||
const costStr = `$${totalCost.toFixed(3)}${usingSubscription ? " (sub)" : ""}`;
|
||||
statsParts.push(costStr);
|
||||
}
|
||||
|
||||
// Colorize context percentage based on usage
|
||||
let contextPercentStr: string;
|
||||
const autoIndicator = this.autoCompactEnabled ? " (auto)" : "";
|
||||
const contextPercentDisplay =
|
||||
contextPercent === "?"
|
||||
? `?/${formatTokens(contextWindow)}${autoIndicator}`
|
||||
: `${contextPercent}%/${formatTokens(contextWindow)}${autoIndicator}`;
|
||||
if (contextPercentValue > 90) {
|
||||
contextPercentStr = theme.fg("error", contextPercentDisplay);
|
||||
} else if (contextPercentValue > 70) {
|
||||
contextPercentStr = theme.fg("warning", contextPercentDisplay);
|
||||
} else {
|
||||
contextPercentStr = contextPercentDisplay;
|
||||
}
|
||||
statsParts.push(contextPercentStr);
|
||||
|
||||
let statsLeft = statsParts.join(" ");
|
||||
|
||||
// Add model name on the right side, plus thinking level if model supports it
|
||||
const modelName = state.model?.id || "no-model";
|
||||
|
||||
let statsLeftWidth = visibleWidth(statsLeft);
|
||||
|
||||
// If statsLeft is too wide, truncate it
|
||||
if (statsLeftWidth > width) {
|
||||
statsLeft = truncateToWidth(statsLeft, width, "...");
|
||||
statsLeftWidth = visibleWidth(statsLeft);
|
||||
}
|
||||
|
||||
// Calculate available space for padding (minimum 2 spaces between stats and model)
|
||||
const minPadding = 2;
|
||||
|
||||
// Add thinking level indicator if model supports reasoning
|
||||
let rightSideWithoutProvider = modelName;
|
||||
if (state.model?.reasoning) {
|
||||
const thinkingLevel = state.thinkingLevel || "off";
|
||||
rightSideWithoutProvider =
|
||||
thinkingLevel === "off"
|
||||
? `${modelName} • thinking off`
|
||||
: `${modelName} • ${thinkingLevel}`;
|
||||
}
|
||||
|
||||
// Prepend the provider in parentheses if there are multiple providers and there's enough room
|
||||
let rightSide = rightSideWithoutProvider;
|
||||
if (this.footerData.getAvailableProviderCount() > 1 && state.model) {
|
||||
rightSide = `(${state.model!.provider}) ${rightSideWithoutProvider}`;
|
||||
if (statsLeftWidth + minPadding + visibleWidth(rightSide) > width) {
|
||||
// Too wide, fall back
|
||||
rightSide = rightSideWithoutProvider;
|
||||
}
|
||||
}
|
||||
|
||||
const rightSideWidth = visibleWidth(rightSide);
|
||||
const totalNeeded = statsLeftWidth + minPadding + rightSideWidth;
|
||||
|
||||
let statsLine: string;
|
||||
if (totalNeeded <= width) {
|
||||
// Both fit - add padding to right-align model
|
||||
const padding = " ".repeat(width - statsLeftWidth - rightSideWidth);
|
||||
statsLine = statsLeft + padding + rightSide;
|
||||
} else {
|
||||
// Need to truncate right side
|
||||
const availableForRight = width - statsLeftWidth - minPadding;
|
||||
if (availableForRight > 0) {
|
||||
const truncatedRight = truncateToWidth(
|
||||
rightSide,
|
||||
availableForRight,
|
||||
"",
|
||||
);
|
||||
const truncatedRightWidth = visibleWidth(truncatedRight);
|
||||
const padding = " ".repeat(
|
||||
Math.max(0, width - statsLeftWidth - truncatedRightWidth),
|
||||
);
|
||||
statsLine = statsLeft + padding + truncatedRight;
|
||||
} else {
|
||||
// Not enough space for right side at all
|
||||
statsLine = statsLeft;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply dim to each part separately. statsLeft may contain color codes (for context %)
|
||||
// that end with a reset, which would clear an outer dim wrapper. So we dim the parts
|
||||
// before and after the colored section independently.
|
||||
const dimStatsLeft = theme.fg("dim", statsLeft);
|
||||
const remainder = statsLine.slice(statsLeft.length); // padding + rightSide
|
||||
const dimRemainder = theme.fg("dim", remainder);
|
||||
|
||||
const pwdLine = truncateToWidth(
|
||||
theme.fg("dim", pwd),
|
||||
width,
|
||||
theme.fg("dim", "..."),
|
||||
);
|
||||
const lines = [pwdLine, dimStatsLeft + dimRemainder];
|
||||
|
||||
// Add extension statuses on a single line, sorted by key alphabetically
|
||||
const extensionStatuses = this.footerData.getExtensionStatuses();
|
||||
if (extensionStatuses.size > 0) {
|
||||
const sortedStatuses = Array.from(extensionStatuses.entries())
|
||||
.sort(([a], [b]) => a.localeCompare(b))
|
||||
.map(([, text]) => sanitizeStatusText(text));
|
||||
const statusLine = sortedStatuses.join(" ");
|
||||
// Truncate to terminal width with dim ellipsis for consistency with footer style
|
||||
lines.push(truncateToWidth(statusLine, width, theme.fg("dim", "...")));
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
// UI Components for extensions
|
||||
export { ArminComponent } from "./armin.js";
|
||||
export { AssistantMessageComponent } from "./assistant-message.js";
|
||||
export { BashExecutionComponent } from "./bash-execution.js";
|
||||
export { BorderedLoader } from "./bordered-loader.js";
|
||||
export { BranchSummaryMessageComponent } from "./branch-summary-message.js";
|
||||
export { CompactionSummaryMessageComponent } from "./compaction-summary-message.js";
|
||||
export { CustomEditor } from "./custom-editor.js";
|
||||
export { CustomMessageComponent } from "./custom-message.js";
|
||||
export { DaxnutsComponent } from "./daxnuts.js";
|
||||
export { type RenderDiffOptions, renderDiff } from "./diff.js";
|
||||
export { DynamicBorder } from "./dynamic-border.js";
|
||||
export { ExtensionEditorComponent } from "./extension-editor.js";
|
||||
export { ExtensionInputComponent } from "./extension-input.js";
|
||||
export { ExtensionSelectorComponent } from "./extension-selector.js";
|
||||
export { FooterComponent } from "./footer.js";
|
||||
export {
|
||||
appKey,
|
||||
appKeyHint,
|
||||
editorKey,
|
||||
keyHint,
|
||||
rawKeyHint,
|
||||
} from "./keybinding-hints.js";
|
||||
export { LoginDialogComponent } from "./login-dialog.js";
|
||||
export { ModelSelectorComponent } from "./model-selector.js";
|
||||
export { OAuthSelectorComponent } from "./oauth-selector.js";
|
||||
export {
|
||||
type ModelsCallbacks,
|
||||
type ModelsConfig,
|
||||
ScopedModelsSelectorComponent,
|
||||
} from "./scoped-models-selector.js";
|
||||
export { SessionSelectorComponent } from "./session-selector.js";
|
||||
export {
|
||||
type SettingsCallbacks,
|
||||
type SettingsConfig,
|
||||
SettingsSelectorComponent,
|
||||
} from "./settings-selector.js";
|
||||
export { ShowImagesSelectorComponent } from "./show-images-selector.js";
|
||||
export { SkillInvocationMessageComponent } from "./skill-invocation-message.js";
|
||||
export { ThemeSelectorComponent } from "./theme-selector.js";
|
||||
export { ThinkingSelectorComponent } from "./thinking-selector.js";
|
||||
export {
|
||||
ToolExecutionComponent,
|
||||
type ToolExecutionOptions,
|
||||
} from "./tool-execution.js";
|
||||
export { TreeSelectorComponent } from "./tree-selector.js";
|
||||
export { UserMessageComponent } from "./user-message.js";
|
||||
export { UserMessageSelectorComponent } from "./user-message-selector.js";
|
||||
export {
|
||||
truncateToVisualLines,
|
||||
type VisualTruncateResult,
|
||||
} from "./visual-truncate.js";
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Utilities for formatting keybinding hints in the UI.
|
||||
*/
|
||||
|
||||
import {
|
||||
type EditorAction,
|
||||
getEditorKeybindings,
|
||||
type KeyId,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type {
|
||||
AppAction,
|
||||
KeybindingsManager,
|
||||
} from "../../../core/keybindings.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Format keys array as display string (e.g., ["ctrl+c", "escape"] -> "ctrl+c/escape").
|
||||
*/
|
||||
function formatKeys(keys: KeyId[]): string {
|
||||
if (keys.length === 0) return "";
|
||||
if (keys.length === 1) return keys[0]!;
|
||||
return keys.join("/");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get display string for an editor action.
|
||||
*/
|
||||
export function editorKey(action: EditorAction): string {
|
||||
return formatKeys(getEditorKeybindings().getKeys(action));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get display string for an app action.
|
||||
*/
|
||||
export function appKey(
|
||||
keybindings: KeybindingsManager,
|
||||
action: AppAction,
|
||||
): string {
|
||||
return formatKeys(keybindings.getKeys(action));
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a keybinding hint with consistent styling: dim key, muted description.
|
||||
* Looks up the key from editor keybindings automatically.
|
||||
*
|
||||
* @param action - Editor action name (e.g., "selectConfirm", "expandTools")
|
||||
* @param description - Description text (e.g., "to expand", "cancel")
|
||||
* @returns Formatted string with dim key and muted description
|
||||
*/
|
||||
export function keyHint(action: EditorAction, description: string): string {
|
||||
return (
|
||||
theme.fg("dim", editorKey(action)) + theme.fg("muted", ` ${description}`)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a keybinding hint for app-level actions.
|
||||
* Requires the KeybindingsManager instance.
|
||||
*
|
||||
* @param keybindings - KeybindingsManager instance
|
||||
* @param action - App action name (e.g., "interrupt", "externalEditor")
|
||||
* @param description - Description text
|
||||
* @returns Formatted string with dim key and muted description
|
||||
*/
|
||||
export function appKeyHint(
|
||||
keybindings: KeybindingsManager,
|
||||
action: AppAction,
|
||||
description: string,
|
||||
): string {
|
||||
return (
|
||||
theme.fg("dim", appKey(keybindings, action)) +
|
||||
theme.fg("muted", ` ${description}`)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a raw key string with description (for non-configurable keys like ↑↓).
|
||||
*
|
||||
* @param key - Raw key string
|
||||
* @param description - Description text
|
||||
* @returns Formatted string with dim key and muted description
|
||||
*/
|
||||
export function rawKeyHint(key: string, description: string): string {
|
||||
return theme.fg("dim", key) + theme.fg("muted", ` ${description}`);
|
||||
}
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import { getOAuthProviders } from "@mariozechner/pi-ai/oauth";
|
||||
import {
|
||||
Container,
|
||||
type Focusable,
|
||||
getEditorKeybindings,
|
||||
Input,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { exec } from "child_process";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { keyHint } from "./keybinding-hints.js";
|
||||
|
||||
/**
|
||||
* Login dialog component - replaces editor during OAuth login flow
|
||||
*/
|
||||
export class LoginDialogComponent extends Container implements Focusable {
|
||||
private contentContainer: Container;
|
||||
private input: Input;
|
||||
private tui: TUI;
|
||||
private abortController = new AbortController();
|
||||
private inputResolver?: (value: string) => void;
|
||||
private inputRejecter?: (error: Error) => void;
|
||||
|
||||
// Focusable implementation - propagate to input for IME cursor positioning
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.input.focused = value;
|
||||
}
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
providerId: string,
|
||||
private onComplete: (success: boolean, message?: string) => void,
|
||||
) {
|
||||
super();
|
||||
this.tui = tui;
|
||||
|
||||
const providerInfo = getOAuthProviders().find((p) => p.id === providerId);
|
||||
const providerName = providerInfo?.name || providerId;
|
||||
|
||||
// Top border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Title
|
||||
this.addChild(
|
||||
new Text(theme.fg("warning", `Login to ${providerName}`), 1, 0),
|
||||
);
|
||||
|
||||
// Dynamic content area
|
||||
this.contentContainer = new Container();
|
||||
this.addChild(this.contentContainer);
|
||||
|
||||
// Input (always present, used when needed)
|
||||
this.input = new Input();
|
||||
this.input.onSubmit = () => {
|
||||
if (this.inputResolver) {
|
||||
this.inputResolver(this.input.getValue());
|
||||
this.inputResolver = undefined;
|
||||
this.inputRejecter = undefined;
|
||||
}
|
||||
};
|
||||
this.input.onEscape = () => {
|
||||
this.cancel();
|
||||
};
|
||||
|
||||
// Bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
get signal(): AbortSignal {
|
||||
return this.abortController.signal;
|
||||
}
|
||||
|
||||
private cancel(): void {
|
||||
this.abortController.abort();
|
||||
if (this.inputRejecter) {
|
||||
this.inputRejecter(new Error("Login cancelled"));
|
||||
this.inputResolver = undefined;
|
||||
this.inputRejecter = undefined;
|
||||
}
|
||||
this.onComplete(false, "Login cancelled");
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by onAuth callback - show URL and optional instructions
|
||||
*/
|
||||
showAuth(url: string, instructions?: string): void {
|
||||
this.contentContainer.clear();
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(new Text(theme.fg("accent", url), 1, 0));
|
||||
|
||||
const clickHint =
|
||||
process.platform === "darwin"
|
||||
? "Cmd+click to open"
|
||||
: "Ctrl+click to open";
|
||||
const hyperlink = `\x1b]8;;${url}\x07${clickHint}\x1b]8;;\x07`;
|
||||
this.contentContainer.addChild(new Text(theme.fg("dim", hyperlink), 1, 0));
|
||||
|
||||
if (instructions) {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(
|
||||
new Text(theme.fg("warning", instructions), 1, 0),
|
||||
);
|
||||
}
|
||||
|
||||
// Try to open browser
|
||||
const openCmd =
|
||||
process.platform === "darwin"
|
||||
? "open"
|
||||
: process.platform === "win32"
|
||||
? "start"
|
||||
: "xdg-open";
|
||||
exec(`${openCmd} "${url}"`);
|
||||
|
||||
this.tui.requestRender();
|
||||
}
|
||||
|
||||
/**
|
||||
* Show input for manual code/URL entry (for callback server providers)
|
||||
*/
|
||||
showManualInput(prompt: string): Promise<string> {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(new Text(theme.fg("dim", prompt), 1, 0));
|
||||
this.contentContainer.addChild(this.input);
|
||||
this.contentContainer.addChild(
|
||||
new Text(`(${keyHint("selectCancel", "to cancel")})`, 1, 0),
|
||||
);
|
||||
this.tui.requestRender();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
this.inputResolver = resolve;
|
||||
this.inputRejecter = reject;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by onPrompt callback - show prompt and wait for input
|
||||
* Note: Does NOT clear content, appends to existing (preserves URL from showAuth)
|
||||
*/
|
||||
showPrompt(message: string, placeholder?: string): Promise<string> {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(new Text(theme.fg("text", message), 1, 0));
|
||||
if (placeholder) {
|
||||
this.contentContainer.addChild(
|
||||
new Text(theme.fg("dim", `e.g., ${placeholder}`), 1, 0),
|
||||
);
|
||||
}
|
||||
this.contentContainer.addChild(this.input);
|
||||
this.contentContainer.addChild(
|
||||
new Text(
|
||||
`(${keyHint("selectCancel", "to cancel,")} ${keyHint("selectConfirm", "to submit")})`,
|
||||
1,
|
||||
0,
|
||||
),
|
||||
);
|
||||
|
||||
this.input.setValue("");
|
||||
this.tui.requestRender();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
this.inputResolver = resolve;
|
||||
this.inputRejecter = reject;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Show waiting message (for polling flows like GitHub Copilot)
|
||||
*/
|
||||
showWaiting(message: string): void {
|
||||
this.contentContainer.addChild(new Spacer(1));
|
||||
this.contentContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
|
||||
this.contentContainer.addChild(
|
||||
new Text(`(${keyHint("selectCancel", "to cancel")})`, 1, 0),
|
||||
);
|
||||
this.tui.requestRender();
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by onProgress callback
|
||||
*/
|
||||
showProgress(message: string): void {
|
||||
this.contentContainer.addChild(new Text(theme.fg("dim", message), 1, 0));
|
||||
this.tui.requestRender();
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
|
||||
if (kb.matches(data, "selectCancel")) {
|
||||
this.cancel();
|
||||
return;
|
||||
}
|
||||
|
||||
// Pass to input
|
||||
this.input.handleInput(data);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
import { type Model, modelsAreEqual } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
Container,
|
||||
type Focusable,
|
||||
fuzzyFilter,
|
||||
getEditorKeybindings,
|
||||
Input,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { ModelRegistry } from "../../../core/model-registry.js";
|
||||
import type { SettingsManager } from "../../../core/settings-manager.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
import { keyHint } from "./keybinding-hints.js";
|
||||
|
||||
interface ModelItem {
|
||||
provider: string;
|
||||
id: string;
|
||||
model: Model<any>;
|
||||
}
|
||||
|
||||
interface ScopedModelItem {
|
||||
model: Model<any>;
|
||||
thinkingLevel?: string;
|
||||
}
|
||||
|
||||
type ModelScope = "all" | "scoped";
|
||||
|
||||
/**
|
||||
* Component that renders a model selector with search
|
||||
*/
|
||||
export class ModelSelectorComponent extends Container implements Focusable {
|
||||
private searchInput: Input;
|
||||
|
||||
// Focusable implementation - propagate to searchInput for IME cursor positioning
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.searchInput.focused = value;
|
||||
}
|
||||
private listContainer: Container;
|
||||
private allModels: ModelItem[] = [];
|
||||
private scopedModelItems: ModelItem[] = [];
|
||||
private activeModels: ModelItem[] = [];
|
||||
private filteredModels: ModelItem[] = [];
|
||||
private selectedIndex: number = 0;
|
||||
private currentModel?: Model<any>;
|
||||
private settingsManager: SettingsManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private onSelectCallback: (model: Model<any>) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private errorMessage?: string;
|
||||
private tui: TUI;
|
||||
private scopedModels: ReadonlyArray<ScopedModelItem>;
|
||||
private scope: ModelScope = "all";
|
||||
private scopeText?: Text;
|
||||
private scopeHintText?: Text;
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
currentModel: Model<any> | undefined,
|
||||
settingsManager: SettingsManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
scopedModels: ReadonlyArray<ScopedModelItem>,
|
||||
onSelect: (model: Model<any>) => void,
|
||||
onCancel: () => void,
|
||||
initialSearchInput?: string,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.tui = tui;
|
||||
this.currentModel = currentModel;
|
||||
this.settingsManager = settingsManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.scopedModels = scopedModels;
|
||||
this.scope = scopedModels.length > 0 ? "scoped" : "all";
|
||||
this.onSelectCallback = onSelect;
|
||||
this.onCancelCallback = onCancel;
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add hint about model filtering
|
||||
if (scopedModels.length > 0) {
|
||||
this.scopeText = new Text(this.getScopeText(), 0, 0);
|
||||
this.addChild(this.scopeText);
|
||||
this.scopeHintText = new Text(this.getScopeHintText(), 0, 0);
|
||||
this.addChild(this.scopeHintText);
|
||||
} else {
|
||||
const hintText =
|
||||
"Only showing models with configured API keys (see README for details)";
|
||||
this.addChild(new Text(theme.fg("warning", hintText), 0, 0));
|
||||
}
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create search input
|
||||
this.searchInput = new Input();
|
||||
if (initialSearchInput) {
|
||||
this.searchInput.setValue(initialSearchInput);
|
||||
}
|
||||
this.searchInput.onSubmit = () => {
|
||||
// Enter on search input selects the first filtered item
|
||||
if (this.filteredModels[this.selectedIndex]) {
|
||||
this.handleSelect(this.filteredModels[this.selectedIndex].model);
|
||||
}
|
||||
};
|
||||
this.addChild(this.searchInput);
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create list container
|
||||
this.listContainer = new Container();
|
||||
this.addChild(this.listContainer);
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Load models and do initial render
|
||||
this.loadModels().then(() => {
|
||||
if (initialSearchInput) {
|
||||
this.filterModels(initialSearchInput);
|
||||
} else {
|
||||
this.updateList();
|
||||
}
|
||||
// Request re-render after models are loaded
|
||||
this.tui.requestRender();
|
||||
});
|
||||
}
|
||||
|
||||
private async loadModels(): Promise<void> {
|
||||
let models: ModelItem[];
|
||||
|
||||
// Refresh to pick up any changes to models.json
|
||||
this.modelRegistry.refresh();
|
||||
|
||||
// Check for models.json errors
|
||||
const loadError = this.modelRegistry.getError();
|
||||
if (loadError) {
|
||||
this.errorMessage = loadError;
|
||||
}
|
||||
|
||||
// Load available models (built-in models still work even if models.json failed)
|
||||
try {
|
||||
const availableModels = await this.modelRegistry.getAvailable();
|
||||
models = availableModels.map((model: Model<any>) => ({
|
||||
provider: model.provider,
|
||||
id: model.id,
|
||||
model,
|
||||
}));
|
||||
} catch (error) {
|
||||
this.allModels = [];
|
||||
this.scopedModelItems = [];
|
||||
this.activeModels = [];
|
||||
this.filteredModels = [];
|
||||
this.errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
return;
|
||||
}
|
||||
|
||||
this.allModels = this.sortModels(models);
|
||||
this.scopedModelItems = this.sortModels(
|
||||
this.scopedModels.map((scoped) => ({
|
||||
provider: scoped.model.provider,
|
||||
id: scoped.model.id,
|
||||
model: scoped.model,
|
||||
})),
|
||||
);
|
||||
this.activeModels =
|
||||
this.scope === "scoped" ? this.scopedModelItems : this.allModels;
|
||||
this.filteredModels = this.activeModels;
|
||||
this.selectedIndex = Math.min(
|
||||
this.selectedIndex,
|
||||
Math.max(0, this.filteredModels.length - 1),
|
||||
);
|
||||
}
|
||||
|
||||
private sortModels(models: ModelItem[]): ModelItem[] {
|
||||
const sorted = [...models];
|
||||
// Sort: current model first, then by provider
|
||||
sorted.sort((a, b) => {
|
||||
const aIsCurrent = modelsAreEqual(this.currentModel, a.model);
|
||||
const bIsCurrent = modelsAreEqual(this.currentModel, b.model);
|
||||
if (aIsCurrent && !bIsCurrent) return -1;
|
||||
if (!aIsCurrent && bIsCurrent) return 1;
|
||||
return a.provider.localeCompare(b.provider);
|
||||
});
|
||||
return sorted;
|
||||
}
|
||||
|
||||
private getScopeText(): string {
|
||||
const allText =
|
||||
this.scope === "all"
|
||||
? theme.fg("accent", "all")
|
||||
: theme.fg("muted", "all");
|
||||
const scopedText =
|
||||
this.scope === "scoped"
|
||||
? theme.fg("accent", "scoped")
|
||||
: theme.fg("muted", "scoped");
|
||||
return `${theme.fg("muted", "Scope: ")}${allText}${theme.fg("muted", " | ")}${scopedText}`;
|
||||
}
|
||||
|
||||
private getScopeHintText(): string {
|
||||
return keyHint("tab", "scope") + theme.fg("muted", " (all/scoped)");
|
||||
}
|
||||
|
||||
private setScope(scope: ModelScope): void {
|
||||
if (this.scope === scope) return;
|
||||
this.scope = scope;
|
||||
this.activeModels =
|
||||
this.scope === "scoped" ? this.scopedModelItems : this.allModels;
|
||||
this.selectedIndex = 0;
|
||||
this.filterModels(this.searchInput.getValue());
|
||||
if (this.scopeText) {
|
||||
this.scopeText.setText(this.getScopeText());
|
||||
}
|
||||
}
|
||||
|
||||
private filterModels(query: string): void {
|
||||
this.filteredModels = query
|
||||
? fuzzyFilter(
|
||||
this.activeModels,
|
||||
query,
|
||||
({ id, provider }) => `${id} ${provider}`,
|
||||
)
|
||||
: this.activeModels;
|
||||
this.selectedIndex = Math.min(
|
||||
this.selectedIndex,
|
||||
Math.max(0, this.filteredModels.length - 1),
|
||||
);
|
||||
this.updateList();
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
this.listContainer.clear();
|
||||
|
||||
const maxVisible = 10;
|
||||
const startIndex = Math.max(
|
||||
0,
|
||||
Math.min(
|
||||
this.selectedIndex - Math.floor(maxVisible / 2),
|
||||
this.filteredModels.length - maxVisible,
|
||||
),
|
||||
);
|
||||
const endIndex = Math.min(
|
||||
startIndex + maxVisible,
|
||||
this.filteredModels.length,
|
||||
);
|
||||
|
||||
// Show visible slice of filtered models
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const item = this.filteredModels[i];
|
||||
if (!item) continue;
|
||||
|
||||
const isSelected = i === this.selectedIndex;
|
||||
const isCurrent = modelsAreEqual(this.currentModel, item.model);
|
||||
|
||||
let line = "";
|
||||
if (isSelected) {
|
||||
const prefix = theme.fg("accent", "→ ");
|
||||
const modelText = `${item.id}`;
|
||||
const providerBadge = theme.fg("muted", `[${item.provider}]`);
|
||||
const checkmark = isCurrent ? theme.fg("success", " ✓") : "";
|
||||
line = `${prefix + theme.fg("accent", modelText)} ${providerBadge}${checkmark}`;
|
||||
} else {
|
||||
const modelText = ` ${item.id}`;
|
||||
const providerBadge = theme.fg("muted", `[${item.provider}]`);
|
||||
const checkmark = isCurrent ? theme.fg("success", " ✓") : "";
|
||||
line = `${modelText} ${providerBadge}${checkmark}`;
|
||||
}
|
||||
|
||||
this.listContainer.addChild(new Text(line, 0, 0));
|
||||
}
|
||||
|
||||
// Add scroll indicator if needed
|
||||
if (startIndex > 0 || endIndex < this.filteredModels.length) {
|
||||
const scrollInfo = theme.fg(
|
||||
"muted",
|
||||
` (${this.selectedIndex + 1}/${this.filteredModels.length})`,
|
||||
);
|
||||
this.listContainer.addChild(new Text(scrollInfo, 0, 0));
|
||||
}
|
||||
|
||||
// Show error message or "no results" if empty
|
||||
if (this.errorMessage) {
|
||||
// Show error in red
|
||||
const errorLines = this.errorMessage.split("\n");
|
||||
for (const line of errorLines) {
|
||||
this.listContainer.addChild(new Text(theme.fg("error", line), 0, 0));
|
||||
}
|
||||
} else if (this.filteredModels.length === 0) {
|
||||
this.listContainer.addChild(
|
||||
new Text(theme.fg("muted", " No matching models"), 0, 0),
|
||||
);
|
||||
} else {
|
||||
const selected = this.filteredModels[this.selectedIndex];
|
||||
this.listContainer.addChild(new Spacer(1));
|
||||
this.listContainer.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", ` Model Name: ${selected.model.name}`),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
if (kb.matches(keyData, "tab")) {
|
||||
if (this.scopedModelItems.length > 0) {
|
||||
const nextScope: ModelScope = this.scope === "all" ? "scoped" : "all";
|
||||
this.setScope(nextScope);
|
||||
if (this.scopeHintText) {
|
||||
this.scopeHintText.setText(this.getScopeHintText());
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Up arrow - wrap to bottom when at top
|
||||
if (kb.matches(keyData, "selectUp")) {
|
||||
if (this.filteredModels.length === 0) return;
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === 0
|
||||
? this.filteredModels.length - 1
|
||||
: this.selectedIndex - 1;
|
||||
this.updateList();
|
||||
}
|
||||
// Down arrow - wrap to top when at bottom
|
||||
else if (kb.matches(keyData, "selectDown")) {
|
||||
if (this.filteredModels.length === 0) return;
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === this.filteredModels.length - 1
|
||||
? 0
|
||||
: this.selectedIndex + 1;
|
||||
this.updateList();
|
||||
}
|
||||
// Enter
|
||||
else if (kb.matches(keyData, "selectConfirm")) {
|
||||
const selectedModel = this.filteredModels[this.selectedIndex];
|
||||
if (selectedModel) {
|
||||
this.handleSelect(selectedModel.model);
|
||||
}
|
||||
}
|
||||
// Escape or Ctrl+C
|
||||
else if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onCancelCallback();
|
||||
}
|
||||
// Pass everything else to search input
|
||||
else {
|
||||
this.searchInput.handleInput(keyData);
|
||||
this.filterModels(this.searchInput.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
private handleSelect(model: Model<any>): void {
|
||||
// Save as new default
|
||||
this.settingsManager.setDefaultModelAndProvider(model.provider, model.id);
|
||||
this.onSelectCallback(model);
|
||||
}
|
||||
|
||||
getSearchInput(): Input {
|
||||
return this.searchInput;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
import type { OAuthProviderInterface } from "@mariozechner/pi-ai";
|
||||
import { getOAuthProviders } from "@mariozechner/pi-ai/oauth";
|
||||
import {
|
||||
Container,
|
||||
getEditorKeybindings,
|
||||
Spacer,
|
||||
TruncatedText,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { AuthStorage } from "../../../core/auth-storage.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
/**
|
||||
* Component that renders an OAuth provider selector
|
||||
*/
|
||||
export class OAuthSelectorComponent extends Container {
|
||||
private listContainer: Container;
|
||||
private allProviders: OAuthProviderInterface[] = [];
|
||||
private selectedIndex: number = 0;
|
||||
private mode: "login" | "logout";
|
||||
private authStorage: AuthStorage;
|
||||
private onSelectCallback: (providerId: string) => void;
|
||||
private onCancelCallback: () => void;
|
||||
|
||||
constructor(
|
||||
mode: "login" | "logout",
|
||||
authStorage: AuthStorage,
|
||||
onSelect: (providerId: string) => void,
|
||||
onCancel: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.mode = mode;
|
||||
this.authStorage = authStorage;
|
||||
this.onSelectCallback = onSelect;
|
||||
this.onCancelCallback = onCancel;
|
||||
|
||||
// Load all OAuth providers
|
||||
this.loadProviders();
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add title
|
||||
const title =
|
||||
mode === "login"
|
||||
? "Select provider to login:"
|
||||
: "Select provider to logout:";
|
||||
this.addChild(new TruncatedText(theme.bold(title)));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create list container
|
||||
this.listContainer = new Container();
|
||||
this.addChild(this.listContainer);
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Initial render
|
||||
this.updateList();
|
||||
}
|
||||
|
||||
private loadProviders(): void {
|
||||
this.allProviders = getOAuthProviders();
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
this.listContainer.clear();
|
||||
|
||||
for (let i = 0; i < this.allProviders.length; i++) {
|
||||
const provider = this.allProviders[i];
|
||||
if (!provider) continue;
|
||||
|
||||
const isSelected = i === this.selectedIndex;
|
||||
|
||||
// Check if user is logged in for this provider
|
||||
const credentials = this.authStorage.get(provider.id);
|
||||
const isLoggedIn = credentials?.type === "oauth";
|
||||
const statusIndicator = isLoggedIn
|
||||
? theme.fg("success", " ✓ logged in")
|
||||
: "";
|
||||
|
||||
let line = "";
|
||||
if (isSelected) {
|
||||
const prefix = theme.fg("accent", "→ ");
|
||||
const text = theme.fg("accent", provider.name);
|
||||
line = prefix + text + statusIndicator;
|
||||
} else {
|
||||
const text = ` ${provider.name}`;
|
||||
line = text + statusIndicator;
|
||||
}
|
||||
|
||||
this.listContainer.addChild(new TruncatedText(line, 0, 0));
|
||||
}
|
||||
|
||||
// Show "no providers" if empty
|
||||
if (this.allProviders.length === 0) {
|
||||
const message =
|
||||
this.mode === "login"
|
||||
? "No OAuth providers available"
|
||||
: "No OAuth providers logged in. Use /login first.";
|
||||
this.listContainer.addChild(
|
||||
new TruncatedText(theme.fg("muted", ` ${message}`), 0, 0),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
// Up arrow
|
||||
if (kb.matches(keyData, "selectUp")) {
|
||||
this.selectedIndex = Math.max(0, this.selectedIndex - 1);
|
||||
this.updateList();
|
||||
}
|
||||
// Down arrow
|
||||
else if (kb.matches(keyData, "selectDown")) {
|
||||
this.selectedIndex = Math.min(
|
||||
this.allProviders.length - 1,
|
||||
this.selectedIndex + 1,
|
||||
);
|
||||
this.updateList();
|
||||
}
|
||||
// Enter
|
||||
else if (kb.matches(keyData, "selectConfirm")) {
|
||||
const selectedProvider = this.allProviders[this.selectedIndex];
|
||||
if (selectedProvider) {
|
||||
this.onSelectCallback(selectedProvider.id);
|
||||
}
|
||||
}
|
||||
// Escape or Ctrl+C
|
||||
else if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onCancelCallback();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,444 @@
|
|||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
Container,
|
||||
type Focusable,
|
||||
fuzzyFilter,
|
||||
getEditorKeybindings,
|
||||
Input,
|
||||
Key,
|
||||
matchesKey,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
// EnabledIds: null = all enabled (no filter), string[] = explicit ordered list
|
||||
type EnabledIds = string[] | null;
|
||||
|
||||
function isEnabled(enabledIds: EnabledIds, id: string): boolean {
|
||||
return enabledIds === null || enabledIds.includes(id);
|
||||
}
|
||||
|
||||
function toggle(enabledIds: EnabledIds, id: string): EnabledIds {
|
||||
if (enabledIds === null) return [id]; // First toggle: start with only this one
|
||||
const index = enabledIds.indexOf(id);
|
||||
if (index >= 0)
|
||||
return [...enabledIds.slice(0, index), ...enabledIds.slice(index + 1)];
|
||||
return [...enabledIds, id];
|
||||
}
|
||||
|
||||
function enableAll(
|
||||
enabledIds: EnabledIds,
|
||||
allIds: string[],
|
||||
targetIds?: string[],
|
||||
): EnabledIds {
|
||||
if (enabledIds === null) return null; // Already all enabled
|
||||
const targets = targetIds ?? allIds;
|
||||
const result = [...enabledIds];
|
||||
for (const id of targets) {
|
||||
if (!result.includes(id)) result.push(id);
|
||||
}
|
||||
return result.length === allIds.length ? null : result;
|
||||
}
|
||||
|
||||
function clearAll(
|
||||
enabledIds: EnabledIds,
|
||||
allIds: string[],
|
||||
targetIds?: string[],
|
||||
): EnabledIds {
|
||||
if (enabledIds === null) {
|
||||
return targetIds ? allIds.filter((id) => !targetIds.includes(id)) : [];
|
||||
}
|
||||
const targets = new Set(targetIds ?? enabledIds);
|
||||
return enabledIds.filter((id) => !targets.has(id));
|
||||
}
|
||||
|
||||
function move(
|
||||
enabledIds: EnabledIds,
|
||||
allIds: string[],
|
||||
id: string,
|
||||
delta: number,
|
||||
): EnabledIds {
|
||||
const list = enabledIds ?? [...allIds];
|
||||
const index = list.indexOf(id);
|
||||
if (index < 0) return list;
|
||||
const newIndex = index + delta;
|
||||
if (newIndex < 0 || newIndex >= list.length) return list;
|
||||
const result = [...list];
|
||||
[result[index], result[newIndex]] = [result[newIndex], result[index]];
|
||||
return result;
|
||||
}
|
||||
|
||||
function getSortedIds(enabledIds: EnabledIds, allIds: string[]): string[] {
|
||||
if (enabledIds === null) return allIds;
|
||||
const enabledSet = new Set(enabledIds);
|
||||
return [...enabledIds, ...allIds.filter((id) => !enabledSet.has(id))];
|
||||
}
|
||||
|
||||
interface ModelItem {
|
||||
fullId: string;
|
||||
model: Model<any>;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export interface ModelsConfig {
|
||||
allModels: Model<any>[];
|
||||
enabledModelIds: Set<string>;
|
||||
/** true if enabledModels setting is defined (empty = all enabled) */
|
||||
hasEnabledModelsFilter: boolean;
|
||||
}
|
||||
|
||||
export interface ModelsCallbacks {
|
||||
/** Called when a model is toggled (session-only, no persist) */
|
||||
onModelToggle: (modelId: string, enabled: boolean) => void;
|
||||
/** Called when user wants to persist current selection to settings */
|
||||
onPersist: (enabledModelIds: string[]) => void;
|
||||
/** Called when user enables all models. Returns list of all model IDs. */
|
||||
onEnableAll: (allModelIds: string[]) => void;
|
||||
/** Called when user clears all models */
|
||||
onClearAll: () => void;
|
||||
/** Called when user toggles all models for a provider. Returns affected model IDs. */
|
||||
onToggleProvider: (
|
||||
provider: string,
|
||||
modelIds: string[],
|
||||
enabled: boolean,
|
||||
) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Component for enabling/disabling models for Ctrl+P cycling.
|
||||
* Changes are session-only until explicitly persisted with Ctrl+S.
|
||||
*/
|
||||
export class ScopedModelsSelectorComponent
|
||||
extends Container
|
||||
implements Focusable
|
||||
{
|
||||
private modelsById: Map<string, Model<any>> = new Map();
|
||||
private allIds: string[] = [];
|
||||
private enabledIds: EnabledIds = null;
|
||||
private filteredItems: ModelItem[] = [];
|
||||
private selectedIndex = 0;
|
||||
private searchInput: Input;
|
||||
|
||||
// Focusable implementation - propagate to searchInput for IME cursor positioning
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
this.searchInput.focused = value;
|
||||
}
|
||||
private listContainer: Container;
|
||||
private footerText: Text;
|
||||
private callbacks: ModelsCallbacks;
|
||||
private maxVisible = 15;
|
||||
private isDirty = false;
|
||||
|
||||
constructor(config: ModelsConfig, callbacks: ModelsCallbacks) {
|
||||
super();
|
||||
this.callbacks = callbacks;
|
||||
|
||||
for (const model of config.allModels) {
|
||||
const fullId = `${model.provider}/${model.id}`;
|
||||
this.modelsById.set(fullId, model);
|
||||
this.allIds.push(fullId);
|
||||
}
|
||||
|
||||
this.enabledIds = config.hasEnabledModelsFilter
|
||||
? [...config.enabledModelIds]
|
||||
: null;
|
||||
this.filteredItems = this.buildItems();
|
||||
|
||||
// Header
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Text(theme.fg("accent", theme.bold("Model Configuration")), 0, 0),
|
||||
);
|
||||
this.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", "Session-only. Ctrl+S to save to settings."),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Search input
|
||||
this.searchInput = new Input();
|
||||
this.addChild(this.searchInput);
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// List container
|
||||
this.listContainer = new Container();
|
||||
this.addChild(this.listContainer);
|
||||
|
||||
// Footer hint
|
||||
this.addChild(new Spacer(1));
|
||||
this.footerText = new Text(this.getFooterText(), 0, 0);
|
||||
this.addChild(this.footerText);
|
||||
|
||||
this.addChild(new DynamicBorder());
|
||||
this.updateList();
|
||||
}
|
||||
|
||||
private buildItems(): ModelItem[] {
|
||||
// Filter out IDs that no longer have a corresponding model (e.g., after logout)
|
||||
return getSortedIds(this.enabledIds, this.allIds)
|
||||
.filter((id) => this.modelsById.has(id))
|
||||
.map((id) => ({
|
||||
fullId: id,
|
||||
model: this.modelsById.get(id)!,
|
||||
enabled: isEnabled(this.enabledIds, id),
|
||||
}));
|
||||
}
|
||||
|
||||
private getFooterText(): string {
|
||||
const enabledCount = this.enabledIds?.length ?? this.allIds.length;
|
||||
const allEnabled = this.enabledIds === null;
|
||||
const countText = allEnabled
|
||||
? "all enabled"
|
||||
: `${enabledCount}/${this.allIds.length} enabled`;
|
||||
const parts = [
|
||||
"Enter toggle",
|
||||
"^A all",
|
||||
"^X clear",
|
||||
"^P provider",
|
||||
"Alt+↑↓ reorder",
|
||||
"^S save",
|
||||
countText,
|
||||
];
|
||||
return this.isDirty
|
||||
? theme.fg("dim", ` ${parts.join(" · ")} `) +
|
||||
theme.fg("warning", "(unsaved)")
|
||||
: theme.fg("dim", ` ${parts.join(" · ")}`);
|
||||
}
|
||||
|
||||
private refresh(): void {
|
||||
const query = this.searchInput.getValue();
|
||||
const items = this.buildItems();
|
||||
this.filteredItems = query
|
||||
? fuzzyFilter(items, query, (i) => `${i.model.id} ${i.model.provider}`)
|
||||
: items;
|
||||
this.selectedIndex = Math.min(
|
||||
this.selectedIndex,
|
||||
Math.max(0, this.filteredItems.length - 1),
|
||||
);
|
||||
this.updateList();
|
||||
this.footerText.setText(this.getFooterText());
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
this.listContainer.clear();
|
||||
|
||||
if (this.filteredItems.length === 0) {
|
||||
this.listContainer.addChild(
|
||||
new Text(theme.fg("muted", " No matching models"), 0, 0),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const startIndex = Math.max(
|
||||
0,
|
||||
Math.min(
|
||||
this.selectedIndex - Math.floor(this.maxVisible / 2),
|
||||
this.filteredItems.length - this.maxVisible,
|
||||
),
|
||||
);
|
||||
const endIndex = Math.min(
|
||||
startIndex + this.maxVisible,
|
||||
this.filteredItems.length,
|
||||
);
|
||||
const allEnabled = this.enabledIds === null;
|
||||
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const item = this.filteredItems[i]!;
|
||||
const isSelected = i === this.selectedIndex;
|
||||
const prefix = isSelected ? theme.fg("accent", "→ ") : " ";
|
||||
const modelText = isSelected
|
||||
? theme.fg("accent", item.model.id)
|
||||
: item.model.id;
|
||||
const providerBadge = theme.fg("muted", ` [${item.model.provider}]`);
|
||||
const status = allEnabled
|
||||
? ""
|
||||
: item.enabled
|
||||
? theme.fg("success", " ✓")
|
||||
: theme.fg("dim", " ✗");
|
||||
this.listContainer.addChild(
|
||||
new Text(`${prefix}${modelText}${providerBadge}${status}`, 0, 0),
|
||||
);
|
||||
}
|
||||
|
||||
// Add scroll indicator if needed
|
||||
if (startIndex > 0 || endIndex < this.filteredItems.length) {
|
||||
this.listContainer.addChild(
|
||||
new Text(
|
||||
theme.fg(
|
||||
"muted",
|
||||
` (${this.selectedIndex + 1}/${this.filteredItems.length})`,
|
||||
),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (this.filteredItems.length > 0) {
|
||||
const selected = this.filteredItems[this.selectedIndex];
|
||||
this.listContainer.addChild(new Spacer(1));
|
||||
this.listContainer.addChild(
|
||||
new Text(
|
||||
theme.fg("muted", ` Model Name: ${selected.model.name}`),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
|
||||
// Navigation
|
||||
if (kb.matches(data, "selectUp")) {
|
||||
if (this.filteredItems.length === 0) return;
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === 0
|
||||
? this.filteredItems.length - 1
|
||||
: this.selectedIndex - 1;
|
||||
this.updateList();
|
||||
return;
|
||||
}
|
||||
if (kb.matches(data, "selectDown")) {
|
||||
if (this.filteredItems.length === 0) return;
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === this.filteredItems.length - 1
|
||||
? 0
|
||||
: this.selectedIndex + 1;
|
||||
this.updateList();
|
||||
return;
|
||||
}
|
||||
|
||||
// Alt+Up/Down - Reorder enabled models
|
||||
if (matchesKey(data, Key.alt("up")) || matchesKey(data, Key.alt("down"))) {
|
||||
const item = this.filteredItems[this.selectedIndex];
|
||||
if (item && isEnabled(this.enabledIds, item.fullId)) {
|
||||
const delta = matchesKey(data, Key.alt("up")) ? -1 : 1;
|
||||
const enabledList = this.enabledIds ?? this.allIds;
|
||||
const currentIndex = enabledList.indexOf(item.fullId);
|
||||
const newIndex = currentIndex + delta;
|
||||
// Only move if within bounds
|
||||
if (newIndex >= 0 && newIndex < enabledList.length) {
|
||||
this.enabledIds = move(
|
||||
this.enabledIds,
|
||||
this.allIds,
|
||||
item.fullId,
|
||||
delta,
|
||||
);
|
||||
this.isDirty = true;
|
||||
this.selectedIndex += delta;
|
||||
this.refresh();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Toggle on Enter
|
||||
if (matchesKey(data, Key.enter)) {
|
||||
const item = this.filteredItems[this.selectedIndex];
|
||||
if (item) {
|
||||
const wasAllEnabled = this.enabledIds === null;
|
||||
this.enabledIds = toggle(this.enabledIds, item.fullId);
|
||||
this.isDirty = true;
|
||||
if (wasAllEnabled) this.callbacks.onClearAll();
|
||||
this.callbacks.onModelToggle(
|
||||
item.fullId,
|
||||
isEnabled(this.enabledIds, item.fullId),
|
||||
);
|
||||
this.refresh();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+A - Enable all (filtered if search active, otherwise all)
|
||||
if (matchesKey(data, Key.ctrl("a"))) {
|
||||
const targetIds = this.searchInput.getValue()
|
||||
? this.filteredItems.map((i) => i.fullId)
|
||||
: undefined;
|
||||
this.enabledIds = enableAll(this.enabledIds, this.allIds, targetIds);
|
||||
this.isDirty = true;
|
||||
this.callbacks.onEnableAll(targetIds ?? this.allIds);
|
||||
this.refresh();
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+X - Clear all (filtered if search active, otherwise all)
|
||||
if (matchesKey(data, Key.ctrl("x"))) {
|
||||
const targetIds = this.searchInput.getValue()
|
||||
? this.filteredItems.map((i) => i.fullId)
|
||||
: undefined;
|
||||
this.enabledIds = clearAll(this.enabledIds, this.allIds, targetIds);
|
||||
this.isDirty = true;
|
||||
this.callbacks.onClearAll();
|
||||
this.refresh();
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+P - Toggle provider of current item
|
||||
if (matchesKey(data, Key.ctrl("p"))) {
|
||||
const item = this.filteredItems[this.selectedIndex];
|
||||
if (item) {
|
||||
const provider = item.model.provider;
|
||||
const providerIds = this.allIds.filter(
|
||||
(id) => this.modelsById.get(id)!.provider === provider,
|
||||
);
|
||||
const allEnabled = providerIds.every((id) =>
|
||||
isEnabled(this.enabledIds, id),
|
||||
);
|
||||
this.enabledIds = allEnabled
|
||||
? clearAll(this.enabledIds, this.allIds, providerIds)
|
||||
: enableAll(this.enabledIds, this.allIds, providerIds);
|
||||
this.isDirty = true;
|
||||
this.callbacks.onToggleProvider(provider, providerIds, !allEnabled);
|
||||
this.refresh();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+S - Save/persist to settings
|
||||
if (matchesKey(data, Key.ctrl("s"))) {
|
||||
this.callbacks.onPersist(this.enabledIds ?? [...this.allIds]);
|
||||
this.isDirty = false;
|
||||
this.footerText.setText(this.getFooterText());
|
||||
return;
|
||||
}
|
||||
|
||||
// Ctrl+C - clear search or cancel if empty
|
||||
if (matchesKey(data, Key.ctrl("c"))) {
|
||||
if (this.searchInput.getValue()) {
|
||||
this.searchInput.setValue("");
|
||||
this.refresh();
|
||||
} else {
|
||||
this.callbacks.onCancel();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Escape - cancel
|
||||
if (matchesKey(data, Key.escape)) {
|
||||
this.callbacks.onCancel();
|
||||
return;
|
||||
}
|
||||
|
||||
// Pass everything else to search input
|
||||
this.searchInput.handleInput(data);
|
||||
this.refresh();
|
||||
}
|
||||
|
||||
getSearchInput(): Input {
|
||||
return this.searchInput;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
import { fuzzyMatch } from "@mariozechner/pi-tui";
|
||||
import type { SessionInfo } from "../../../core/session-manager.js";
|
||||
|
||||
export type SortMode = "threaded" | "recent" | "relevance";
|
||||
|
||||
export type NameFilter = "all" | "named";
|
||||
|
||||
export interface ParsedSearchQuery {
|
||||
mode: "tokens" | "regex";
|
||||
tokens: { kind: "fuzzy" | "phrase"; value: string }[];
|
||||
regex: RegExp | null;
|
||||
/** If set, parsing failed and we should treat query as non-matching. */
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface MatchResult {
|
||||
matches: boolean;
|
||||
/** Lower is better; only meaningful when matches === true */
|
||||
score: number;
|
||||
}
|
||||
|
||||
function normalizeWhitespaceLower(text: string): string {
|
||||
return text.toLowerCase().replace(/\s+/g, " ").trim();
|
||||
}
|
||||
|
||||
function getSessionSearchText(session: SessionInfo): string {
|
||||
return `${session.id} ${session.name ?? ""} ${session.allMessagesText} ${session.cwd}`;
|
||||
}
|
||||
|
||||
export function hasSessionName(session: SessionInfo): boolean {
|
||||
return Boolean(session.name?.trim());
|
||||
}
|
||||
|
||||
function matchesNameFilter(session: SessionInfo, filter: NameFilter): boolean {
|
||||
if (filter === "all") return true;
|
||||
return hasSessionName(session);
|
||||
}
|
||||
|
||||
export function parseSearchQuery(query: string): ParsedSearchQuery {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) {
|
||||
return { mode: "tokens", tokens: [], regex: null };
|
||||
}
|
||||
|
||||
// Regex mode: re:<pattern>
|
||||
if (trimmed.startsWith("re:")) {
|
||||
const pattern = trimmed.slice(3).trim();
|
||||
if (!pattern) {
|
||||
return { mode: "regex", tokens: [], regex: null, error: "Empty regex" };
|
||||
}
|
||||
try {
|
||||
return { mode: "regex", tokens: [], regex: new RegExp(pattern, "i") };
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return { mode: "regex", tokens: [], regex: null, error: msg };
|
||||
}
|
||||
}
|
||||
|
||||
// Token mode with quote support.
|
||||
// Example: foo "node cve" bar
|
||||
const tokens: { kind: "fuzzy" | "phrase"; value: string }[] = [];
|
||||
let buf = "";
|
||||
let inQuote = false;
|
||||
let hadUnclosedQuote = false;
|
||||
|
||||
const flush = (kind: "fuzzy" | "phrase"): void => {
|
||||
const v = buf.trim();
|
||||
buf = "";
|
||||
if (!v) return;
|
||||
tokens.push({ kind, value: v });
|
||||
};
|
||||
|
||||
for (let i = 0; i < trimmed.length; i++) {
|
||||
const ch = trimmed[i]!;
|
||||
if (ch === '"') {
|
||||
if (inQuote) {
|
||||
flush("phrase");
|
||||
inQuote = false;
|
||||
} else {
|
||||
flush("fuzzy");
|
||||
inQuote = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inQuote && /\s/.test(ch)) {
|
||||
flush("fuzzy");
|
||||
continue;
|
||||
}
|
||||
|
||||
buf += ch;
|
||||
}
|
||||
|
||||
if (inQuote) {
|
||||
hadUnclosedQuote = true;
|
||||
}
|
||||
|
||||
// If quotes were unbalanced, fall back to plain whitespace tokenization.
|
||||
if (hadUnclosedQuote) {
|
||||
return {
|
||||
mode: "tokens",
|
||||
tokens: trimmed
|
||||
.split(/\s+/)
|
||||
.map((t) => t.trim())
|
||||
.filter((t) => t.length > 0)
|
||||
.map((t) => ({ kind: "fuzzy" as const, value: t })),
|
||||
regex: null,
|
||||
};
|
||||
}
|
||||
|
||||
flush(inQuote ? "phrase" : "fuzzy");
|
||||
|
||||
return { mode: "tokens", tokens, regex: null };
|
||||
}
|
||||
|
||||
export function matchSession(
|
||||
session: SessionInfo,
|
||||
parsed: ParsedSearchQuery,
|
||||
): MatchResult {
|
||||
const text = getSessionSearchText(session);
|
||||
|
||||
if (parsed.mode === "regex") {
|
||||
if (!parsed.regex) {
|
||||
return { matches: false, score: 0 };
|
||||
}
|
||||
const idx = text.search(parsed.regex);
|
||||
if (idx < 0) return { matches: false, score: 0 };
|
||||
return { matches: true, score: idx * 0.1 };
|
||||
}
|
||||
|
||||
if (parsed.tokens.length === 0) {
|
||||
return { matches: true, score: 0 };
|
||||
}
|
||||
|
||||
let totalScore = 0;
|
||||
let normalizedText: string | null = null;
|
||||
|
||||
for (const token of parsed.tokens) {
|
||||
if (token.kind === "phrase") {
|
||||
if (normalizedText === null) {
|
||||
normalizedText = normalizeWhitespaceLower(text);
|
||||
}
|
||||
const phrase = normalizeWhitespaceLower(token.value);
|
||||
if (!phrase) continue;
|
||||
const idx = normalizedText.indexOf(phrase);
|
||||
if (idx < 0) return { matches: false, score: 0 };
|
||||
totalScore += idx * 0.1;
|
||||
continue;
|
||||
}
|
||||
|
||||
const m = fuzzyMatch(token.value, text);
|
||||
if (!m.matches) return { matches: false, score: 0 };
|
||||
totalScore += m.score;
|
||||
}
|
||||
|
||||
return { matches: true, score: totalScore };
|
||||
}
|
||||
|
||||
export function filterAndSortSessions(
|
||||
sessions: SessionInfo[],
|
||||
query: string,
|
||||
sortMode: SortMode,
|
||||
nameFilter: NameFilter = "all",
|
||||
): SessionInfo[] {
|
||||
const nameFiltered =
|
||||
nameFilter === "all"
|
||||
? sessions
|
||||
: sessions.filter((session) => matchesNameFilter(session, nameFilter));
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) return nameFiltered;
|
||||
|
||||
const parsed = parseSearchQuery(query);
|
||||
if (parsed.error) return [];
|
||||
|
||||
// Recent mode: filter only, keep incoming order.
|
||||
if (sortMode === "recent") {
|
||||
const filtered: SessionInfo[] = [];
|
||||
for (const s of nameFiltered) {
|
||||
const res = matchSession(s, parsed);
|
||||
if (res.matches) filtered.push(s);
|
||||
}
|
||||
return filtered;
|
||||
}
|
||||
|
||||
// Relevance mode: sort by score, tie-break by modified desc.
|
||||
const scored: { session: SessionInfo; score: number }[] = [];
|
||||
for (const s of nameFiltered) {
|
||||
const res = matchSession(s, parsed);
|
||||
if (!res.matches) continue;
|
||||
scored.push({ session: s, score: res.score });
|
||||
}
|
||||
|
||||
scored.sort((a, b) => {
|
||||
if (a.score !== b.score) return a.score - b.score;
|
||||
return b.session.modified.getTime() - a.session.modified.getTime();
|
||||
});
|
||||
|
||||
return scored.map((r) => r.session);
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,453 @@
|
|||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import type { Transport } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
Container,
|
||||
getCapabilities,
|
||||
type SelectItem,
|
||||
SelectList,
|
||||
type SettingItem,
|
||||
SettingsList,
|
||||
Spacer,
|
||||
Text,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import {
|
||||
getSelectListTheme,
|
||||
getSettingsListTheme,
|
||||
theme,
|
||||
} from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
const THINKING_DESCRIPTIONS: Record<ThinkingLevel, string> = {
|
||||
off: "No reasoning",
|
||||
minimal: "Very brief reasoning (~1k tokens)",
|
||||
low: "Light reasoning (~2k tokens)",
|
||||
medium: "Moderate reasoning (~8k tokens)",
|
||||
high: "Deep reasoning (~16k tokens)",
|
||||
xhigh: "Maximum reasoning (~32k tokens)",
|
||||
};
|
||||
|
||||
export interface SettingsConfig {
|
||||
autoCompact: boolean;
|
||||
showImages: boolean;
|
||||
autoResizeImages: boolean;
|
||||
blockImages: boolean;
|
||||
enableSkillCommands: boolean;
|
||||
steeringMode: "all" | "one-at-a-time";
|
||||
followUpMode: "all" | "one-at-a-time";
|
||||
transport: Transport;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
availableThinkingLevels: ThinkingLevel[];
|
||||
currentTheme: string;
|
||||
availableThemes: string[];
|
||||
hideThinkingBlock: boolean;
|
||||
collapseChangelog: boolean;
|
||||
doubleEscapeAction: "fork" | "tree" | "none";
|
||||
treeFilterMode: "default" | "no-tools" | "user-only" | "labeled-only" | "all";
|
||||
showHardwareCursor: boolean;
|
||||
editorPaddingX: number;
|
||||
autocompleteMaxVisible: number;
|
||||
quietStartup: boolean;
|
||||
clearOnShrink: boolean;
|
||||
}
|
||||
|
||||
export interface SettingsCallbacks {
|
||||
onAutoCompactChange: (enabled: boolean) => void;
|
||||
onShowImagesChange: (enabled: boolean) => void;
|
||||
onAutoResizeImagesChange: (enabled: boolean) => void;
|
||||
onBlockImagesChange: (blocked: boolean) => void;
|
||||
onEnableSkillCommandsChange: (enabled: boolean) => void;
|
||||
onSteeringModeChange: (mode: "all" | "one-at-a-time") => void;
|
||||
onFollowUpModeChange: (mode: "all" | "one-at-a-time") => void;
|
||||
onTransportChange: (transport: Transport) => void;
|
||||
onThinkingLevelChange: (level: ThinkingLevel) => void;
|
||||
onThemeChange: (theme: string) => void;
|
||||
onThemePreview?: (theme: string) => void;
|
||||
onHideThinkingBlockChange: (hidden: boolean) => void;
|
||||
onCollapseChangelogChange: (collapsed: boolean) => void;
|
||||
onDoubleEscapeActionChange: (action: "fork" | "tree" | "none") => void;
|
||||
onTreeFilterModeChange: (
|
||||
mode: "default" | "no-tools" | "user-only" | "labeled-only" | "all",
|
||||
) => void;
|
||||
onShowHardwareCursorChange: (enabled: boolean) => void;
|
||||
onEditorPaddingXChange: (padding: number) => void;
|
||||
onAutocompleteMaxVisibleChange: (maxVisible: number) => void;
|
||||
onQuietStartupChange: (enabled: boolean) => void;
|
||||
onClearOnShrinkChange: (enabled: boolean) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* A submenu component for selecting from a list of options.
|
||||
*/
|
||||
class SelectSubmenu extends Container {
|
||||
private selectList: SelectList;
|
||||
|
||||
constructor(
|
||||
title: string,
|
||||
description: string,
|
||||
options: SelectItem[],
|
||||
currentValue: string,
|
||||
onSelect: (value: string) => void,
|
||||
onCancel: () => void,
|
||||
onSelectionChange?: (value: string) => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
// Title
|
||||
this.addChild(new Text(theme.bold(theme.fg("accent", title)), 0, 0));
|
||||
|
||||
// Description
|
||||
if (description) {
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new Text(theme.fg("muted", description), 0, 0));
|
||||
}
|
||||
|
||||
// Spacer
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Select list
|
||||
this.selectList = new SelectList(
|
||||
options,
|
||||
Math.min(options.length, 10),
|
||||
getSelectListTheme(),
|
||||
);
|
||||
|
||||
// Pre-select current value
|
||||
const currentIndex = options.findIndex((o) => o.value === currentValue);
|
||||
if (currentIndex !== -1) {
|
||||
this.selectList.setSelectedIndex(currentIndex);
|
||||
}
|
||||
|
||||
this.selectList.onSelect = (item) => {
|
||||
onSelect(item.value);
|
||||
};
|
||||
|
||||
this.selectList.onCancel = onCancel;
|
||||
|
||||
if (onSelectionChange) {
|
||||
this.selectList.onSelectionChange = (item) => {
|
||||
onSelectionChange(item.value);
|
||||
};
|
||||
}
|
||||
|
||||
this.addChild(this.selectList);
|
||||
|
||||
// Hint
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Text(theme.fg("dim", " Enter to select · Esc to go back"), 0, 0),
|
||||
);
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
this.selectList.handleInput(data);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Main settings selector component.
|
||||
*/
|
||||
export class SettingsSelectorComponent extends Container {
|
||||
private settingsList: SettingsList;
|
||||
|
||||
constructor(config: SettingsConfig, callbacks: SettingsCallbacks) {
|
||||
super();
|
||||
|
||||
const supportsImages = getCapabilities().images;
|
||||
|
||||
const items: SettingItem[] = [
|
||||
{
|
||||
id: "autocompact",
|
||||
label: "Auto-compact",
|
||||
description: "Automatically compact context when it gets too large",
|
||||
currentValue: config.autoCompact ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
},
|
||||
{
|
||||
id: "steering-mode",
|
||||
label: "Steering mode",
|
||||
description:
|
||||
"Enter while streaming queues steering messages. 'one-at-a-time': deliver one, wait for response. 'all': deliver all at once.",
|
||||
currentValue: config.steeringMode,
|
||||
values: ["one-at-a-time", "all"],
|
||||
},
|
||||
{
|
||||
id: "follow-up-mode",
|
||||
label: "Follow-up mode",
|
||||
description:
|
||||
"Alt+Enter queues follow-up messages until agent stops. 'one-at-a-time': deliver one, wait for response. 'all': deliver all at once.",
|
||||
currentValue: config.followUpMode,
|
||||
values: ["one-at-a-time", "all"],
|
||||
},
|
||||
{
|
||||
id: "transport",
|
||||
label: "Transport",
|
||||
description:
|
||||
"Preferred transport for providers that support multiple transports",
|
||||
currentValue: config.transport,
|
||||
values: ["sse", "websocket", "auto"],
|
||||
},
|
||||
{
|
||||
id: "hide-thinking",
|
||||
label: "Hide thinking",
|
||||
description: "Hide thinking blocks in assistant responses",
|
||||
currentValue: config.hideThinkingBlock ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
},
|
||||
{
|
||||
id: "collapse-changelog",
|
||||
label: "Collapse changelog",
|
||||
description: "Show condensed changelog after updates",
|
||||
currentValue: config.collapseChangelog ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
},
|
||||
{
|
||||
id: "quiet-startup",
|
||||
label: "Quiet startup",
|
||||
description: "Disable verbose printing at startup",
|
||||
currentValue: config.quietStartup ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
},
|
||||
{
|
||||
id: "double-escape-action",
|
||||
label: "Double-escape action",
|
||||
description: "Action when pressing Escape twice with empty editor",
|
||||
currentValue: config.doubleEscapeAction,
|
||||
values: ["tree", "fork", "none"],
|
||||
},
|
||||
{
|
||||
id: "tree-filter-mode",
|
||||
label: "Tree filter mode",
|
||||
description: "Default filter when opening /tree",
|
||||
currentValue: config.treeFilterMode,
|
||||
values: ["default", "no-tools", "user-only", "labeled-only", "all"],
|
||||
},
|
||||
{
|
||||
id: "thinking",
|
||||
label: "Thinking level",
|
||||
description: "Reasoning depth for thinking-capable models",
|
||||
currentValue: config.thinkingLevel,
|
||||
submenu: (currentValue, done) =>
|
||||
new SelectSubmenu(
|
||||
"Thinking Level",
|
||||
"Select reasoning depth for thinking-capable models",
|
||||
config.availableThinkingLevels.map((level) => ({
|
||||
value: level,
|
||||
label: level,
|
||||
description: THINKING_DESCRIPTIONS[level],
|
||||
})),
|
||||
currentValue,
|
||||
(value) => {
|
||||
callbacks.onThinkingLevelChange(value as ThinkingLevel);
|
||||
done(value);
|
||||
},
|
||||
() => done(),
|
||||
),
|
||||
},
|
||||
{
|
||||
id: "theme",
|
||||
label: "Theme",
|
||||
description: "Color theme for the interface",
|
||||
currentValue: config.currentTheme,
|
||||
submenu: (currentValue, done) =>
|
||||
new SelectSubmenu(
|
||||
"Theme",
|
||||
"Select color theme",
|
||||
config.availableThemes.map((t) => ({
|
||||
value: t,
|
||||
label: t,
|
||||
})),
|
||||
currentValue,
|
||||
(value) => {
|
||||
callbacks.onThemeChange(value);
|
||||
done(value);
|
||||
},
|
||||
() => {
|
||||
// Restore original theme on cancel
|
||||
callbacks.onThemePreview?.(currentValue);
|
||||
done();
|
||||
},
|
||||
(value) => {
|
||||
// Preview theme on selection change
|
||||
callbacks.onThemePreview?.(value);
|
||||
},
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
// Only show image toggle if terminal supports it
|
||||
if (supportsImages) {
|
||||
// Insert after autocompact
|
||||
items.splice(1, 0, {
|
||||
id: "show-images",
|
||||
label: "Show images",
|
||||
description: "Render images inline in terminal",
|
||||
currentValue: config.showImages ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
}
|
||||
|
||||
// Image auto-resize toggle (always available, affects both attached and read images)
|
||||
items.splice(supportsImages ? 2 : 1, 0, {
|
||||
id: "auto-resize-images",
|
||||
label: "Auto-resize images",
|
||||
description:
|
||||
"Resize large images to 2000x2000 max for better model compatibility",
|
||||
currentValue: config.autoResizeImages ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
|
||||
// Block images toggle (always available, insert after auto-resize-images)
|
||||
const autoResizeIndex = items.findIndex(
|
||||
(item) => item.id === "auto-resize-images",
|
||||
);
|
||||
items.splice(autoResizeIndex + 1, 0, {
|
||||
id: "block-images",
|
||||
label: "Block images",
|
||||
description: "Prevent images from being sent to LLM providers",
|
||||
currentValue: config.blockImages ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
|
||||
// Skill commands toggle (insert after block-images)
|
||||
const blockImagesIndex = items.findIndex(
|
||||
(item) => item.id === "block-images",
|
||||
);
|
||||
items.splice(blockImagesIndex + 1, 0, {
|
||||
id: "skill-commands",
|
||||
label: "Skill commands",
|
||||
description: "Register skills as /skill:name commands",
|
||||
currentValue: config.enableSkillCommands ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
|
||||
// Hardware cursor toggle (insert after skill-commands)
|
||||
const skillCommandsIndex = items.findIndex(
|
||||
(item) => item.id === "skill-commands",
|
||||
);
|
||||
items.splice(skillCommandsIndex + 1, 0, {
|
||||
id: "show-hardware-cursor",
|
||||
label: "Show hardware cursor",
|
||||
description:
|
||||
"Show the terminal cursor while still positioning it for IME support",
|
||||
currentValue: config.showHardwareCursor ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
|
||||
// Editor padding toggle (insert after show-hardware-cursor)
|
||||
const hardwareCursorIndex = items.findIndex(
|
||||
(item) => item.id === "show-hardware-cursor",
|
||||
);
|
||||
items.splice(hardwareCursorIndex + 1, 0, {
|
||||
id: "editor-padding",
|
||||
label: "Editor padding",
|
||||
description: "Horizontal padding for input editor (0-3)",
|
||||
currentValue: String(config.editorPaddingX),
|
||||
values: ["0", "1", "2", "3"],
|
||||
});
|
||||
|
||||
// Autocomplete max visible toggle (insert after editor-padding)
|
||||
const editorPaddingIndex = items.findIndex(
|
||||
(item) => item.id === "editor-padding",
|
||||
);
|
||||
items.splice(editorPaddingIndex + 1, 0, {
|
||||
id: "autocomplete-max-visible",
|
||||
label: "Autocomplete max items",
|
||||
description: "Max visible items in autocomplete dropdown (3-20)",
|
||||
currentValue: String(config.autocompleteMaxVisible),
|
||||
values: ["3", "5", "7", "10", "15", "20"],
|
||||
});
|
||||
|
||||
// Clear on shrink toggle (insert after autocomplete-max-visible)
|
||||
const autocompleteIndex = items.findIndex(
|
||||
(item) => item.id === "autocomplete-max-visible",
|
||||
);
|
||||
items.splice(autocompleteIndex + 1, 0, {
|
||||
id: "clear-on-shrink",
|
||||
label: "Clear on shrink",
|
||||
description: "Clear empty rows when content shrinks (may cause flicker)",
|
||||
currentValue: config.clearOnShrink ? "true" : "false",
|
||||
values: ["true", "false"],
|
||||
});
|
||||
|
||||
// Add borders
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
this.settingsList = new SettingsList(
|
||||
items,
|
||||
10,
|
||||
getSettingsListTheme(),
|
||||
(id, newValue) => {
|
||||
switch (id) {
|
||||
case "autocompact":
|
||||
callbacks.onAutoCompactChange(newValue === "true");
|
||||
break;
|
||||
case "show-images":
|
||||
callbacks.onShowImagesChange(newValue === "true");
|
||||
break;
|
||||
case "auto-resize-images":
|
||||
callbacks.onAutoResizeImagesChange(newValue === "true");
|
||||
break;
|
||||
case "block-images":
|
||||
callbacks.onBlockImagesChange(newValue === "true");
|
||||
break;
|
||||
case "skill-commands":
|
||||
callbacks.onEnableSkillCommandsChange(newValue === "true");
|
||||
break;
|
||||
case "steering-mode":
|
||||
callbacks.onSteeringModeChange(newValue as "all" | "one-at-a-time");
|
||||
break;
|
||||
case "follow-up-mode":
|
||||
callbacks.onFollowUpModeChange(newValue as "all" | "one-at-a-time");
|
||||
break;
|
||||
case "transport":
|
||||
callbacks.onTransportChange(newValue as Transport);
|
||||
break;
|
||||
case "hide-thinking":
|
||||
callbacks.onHideThinkingBlockChange(newValue === "true");
|
||||
break;
|
||||
case "collapse-changelog":
|
||||
callbacks.onCollapseChangelogChange(newValue === "true");
|
||||
break;
|
||||
case "quiet-startup":
|
||||
callbacks.onQuietStartupChange(newValue === "true");
|
||||
break;
|
||||
case "double-escape-action":
|
||||
callbacks.onDoubleEscapeActionChange(newValue as "fork" | "tree");
|
||||
break;
|
||||
case "tree-filter-mode":
|
||||
callbacks.onTreeFilterModeChange(
|
||||
newValue as
|
||||
| "default"
|
||||
| "no-tools"
|
||||
| "user-only"
|
||||
| "labeled-only"
|
||||
| "all",
|
||||
);
|
||||
break;
|
||||
case "show-hardware-cursor":
|
||||
callbacks.onShowHardwareCursorChange(newValue === "true");
|
||||
break;
|
||||
case "editor-padding":
|
||||
callbacks.onEditorPaddingXChange(parseInt(newValue, 10));
|
||||
break;
|
||||
case "autocomplete-max-visible":
|
||||
callbacks.onAutocompleteMaxVisibleChange(parseInt(newValue, 10));
|
||||
break;
|
||||
case "clear-on-shrink":
|
||||
callbacks.onClearOnShrinkChange(newValue === "true");
|
||||
break;
|
||||
}
|
||||
},
|
||||
callbacks.onCancel,
|
||||
{ enableSearch: true },
|
||||
);
|
||||
|
||||
this.addChild(this.settingsList);
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
getSettingsList(): SettingsList {
|
||||
return this.settingsList;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
|
||||
import { getSelectListTheme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
/**
|
||||
* Component that renders a show images selector with borders
|
||||
*/
|
||||
export class ShowImagesSelectorComponent extends Container {
|
||||
private selectList: SelectList;
|
||||
|
||||
constructor(
|
||||
currentValue: boolean,
|
||||
onSelect: (show: boolean) => void,
|
||||
onCancel: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
const items: SelectItem[] = [
|
||||
{
|
||||
value: "yes",
|
||||
label: "Yes",
|
||||
description: "Show images inline in terminal",
|
||||
},
|
||||
{
|
||||
value: "no",
|
||||
label: "No",
|
||||
description: "Show text placeholder instead",
|
||||
},
|
||||
];
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Create selector
|
||||
this.selectList = new SelectList(items, 5, getSelectListTheme());
|
||||
|
||||
// Preselect current value
|
||||
this.selectList.setSelectedIndex(currentValue ? 0 : 1);
|
||||
|
||||
this.selectList.onSelect = (item) => {
|
||||
onSelect(item.value === "yes");
|
||||
};
|
||||
|
||||
this.selectList.onCancel = () => {
|
||||
onCancel();
|
||||
};
|
||||
|
||||
this.addChild(this.selectList);
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
getSelectList(): SelectList {
|
||||
return this.selectList;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import { Box, Markdown, type MarkdownTheme, Text } from "@mariozechner/pi-tui";
|
||||
import type { ParsedSkillBlock } from "../../../core/agent-session.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
import { editorKey } from "./keybinding-hints.js";
|
||||
|
||||
/**
|
||||
* Component that renders a skill invocation message with collapsed/expanded state.
|
||||
* Uses same background color as custom messages for visual consistency.
|
||||
* Only renders the skill block itself - user message is rendered separately.
|
||||
*/
|
||||
export class SkillInvocationMessageComponent extends Box {
|
||||
private expanded = false;
|
||||
private skillBlock: ParsedSkillBlock;
|
||||
private markdownTheme: MarkdownTheme;
|
||||
|
||||
constructor(
|
||||
skillBlock: ParsedSkillBlock,
|
||||
markdownTheme: MarkdownTheme = getMarkdownTheme(),
|
||||
) {
|
||||
super(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
this.skillBlock = skillBlock;
|
||||
this.markdownTheme = markdownTheme;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
override invalidate(): void {
|
||||
super.invalidate();
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
if (this.expanded) {
|
||||
// Expanded: label + skill name header + full content
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[skill]\x1b[22m`);
|
||||
this.addChild(new Text(label, 0, 0));
|
||||
const header = `**${this.skillBlock.name}**\n\n`;
|
||||
this.addChild(
|
||||
new Markdown(
|
||||
header + this.skillBlock.content,
|
||||
0,
|
||||
0,
|
||||
this.markdownTheme,
|
||||
{
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
},
|
||||
),
|
||||
);
|
||||
} else {
|
||||
// Collapsed: single line - [skill] name (hint to expand)
|
||||
const line =
|
||||
theme.fg("customMessageLabel", `\x1b[1m[skill]\x1b[22m `) +
|
||||
theme.fg("customMessageText", this.skillBlock.name) +
|
||||
theme.fg("dim", ` (${editorKey("expandTools")} to expand)`);
|
||||
this.addChild(new Text(line, 0, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
|
||||
import { getAvailableThemes, getSelectListTheme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
/**
|
||||
* Component that renders a theme selector
|
||||
*/
|
||||
export class ThemeSelectorComponent extends Container {
|
||||
private selectList: SelectList;
|
||||
private onPreview: (themeName: string) => void;
|
||||
|
||||
constructor(
|
||||
currentTheme: string,
|
||||
onSelect: (themeName: string) => void,
|
||||
onCancel: () => void,
|
||||
onPreview: (themeName: string) => void,
|
||||
) {
|
||||
super();
|
||||
this.onPreview = onPreview;
|
||||
|
||||
// Get available themes and create select items
|
||||
const themes = getAvailableThemes();
|
||||
const themeItems: SelectItem[] = themes.map((name) => ({
|
||||
value: name,
|
||||
label: name,
|
||||
description: name === currentTheme ? "(current)" : undefined,
|
||||
}));
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Create selector
|
||||
this.selectList = new SelectList(themeItems, 10, getSelectListTheme());
|
||||
|
||||
// Preselect current theme
|
||||
const currentIndex = themes.indexOf(currentTheme);
|
||||
if (currentIndex !== -1) {
|
||||
this.selectList.setSelectedIndex(currentIndex);
|
||||
}
|
||||
|
||||
this.selectList.onSelect = (item) => {
|
||||
onSelect(item.value);
|
||||
};
|
||||
|
||||
this.selectList.onCancel = () => {
|
||||
onCancel();
|
||||
};
|
||||
|
||||
this.selectList.onSelectionChange = (item) => {
|
||||
this.onPreview(item.value);
|
||||
};
|
||||
|
||||
this.addChild(this.selectList);
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
getSelectList(): SelectList {
|
||||
return this.selectList;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import type { ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import { Container, type SelectItem, SelectList } from "@mariozechner/pi-tui";
|
||||
import { getSelectListTheme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
const LEVEL_DESCRIPTIONS: Record<ThinkingLevel, string> = {
|
||||
off: "No reasoning",
|
||||
minimal: "Very brief reasoning (~1k tokens)",
|
||||
low: "Light reasoning (~2k tokens)",
|
||||
medium: "Moderate reasoning (~8k tokens)",
|
||||
high: "Deep reasoning (~16k tokens)",
|
||||
xhigh: "Maximum reasoning (~32k tokens)",
|
||||
};
|
||||
|
||||
/**
|
||||
* Component that renders a thinking level selector with borders
|
||||
*/
|
||||
export class ThinkingSelectorComponent extends Container {
|
||||
private selectList: SelectList;
|
||||
|
||||
constructor(
|
||||
currentLevel: ThinkingLevel,
|
||||
availableLevels: ThinkingLevel[],
|
||||
onSelect: (level: ThinkingLevel) => void,
|
||||
onCancel: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
const thinkingLevels: SelectItem[] = availableLevels.map((level) => ({
|
||||
value: level,
|
||||
label: level,
|
||||
description: LEVEL_DESCRIPTIONS[level],
|
||||
}));
|
||||
|
||||
// Add top border
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Create selector
|
||||
this.selectList = new SelectList(
|
||||
thinkingLevels,
|
||||
thinkingLevels.length,
|
||||
getSelectListTheme(),
|
||||
);
|
||||
|
||||
// Preselect current level
|
||||
const currentIndex = thinkingLevels.findIndex(
|
||||
(item) => item.value === currentLevel,
|
||||
);
|
||||
if (currentIndex !== -1) {
|
||||
this.selectList.setSelectedIndex(currentIndex);
|
||||
}
|
||||
|
||||
this.selectList.onSelect = (item) => {
|
||||
onSelect(item.value as ThinkingLevel);
|
||||
};
|
||||
|
||||
this.selectList.onCancel = () => {
|
||||
onCancel();
|
||||
};
|
||||
|
||||
this.addChild(this.selectList);
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new DynamicBorder());
|
||||
}
|
||||
|
||||
getSelectList(): SelectList {
|
||||
return this.selectList;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,179 @@
|
|||
import {
|
||||
type Component,
|
||||
Container,
|
||||
getEditorKeybindings,
|
||||
Spacer,
|
||||
Text,
|
||||
truncateToWidth,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
interface UserMessageItem {
|
||||
id: string; // Entry ID in the session
|
||||
text: string; // The message text
|
||||
timestamp?: string; // Optional timestamp if available
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom user message list component with selection
|
||||
*/
|
||||
class UserMessageList implements Component {
|
||||
private messages: UserMessageItem[] = [];
|
||||
private selectedIndex: number = 0;
|
||||
public onSelect?: (entryId: string) => void;
|
||||
public onCancel?: () => void;
|
||||
private maxVisible: number = 10; // Max messages visible
|
||||
|
||||
constructor(messages: UserMessageItem[]) {
|
||||
// Store messages in chronological order (oldest to newest)
|
||||
this.messages = messages;
|
||||
// Start with the last (most recent) message selected
|
||||
this.selectedIndex = Math.max(0, messages.length - 1);
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
// No cached state to invalidate currently
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
const lines: string[] = [];
|
||||
|
||||
if (this.messages.length === 0) {
|
||||
lines.push(theme.fg("muted", " No user messages found"));
|
||||
return lines;
|
||||
}
|
||||
|
||||
// Calculate visible range with scrolling
|
||||
const startIndex = Math.max(
|
||||
0,
|
||||
Math.min(
|
||||
this.selectedIndex - Math.floor(this.maxVisible / 2),
|
||||
this.messages.length - this.maxVisible,
|
||||
),
|
||||
);
|
||||
const endIndex = Math.min(
|
||||
startIndex + this.maxVisible,
|
||||
this.messages.length,
|
||||
);
|
||||
|
||||
// Render visible messages (2 lines per message + blank line)
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const message = this.messages[i];
|
||||
const isSelected = i === this.selectedIndex;
|
||||
|
||||
// Normalize message to single line
|
||||
const normalizedMessage = message.text.replace(/\n/g, " ").trim();
|
||||
|
||||
// First line: cursor + message
|
||||
const cursor = isSelected ? theme.fg("accent", "› ") : " ";
|
||||
const maxMsgWidth = width - 2; // Account for cursor (2 chars)
|
||||
const truncatedMsg = truncateToWidth(normalizedMessage, maxMsgWidth);
|
||||
const messageLine =
|
||||
cursor + (isSelected ? theme.bold(truncatedMsg) : truncatedMsg);
|
||||
|
||||
lines.push(messageLine);
|
||||
|
||||
// Second line: metadata (position in history)
|
||||
const position = i + 1;
|
||||
const metadata = ` Message ${position} of ${this.messages.length}`;
|
||||
const metadataLine = theme.fg("muted", metadata);
|
||||
lines.push(metadataLine);
|
||||
lines.push(""); // Blank line between messages
|
||||
}
|
||||
|
||||
// Add scroll indicator if needed
|
||||
if (startIndex > 0 || endIndex < this.messages.length) {
|
||||
const scrollInfo = theme.fg(
|
||||
"muted",
|
||||
` (${this.selectedIndex + 1}/${this.messages.length})`,
|
||||
);
|
||||
lines.push(scrollInfo);
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
// Up arrow - go to previous (older) message, wrap to bottom when at top
|
||||
if (kb.matches(keyData, "selectUp")) {
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === 0
|
||||
? this.messages.length - 1
|
||||
: this.selectedIndex - 1;
|
||||
}
|
||||
// Down arrow - go to next (newer) message, wrap to top when at bottom
|
||||
else if (kb.matches(keyData, "selectDown")) {
|
||||
this.selectedIndex =
|
||||
this.selectedIndex === this.messages.length - 1
|
||||
? 0
|
||||
: this.selectedIndex + 1;
|
||||
}
|
||||
// Enter - select message and branch
|
||||
else if (kb.matches(keyData, "selectConfirm")) {
|
||||
const selected = this.messages[this.selectedIndex];
|
||||
if (selected && this.onSelect) {
|
||||
this.onSelect(selected.id);
|
||||
}
|
||||
}
|
||||
// Escape - cancel
|
||||
else if (kb.matches(keyData, "selectCancel")) {
|
||||
if (this.onCancel) {
|
||||
this.onCancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Component that renders a user message selector for branching
|
||||
*/
|
||||
export class UserMessageSelectorComponent extends Container {
|
||||
private messageList: UserMessageList;
|
||||
|
||||
constructor(
|
||||
messages: UserMessageItem[],
|
||||
onSelect: (entryId: string) => void,
|
||||
onCancel: () => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
// Add header
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new Text(theme.bold("Branch from Message"), 1, 0));
|
||||
this.addChild(
|
||||
new Text(
|
||||
theme.fg(
|
||||
"muted",
|
||||
"Select a message to create a new branch from that point",
|
||||
),
|
||||
1,
|
||||
0,
|
||||
),
|
||||
);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create message list
|
||||
this.messageList = new UserMessageList(messages);
|
||||
this.messageList.onSelect = onSelect;
|
||||
this.messageList.onCancel = onCancel;
|
||||
|
||||
this.addChild(this.messageList);
|
||||
|
||||
// Add bottom border
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
// Auto-cancel if no messages
|
||||
if (messages.length === 0) {
|
||||
setTimeout(() => onCancel(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
getMessageList(): UserMessageList {
|
||||
return this.messageList;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import {
|
||||
Container,
|
||||
Markdown,
|
||||
type MarkdownTheme,
|
||||
Spacer,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
const OSC133_ZONE_START = "\x1b]133;A\x07";
|
||||
const OSC133_ZONE_END = "\x1b]133;B\x07";
|
||||
|
||||
/**
|
||||
* Component that renders a user message
|
||||
*/
|
||||
export class UserMessageComponent extends Container {
|
||||
constructor(text: string, markdownTheme: MarkdownTheme = getMarkdownTheme()) {
|
||||
super();
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Markdown(text, 1, 1, markdownTheme, {
|
||||
bgColor: (text: string) => theme.bg("userMessageBg", text),
|
||||
color: (text: string) => theme.fg("userMessageText", text),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
override render(width: number): string[] {
|
||||
const lines = super.render(width);
|
||||
if (lines.length === 0) {
|
||||
return lines;
|
||||
}
|
||||
|
||||
lines[0] = OSC133_ZONE_START + lines[0];
|
||||
lines[lines.length - 1] = lines[lines.length - 1] + OSC133_ZONE_END;
|
||||
return lines;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue