Reorganize file structure: core/, utils/, modes/interactive/components/, modes/interactive/theme/

This commit is contained in:
Mario Zechner 2025-12-09 00:51:33 +01:00
parent 00982705f2
commit 83a6c26969
56 changed files with 133 additions and 128 deletions

View file

@ -15,15 +15,15 @@
import type { Agent, AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model } from "@mariozechner/pi-ai";
import { calculateContextTokens, compact, shouldCompact } from "../compaction.js";
import { getModelsPath } from "../config.js";
import { exportSessionToHtml } from "../export-html.js";
import type { BashExecutionMessage } from "../messages.js";
import { getApiKeyForModel, getAvailableModels } from "../model-config.js";
import { loadSessionFromEntries, type SessionManager } from "../session-manager.js";
import type { SettingsManager } from "../settings-manager.js";
import { expandSlashCommand, type FileSlashCommand } from "../slash-commands.js";
import { getModelsPath } from "../utils/config.js";
import { type BashResult, executeBash as executeBashCommand } from "./bash-executor.js";
import { calculateContextTokens, compact, shouldCompact } from "./compaction.js";
import { exportSessionToHtml } from "./export-html.js";
import type { BashExecutionMessage } from "./messages.js";
import { getApiKeyForModel, getAvailableModels } from "./model-config.js";
import { loadSessionFromEntries, type SessionManager } from "./session-manager.js";
import type { SettingsManager } from "./settings-manager.js";
import { expandSlashCommand, type FileSlashCommand } from "./slash-commands.js";
/** Listener function for agent events */
export type AgentEventListener = (event: AgentEvent) => void;

View file

@ -12,8 +12,8 @@ import { tmpdir } from "node:os";
import { join } from "node:path";
import { type ChildProcess, spawn } from "child_process";
import stripAnsi from "strip-ansi";
import { getShellConfig, killProcessTree, sanitizeBinaryOutput } from "../shell.js";
import { DEFAULT_MAX_BYTES, truncateTail } from "../tools/truncate.js";
import { getShellConfig, killProcessTree, sanitizeBinaryOutput } from "../utils/shell.js";
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
// ============================================================================
// Types

View file

@ -0,0 +1,293 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { complete } from "@mariozechner/pi-ai";
import { messageTransformer } from "./messages.js";
import type { CompactionEntry, SessionEntry } from "./session-manager.js";
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
}
/**
* Get usage from an assistant message if available.
*/
function getAssistantUsage(msg: AppMessage): Usage | null {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason !== "aborted" && assistantMsg.usage) {
return assistantMsg.usage;
}
}
return null;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | null {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return null;
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Find indices of message entries that are user messages (turn boundaries).
*/
function findTurnBoundaries(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
const boundaries: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
if (entry.type === "message" && entry.message.role === "user") {
boundaries.push(i);
}
}
return boundaries;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
* Returns the entry index of the first message to keep (a user message for turn integrity).
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): number {
const boundaries = findTurnBoundaries(entries, startIndex, endIndex);
if (boundaries.length === 0) {
return startIndex; // No user messages, keep everything in range
}
// Collect assistant usages walking backwards from endIndex
const assistantUsages: Array<{ index: number; tokens: number }> = [];
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) {
assistantUsages.push({
index: i,
tokens: calculateContextTokens(usage),
});
}
}
}
if (assistantUsages.length === 0) {
// No usage info, keep last turn only
return boundaries[boundaries.length - 1];
}
// Walk through and find where cumulative token difference exceeds keepRecentTokens
const newestTokens = assistantUsages[0].tokens;
let cutIndex = startIndex; // Default: keep everything in range
for (let i = 1; i < assistantUsages.length; i++) {
const tokenDiff = newestTokens - assistantUsages[i].tokens;
if (tokenDiff >= keepRecentTokens) {
// Find the turn boundary at or before the assistant we want to keep
const lastKeptAssistantIndex = assistantUsages[i - 1].index;
for (let b = boundaries.length - 1; b >= 0; b--) {
if (boundaries[b] <= lastKeptAssistantIndex) {
cutIndex = boundaries[b];
break;
}
}
break;
}
}
return cutIndex;
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION. Create a handoff summary for another LLM that will resume the task.
Include:
- Current progress and key decisions made
- Important context, constraints, or user preferences
- Absolute file paths of any relevant files that were read or modified
- What remains to be done (clear next steps)
- Any critical data, examples, or references needed to continue
Be concise, structured, and focused on helping the next LLM seamlessly continue the work.`;
/**
* Generate a summary of the conversation using the LLM.
*/
export async function generateSummary(
currentMessages: AppMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
const prompt = customInstructions
? `${SUMMARIZATION_PROMPT}\n\nAdditional focus: ${customInstructions}`
: SUMMARIZATION_PROMPT;
// Transform custom messages (like bashExecution) to LLM-compatible messages
const transformedMessages = messageTransformer(currentMessages);
const summarizationMessages = [
...transformedMessages,
{
role: "user" as const,
content: [{ type: "text" as const, text: prompt }],
timestamp: Date.now(),
},
];
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Main compaction function
// ============================================================================
/**
* Calculate compaction and generate summary.
* Returns the CompactionEntry to append to the session file.
*
* @param entries - All session entries
* @param model - Model to use for summarization
* @param settings - Compaction settings
* @param apiKey - API key for LLM
* @param signal - Optional abort signal
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
entries: SessionEntry[],
model: Model<any>,
settings: CompactionSettings,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
): Promise<CompactionEntry> {
// Don't compact if the last entry is already a compaction
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
throw new Error("Already compacted");
}
// Find previous compaction boundary
let prevCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = entries.length;
// Get token count before compaction
const lastUsage = getLastAssistantUsage(entries);
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
// Find cut point (entry index) within the valid range
const firstKeptEntryIndex = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
// Extract messages to summarize (before the cut point)
const messagesToSummarize: AppMessage[] = [];
for (let i = boundaryStart; i < firstKeptEntryIndex; i++) {
const entry = entries[i];
if (entry.type === "message") {
messagesToSummarize.push(entry.message);
}
}
// Also include the previous summary if there was a compaction
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
// Prepend the previous summary as context
messagesToSummarize.unshift({
role: "user",
content: `Previous session summary:\n${prevCompaction.summary}`,
timestamp: Date.now(),
});
}
// Generate summary from messages before the cut point
const summary = await generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
);
return {
type: "compaction",
timestamp: new Date().toISOString(),
summary,
firstKeptEntryIndex,
tokensBefore,
};
}

View file

@ -0,0 +1,842 @@
import type { AgentState } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Message, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { existsSync, readFileSync, writeFileSync } from "fs";
import { homedir } from "os";
import { basename } from "path";
import { APP_NAME, VERSION } from "../utils/config.js";
import { type BashExecutionMessage, isBashExecutionMessage } from "./messages.js";
import type { SessionManager } from "./session-manager.js";
// ============================================================================
// Types
// ============================================================================
interface MessageEvent {
type: "message";
message: Message;
timestamp?: number;
}
interface ModelChangeEvent {
type: "model_change";
provider: string;
modelId: string;
timestamp?: number;
}
interface CompactionEvent {
type: "compaction";
timestamp: string;
summary: string;
tokensBefore: number;
}
type SessionEvent = MessageEvent | ModelChangeEvent | CompactionEvent;
interface ParsedSessionData {
sessionId: string;
timestamp: string;
systemPrompt?: string;
modelsUsed: Set<string>;
messages: Message[];
toolResultsMap: Map<string, ToolResultMessage>;
sessionEvents: SessionEvent[];
tokenStats: { input: number; output: number; cacheRead: number; cacheWrite: number };
costStats: { input: number; output: number; cacheRead: number; cacheWrite: number };
tools?: { name: string; description: string }[];
contextWindow?: number;
isStreamingFormat?: boolean;
}
// ============================================================================
// Color scheme (matching TUI)
// ============================================================================
const COLORS = {
userMessageBg: "rgb(52, 53, 65)",
toolPendingBg: "rgb(40, 40, 50)",
toolSuccessBg: "rgb(40, 50, 40)",
toolErrorBg: "rgb(60, 40, 40)",
userBashBg: "rgb(50, 48, 35)", // Faint yellow/brown for user-executed bash
userBashErrorBg: "rgb(60, 45, 35)", // Slightly more orange for errors
bodyBg: "rgb(24, 24, 30)",
containerBg: "rgb(30, 30, 36)",
text: "rgb(229, 229, 231)",
textDim: "rgb(161, 161, 170)",
cyan: "rgb(103, 232, 249)",
green: "rgb(34, 197, 94)",
red: "rgb(239, 68, 68)",
yellow: "rgb(234, 179, 8)",
};
// ============================================================================
// Utility functions
// ============================================================================
function escapeHtml(text: string): string {
return text
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
function shortenPath(path: string): string {
const home = homedir();
return path.startsWith(home) ? "~" + path.slice(home.length) : path;
}
function replaceTabs(text: string): string {
return text.replace(/\t/g, " ");
}
function formatTimestamp(timestamp: number | string | undefined): string {
if (!timestamp) return "";
const date = new Date(typeof timestamp === "string" ? timestamp : timestamp);
return date.toLocaleTimeString(undefined, { hour: "2-digit", minute: "2-digit", second: "2-digit" });
}
function formatExpandableOutput(lines: string[], maxLines: number): string {
const displayLines = lines.slice(0, maxLines);
const remaining = lines.length - maxLines;
if (remaining > 0) {
let out = '<div class="tool-output expandable" onclick="this.classList.toggle(\'expanded\')">';
out += '<div class="output-preview">';
for (const line of displayLines) {
out += `<div>${escapeHtml(replaceTabs(line))}</div>`;
}
out += `<div class="expand-hint">... (${remaining} more lines) - click to expand</div>`;
out += "</div>";
out += '<div class="output-full">';
for (const line of lines) {
out += `<div>${escapeHtml(replaceTabs(line))}</div>`;
}
out += "</div></div>";
return out;
}
let out = '<div class="tool-output">';
for (const line of displayLines) {
out += `<div>${escapeHtml(replaceTabs(line))}</div>`;
}
out += "</div>";
return out;
}
// ============================================================================
// Parsing functions
// ============================================================================
function parseSessionManagerFormat(lines: string[]): ParsedSessionData {
const data: ParsedSessionData = {
sessionId: "unknown",
timestamp: new Date().toISOString(),
modelsUsed: new Set(),
messages: [],
toolResultsMap: new Map(),
sessionEvents: [],
tokenStats: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
costStats: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
};
for (const line of lines) {
let entry: { type: string; [key: string]: unknown };
try {
entry = JSON.parse(line) as { type: string; [key: string]: unknown };
} catch {
continue;
}
switch (entry.type) {
case "session":
data.sessionId = (entry.id as string) || "unknown";
data.timestamp = (entry.timestamp as string) || data.timestamp;
data.systemPrompt = entry.systemPrompt as string | undefined;
if (entry.modelId) {
const modelInfo = entry.provider ? `${entry.provider}/${entry.modelId}` : (entry.modelId as string);
data.modelsUsed.add(modelInfo);
}
break;
case "message": {
const message = entry.message as Message;
data.messages.push(message);
data.sessionEvents.push({
type: "message",
message,
timestamp: entry.timestamp as number | undefined,
});
if (message.role === "toolResult") {
const toolResult = message as ToolResultMessage;
data.toolResultsMap.set(toolResult.toolCallId, toolResult);
} else if (message.role === "assistant") {
const assistantMsg = message as AssistantMessage;
if (assistantMsg.usage) {
data.tokenStats.input += assistantMsg.usage.input || 0;
data.tokenStats.output += assistantMsg.usage.output || 0;
data.tokenStats.cacheRead += assistantMsg.usage.cacheRead || 0;
data.tokenStats.cacheWrite += assistantMsg.usage.cacheWrite || 0;
if (assistantMsg.usage.cost) {
data.costStats.input += assistantMsg.usage.cost.input || 0;
data.costStats.output += assistantMsg.usage.cost.output || 0;
data.costStats.cacheRead += assistantMsg.usage.cost.cacheRead || 0;
data.costStats.cacheWrite += assistantMsg.usage.cost.cacheWrite || 0;
}
}
}
break;
}
case "model_change":
data.sessionEvents.push({
type: "model_change",
provider: entry.provider as string,
modelId: entry.modelId as string,
timestamp: entry.timestamp as number | undefined,
});
if (entry.modelId) {
const modelInfo = entry.provider ? `${entry.provider}/${entry.modelId}` : (entry.modelId as string);
data.modelsUsed.add(modelInfo);
}
break;
case "compaction":
data.sessionEvents.push({
type: "compaction",
timestamp: entry.timestamp as string,
summary: entry.summary as string,
tokensBefore: entry.tokensBefore as number,
});
break;
}
}
return data;
}
function parseStreamingEventFormat(lines: string[]): ParsedSessionData {
const data: ParsedSessionData = {
sessionId: "unknown",
timestamp: new Date().toISOString(),
modelsUsed: new Set(),
messages: [],
toolResultsMap: new Map(),
sessionEvents: [],
tokenStats: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
costStats: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
isStreamingFormat: true,
};
let timestampSet = false;
for (const line of lines) {
let entry: { type: string; message?: Message };
try {
entry = JSON.parse(line) as { type: string; message?: Message };
} catch {
continue;
}
if (entry.type === "message_end" && entry.message) {
const msg = entry.message;
data.messages.push(msg);
data.sessionEvents.push({
type: "message",
message: msg,
timestamp: (msg as { timestamp?: number }).timestamp,
});
if (msg.role === "toolResult") {
const toolResult = msg as ToolResultMessage;
data.toolResultsMap.set(toolResult.toolCallId, toolResult);
} else if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.model) {
const modelInfo = assistantMsg.provider
? `${assistantMsg.provider}/${assistantMsg.model}`
: assistantMsg.model;
data.modelsUsed.add(modelInfo);
}
if (assistantMsg.usage) {
data.tokenStats.input += assistantMsg.usage.input || 0;
data.tokenStats.output += assistantMsg.usage.output || 0;
data.tokenStats.cacheRead += assistantMsg.usage.cacheRead || 0;
data.tokenStats.cacheWrite += assistantMsg.usage.cacheWrite || 0;
if (assistantMsg.usage.cost) {
data.costStats.input += assistantMsg.usage.cost.input || 0;
data.costStats.output += assistantMsg.usage.cost.output || 0;
data.costStats.cacheRead += assistantMsg.usage.cost.cacheRead || 0;
data.costStats.cacheWrite += assistantMsg.usage.cost.cacheWrite || 0;
}
}
}
if (!timestampSet && (msg as { timestamp?: number }).timestamp) {
data.timestamp = new Date((msg as { timestamp: number }).timestamp).toISOString();
timestampSet = true;
}
}
}
data.sessionId = `stream-${data.timestamp.replace(/[:.]/g, "-")}`;
return data;
}
function detectFormat(lines: string[]): "session-manager" | "streaming-events" | "unknown" {
for (const line of lines) {
try {
const entry = JSON.parse(line) as { type: string };
if (entry.type === "session") return "session-manager";
if (entry.type === "agent_start" || entry.type === "message_start" || entry.type === "turn_start") {
return "streaming-events";
}
} catch {}
}
return "unknown";
}
function parseSessionFile(content: string): ParsedSessionData {
const lines = content
.trim()
.split("\n")
.filter((l) => l.trim());
if (lines.length === 0) {
throw new Error("Empty session file");
}
const format = detectFormat(lines);
if (format === "unknown") {
throw new Error("Unknown session file format");
}
return format === "session-manager" ? parseSessionManagerFormat(lines) : parseStreamingEventFormat(lines);
}
// ============================================================================
// HTML formatting functions
// ============================================================================
function formatToolExecution(
toolName: string,
args: Record<string, unknown>,
result?: ToolResultMessage,
): { html: string; bgColor: string } {
let html = "";
const isError = result?.isError || false;
const bgColor = result ? (isError ? COLORS.toolErrorBg : COLORS.toolSuccessBg) : COLORS.toolPendingBg;
const getTextOutput = (): string => {
if (!result) return "";
const textBlocks = result.content.filter((c) => c.type === "text");
return textBlocks.map((c) => (c as { type: "text"; text: string }).text).join("\n");
};
switch (toolName) {
case "bash": {
const command = (args?.command as string) || "";
html = `<div class="tool-command">$ ${escapeHtml(command || "...")}</div>`;
if (result) {
const output = getTextOutput().trim();
if (output) {
html += formatExpandableOutput(output.split("\n"), 5);
}
}
break;
}
case "read": {
const path = shortenPath((args?.file_path as string) || (args?.path as string) || "");
html = `<div class="tool-header"><span class="tool-name">read</span> <span class="tool-path">${escapeHtml(path || "...")}</span></div>`;
if (result) {
const output = getTextOutput();
if (output) {
html += formatExpandableOutput(output.split("\n"), 10);
}
}
break;
}
case "write": {
const path = shortenPath((args?.file_path as string) || (args?.path as string) || "");
const fileContent = (args?.content as string) || "";
const lines = fileContent ? fileContent.split("\n") : [];
html = `<div class="tool-header"><span class="tool-name">write</span> <span class="tool-path">${escapeHtml(path || "...")}</span>`;
if (lines.length > 10) {
html += ` <span class="line-count">(${lines.length} lines)</span>`;
}
html += "</div>";
if (fileContent) {
html += formatExpandableOutput(lines, 10);
}
if (result) {
const output = getTextOutput().trim();
if (output) {
html += `<div class="tool-output"><div>${escapeHtml(output)}</div></div>`;
}
}
break;
}
case "edit": {
const path = shortenPath((args?.file_path as string) || (args?.path as string) || "");
html = `<div class="tool-header"><span class="tool-name">edit</span> <span class="tool-path">${escapeHtml(path || "...")}</span></div>`;
if (result?.details?.diff) {
const diffLines = result.details.diff.split("\n");
html += '<div class="tool-diff">';
for (const line of diffLines) {
if (line.startsWith("+")) {
html += `<div class="diff-line-new">${escapeHtml(line)}</div>`;
} else if (line.startsWith("-")) {
html += `<div class="diff-line-old">${escapeHtml(line)}</div>`;
} else {
html += `<div class="diff-line-context">${escapeHtml(line)}</div>`;
}
}
html += "</div>";
}
if (result) {
const output = getTextOutput().trim();
if (output) {
html += `<div class="tool-output"><div>${escapeHtml(output)}</div></div>`;
}
}
break;
}
default: {
html = `<div class="tool-header"><span class="tool-name">${escapeHtml(toolName)}</span></div>`;
html += `<div class="tool-output"><pre>${escapeHtml(JSON.stringify(args, null, 2))}</pre></div>`;
if (result) {
const output = getTextOutput();
if (output) {
html += `<div class="tool-output"><div>${escapeHtml(output)}</div></div>`;
}
}
}
}
return { html, bgColor };
}
function formatMessage(message: Message, toolResultsMap: Map<string, ToolResultMessage>): string {
let html = "";
const timestamp = (message as { timestamp?: number }).timestamp;
const timestampHtml = timestamp ? `<div class="message-timestamp">${formatTimestamp(timestamp)}</div>` : "";
// Handle bash execution messages (user-executed via ! command)
if (isBashExecutionMessage(message)) {
const bashMsg = message as unknown as BashExecutionMessage;
const isError = bashMsg.cancelled || (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null);
const bgColor = isError ? COLORS.userBashErrorBg : COLORS.userBashBg;
html += `<div class="tool-execution" style="background-color: ${bgColor}">`;
html += timestampHtml;
html += `<div class="tool-command">$ ${escapeHtml(bashMsg.command)}</div>`;
if (bashMsg.output) {
const lines = bashMsg.output.split("\n");
html += formatExpandableOutput(lines, 10);
}
if (bashMsg.cancelled) {
html += `<div class="bash-status" style="color: ${COLORS.yellow}">(cancelled)</div>`;
} else if (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null) {
html += `<div class="bash-status" style="color: ${COLORS.red}">(exit ${bashMsg.exitCode})</div>`;
}
if (bashMsg.truncated && bashMsg.fullOutputPath) {
html += `<div class="bash-truncation" style="color: ${COLORS.yellow}">Output truncated. Full output: ${escapeHtml(bashMsg.fullOutputPath)}</div>`;
}
html += `</div>`;
return html;
}
if (message.role === "user") {
const userMsg = message as UserMessage;
let textContent = "";
if (typeof userMsg.content === "string") {
textContent = userMsg.content;
} else {
const textBlocks = userMsg.content.filter((c) => c.type === "text");
textContent = textBlocks.map((c) => (c as { type: "text"; text: string }).text).join("");
}
if (textContent.trim()) {
html += `<div class="user-message">${timestampHtml}${escapeHtml(textContent).replace(/\n/g, "<br>")}</div>`;
}
} else if (message.role === "assistant") {
const assistantMsg = message as AssistantMessage;
html += timestampHtml ? `<div class="assistant-message">${timestampHtml}` : "";
for (const content of assistantMsg.content) {
if (content.type === "text" && content.text.trim()) {
html += `<div class="assistant-text">${escapeHtml(content.text.trim()).replace(/\n/g, "<br>")}</div>`;
} else if (content.type === "thinking" && content.thinking.trim()) {
html += `<div class="thinking-text">${escapeHtml(content.thinking.trim()).replace(/\n/g, "<br>")}</div>`;
}
}
for (const content of assistantMsg.content) {
if (content.type === "toolCall") {
const toolResult = toolResultsMap.get(content.id);
const { html: toolHtml, bgColor } = formatToolExecution(
content.name,
content.arguments as Record<string, unknown>,
toolResult,
);
html += `<div class="tool-execution" style="background-color: ${bgColor}">${toolHtml}</div>`;
}
}
const hasToolCalls = assistantMsg.content.some((c) => c.type === "toolCall");
if (!hasToolCalls) {
if (assistantMsg.stopReason === "aborted") {
html += '<div class="error-text">Aborted</div>';
} else if (assistantMsg.stopReason === "error") {
html += `<div class="error-text">Error: ${escapeHtml(assistantMsg.errorMessage || "Unknown error")}</div>`;
}
}
if (timestampHtml) {
html += "</div>";
}
}
return html;
}
function formatModelChange(event: ModelChangeEvent): string {
const timestamp = formatTimestamp(event.timestamp);
const timestampHtml = timestamp ? `<div class="message-timestamp">${timestamp}</div>` : "";
const modelInfo = `${event.provider}/${event.modelId}`;
return `<div class="model-change">${timestampHtml}<div class="model-change-text">Switched to model: <span class="model-name">${escapeHtml(modelInfo)}</span></div></div>`;
}
function formatCompaction(event: CompactionEvent): string {
const timestamp = formatTimestamp(event.timestamp);
const timestampHtml = timestamp ? `<div class="message-timestamp">${timestamp}</div>` : "";
const summaryHtml = escapeHtml(event.summary).replace(/\n/g, "<br>");
return `<div class="compaction-container">
<div class="compaction-header" onclick="this.parentElement.classList.toggle('expanded')">
${timestampHtml}
<div class="compaction-header-row">
<span class="compaction-toggle"></span>
<span class="compaction-title">Context compacted from ${event.tokensBefore.toLocaleString()} tokens</span>
<span class="compaction-hint">(click to expand summary)</span>
</div>
</div>
<div class="compaction-content">
<div class="compaction-summary">
<div class="compaction-summary-header">Summary sent to model</div>
<div class="compaction-summary-content">${summaryHtml}</div>
</div>
</div>
</div>`;
}
// ============================================================================
// HTML generation
// ============================================================================
function generateHtml(data: ParsedSessionData, filename: string): string {
const userMessages = data.messages.filter((m) => m.role === "user").length;
const assistantMessages = data.messages.filter((m) => m.role === "assistant").length;
let toolCallsCount = 0;
for (const message of data.messages) {
if (message.role === "assistant") {
toolCallsCount += (message as AssistantMessage).content.filter((c) => c.type === "toolCall").length;
}
}
const lastAssistantMessage = data.messages
.slice()
.reverse()
.find((m) => m.role === "assistant" && (m as AssistantMessage).stopReason !== "aborted") as
| AssistantMessage
| undefined;
const contextTokens = lastAssistantMessage
? lastAssistantMessage.usage.input +
lastAssistantMessage.usage.output +
lastAssistantMessage.usage.cacheRead +
lastAssistantMessage.usage.cacheWrite
: 0;
const lastModel = lastAssistantMessage?.model || "unknown";
const lastProvider = lastAssistantMessage?.provider || "";
const lastModelInfo = lastProvider ? `${lastProvider}/${lastModel}` : lastModel;
const contextWindow = data.contextWindow || 0;
const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : null;
let messagesHtml = "";
for (const event of data.sessionEvents) {
switch (event.type) {
case "message":
if (event.message.role !== "toolResult") {
messagesHtml += formatMessage(event.message, data.toolResultsMap);
}
break;
case "model_change":
messagesHtml += formatModelChange(event);
break;
case "compaction":
messagesHtml += formatCompaction(event);
break;
}
}
const systemPromptHtml = data.systemPrompt
? `<div class="system-prompt">
<div class="system-prompt-header">System Prompt</div>
<div class="system-prompt-content">${escapeHtml(data.systemPrompt)}</div>
</div>`
: "";
const toolsHtml = data.tools
? `<div class="tools-list">
<div class="tools-header">Available Tools</div>
<div class="tools-content">
${data.tools.map((tool) => `<div class="tool-item"><span class="tool-item-name">${escapeHtml(tool.name)}</span> - ${escapeHtml(tool.description)}</div>`).join("")}
</div>
</div>`
: "";
const streamingNotice = data.isStreamingFormat
? `<div class="streaming-notice">
<em>Note: This session was reconstructed from raw agent event logs, which do not contain system prompt or tool definitions.</em>
</div>`
: "";
const contextUsageText = contextPercent
? `${contextTokens.toLocaleString()} / ${contextWindow.toLocaleString()} tokens (${contextPercent}%) - ${escapeHtml(lastModelInfo)}`
: `${contextTokens.toLocaleString()} tokens (last turn) - ${escapeHtml(lastModelInfo)}`;
return `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Session Export - ${escapeHtml(filename)}</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
font-size: 12px;
line-height: 1.6;
color: ${COLORS.text};
background: ${COLORS.bodyBg};
padding: 24px;
}
.container { max-width: 700px; margin: 0 auto; }
.header {
margin-bottom: 24px;
padding: 16px;
background: ${COLORS.containerBg};
border-radius: 4px;
}
.header h1 {
font-size: 14px;
font-weight: bold;
margin-bottom: 12px;
color: ${COLORS.cyan};
}
.header-info { display: flex; flex-direction: column; gap: 3px; font-size: 11px; }
.info-item { color: ${COLORS.textDim}; display: flex; align-items: baseline; }
.info-label { font-weight: 600; margin-right: 8px; min-width: 100px; }
.info-value { color: ${COLORS.text}; flex: 1; }
.info-value.cost { font-family: 'SF Mono', monospace; }
.messages { display: flex; flex-direction: column; gap: 16px; }
.message-timestamp { font-size: 10px; color: ${COLORS.textDim}; margin-bottom: 4px; opacity: 0.8; }
.user-message {
background: ${COLORS.userMessageBg};
padding: 12px 16px;
border-radius: 4px;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
}
.assistant-message { padding: 0; }
.assistant-text, .thinking-text {
padding: 12px 16px;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
}
.thinking-text { color: ${COLORS.textDim}; font-style: italic; }
.model-change { padding: 8px 16px; background: rgb(40, 40, 50); border-radius: 4px; }
.model-change-text { color: ${COLORS.textDim}; font-size: 11px; }
.model-name { color: ${COLORS.cyan}; font-weight: bold; }
.compaction-container { background: rgb(60, 55, 35); border-radius: 4px; overflow: hidden; }
.compaction-header { padding: 12px 16px; cursor: pointer; }
.compaction-header:hover { background: rgba(255, 255, 255, 0.05); }
.compaction-header-row { display: flex; align-items: center; gap: 8px; }
.compaction-toggle { color: ${COLORS.cyan}; font-size: 10px; transition: transform 0.2s; }
.compaction-container.expanded .compaction-toggle { transform: rotate(90deg); }
.compaction-title { color: ${COLORS.text}; font-weight: bold; }
.compaction-hint { color: ${COLORS.textDim}; font-size: 11px; }
.compaction-content { display: none; padding: 0 16px 16px 16px; }
.compaction-container.expanded .compaction-content { display: block; }
.compaction-summary { background: rgba(0, 0, 0, 0.2); border-radius: 4px; padding: 12px; }
.compaction-summary-header { font-weight: bold; color: ${COLORS.cyan}; margin-bottom: 8px; font-size: 11px; }
.compaction-summary-content { color: ${COLORS.text}; white-space: pre-wrap; word-wrap: break-word; }
.tool-execution { padding: 12px 16px; border-radius: 4px; margin-top: 8px; }
.tool-header, .tool-name { font-weight: bold; }
.tool-path { color: ${COLORS.cyan}; word-break: break-all; }
.line-count { color: ${COLORS.textDim}; }
.tool-command { font-weight: bold; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; word-break: break-word; }
.tool-output {
margin-top: 12px;
color: ${COLORS.textDim};
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
font-family: inherit;
overflow-x: auto;
}
.tool-output > div { line-height: 1.4; }
.tool-output pre { margin: 0; font-family: inherit; color: inherit; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; }
.tool-output.expandable { cursor: pointer; }
.tool-output.expandable:hover { opacity: 0.9; }
.tool-output.expandable .output-full { display: none; }
.tool-output.expandable.expanded .output-preview { display: none; }
.tool-output.expandable.expanded .output-full { display: block; }
.expand-hint { color: ${COLORS.cyan}; font-style: italic; margin-top: 4px; }
.system-prompt, .tools-list { background: rgb(60, 55, 40); padding: 12px 16px; border-radius: 4px; margin-bottom: 16px; }
.system-prompt-header, .tools-header { font-weight: bold; color: ${COLORS.yellow}; margin-bottom: 8px; }
.system-prompt-content, .tools-content { color: ${COLORS.textDim}; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; word-break: break-word; font-size: 11px; }
.tool-item { margin: 4px 0; }
.tool-item-name { font-weight: bold; color: ${COLORS.text}; }
.tool-diff { margin-top: 12px; font-size: 11px; font-family: inherit; overflow-x: auto; max-width: 100%; }
.diff-line-old { color: ${COLORS.red}; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; }
.diff-line-new { color: ${COLORS.green}; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; }
.diff-line-context { color: ${COLORS.textDim}; white-space: pre-wrap; word-wrap: break-word; overflow-wrap: break-word; }
.error-text { color: ${COLORS.red}; padding: 12px 16px; }
.footer { margin-top: 48px; padding: 20px; text-align: center; color: ${COLORS.textDim}; font-size: 10px; }
.streaming-notice { background: rgb(50, 45, 35); padding: 12px 16px; border-radius: 4px; margin-bottom: 16px; color: ${COLORS.textDim}; font-size: 11px; }
@media print { body { background: white; color: black; } .tool-execution { border: 1px solid #ddd; } }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>${APP_NAME} v${VERSION}</h1>
<div class="header-info">
<div class="info-item"><span class="info-label">Session:</span><span class="info-value">${escapeHtml(data.sessionId)}</span></div>
<div class="info-item"><span class="info-label">Date:</span><span class="info-value">${new Date(data.timestamp).toLocaleString()}</span></div>
<div class="info-item"><span class="info-label">Models:</span><span class="info-value">${
Array.from(data.modelsUsed)
.map((m) => escapeHtml(m))
.join(", ") || "unknown"
}</span></div>
</div>
</div>
<div class="header">
<h1>Messages</h1>
<div class="header-info">
<div class="info-item"><span class="info-label">User:</span><span class="info-value">${userMessages}</span></div>
<div class="info-item"><span class="info-label">Assistant:</span><span class="info-value">${assistantMessages}</span></div>
<div class="info-item"><span class="info-label">Tool Calls:</span><span class="info-value">${toolCallsCount}</span></div>
</div>
</div>
<div class="header">
<h1>Tokens & Cost</h1>
<div class="header-info">
<div class="info-item"><span class="info-label">Input:</span><span class="info-value">${data.tokenStats.input.toLocaleString()} tokens</span></div>
<div class="info-item"><span class="info-label">Output:</span><span class="info-value">${data.tokenStats.output.toLocaleString()} tokens</span></div>
<div class="info-item"><span class="info-label">Cache Read:</span><span class="info-value">${data.tokenStats.cacheRead.toLocaleString()} tokens</span></div>
<div class="info-item"><span class="info-label">Cache Write:</span><span class="info-value">${data.tokenStats.cacheWrite.toLocaleString()} tokens</span></div>
<div class="info-item"><span class="info-label">Total:</span><span class="info-value">${(data.tokenStats.input + data.tokenStats.output + data.tokenStats.cacheRead + data.tokenStats.cacheWrite).toLocaleString()} tokens</span></div>
<div class="info-item"><span class="info-label">Input Cost:</span><span class="info-value cost">$${data.costStats.input.toFixed(4)}</span></div>
<div class="info-item"><span class="info-label">Output Cost:</span><span class="info-value cost">$${data.costStats.output.toFixed(4)}</span></div>
<div class="info-item"><span class="info-label">Cache Read Cost:</span><span class="info-value cost">$${data.costStats.cacheRead.toFixed(4)}</span></div>
<div class="info-item"><span class="info-label">Cache Write Cost:</span><span class="info-value cost">$${data.costStats.cacheWrite.toFixed(4)}</span></div>
<div class="info-item"><span class="info-label">Total Cost:</span><span class="info-value cost"><strong>$${(data.costStats.input + data.costStats.output + data.costStats.cacheRead + data.costStats.cacheWrite).toFixed(4)}</strong></span></div>
<div class="info-item"><span class="info-label">Context Usage:</span><span class="info-value">${contextUsageText}</span></div>
</div>
</div>
${systemPromptHtml}
${toolsHtml}
${streamingNotice}
<div class="messages">
${messagesHtml}
</div>
<div class="footer">
Generated by ${APP_NAME} coding-agent on ${new Date().toLocaleString()}
</div>
</div>
</body>
</html>`;
}
// ============================================================================
// Public API
// ============================================================================
/**
* Export session to HTML using SessionManager and AgentState.
* Used by TUI's /export command.
*/
export function exportSessionToHtml(sessionManager: SessionManager, state: AgentState, outputPath?: string): string {
const sessionFile = sessionManager.getSessionFile();
const content = readFileSync(sessionFile, "utf8");
const data = parseSessionFile(content);
// Enrich with data from AgentState (tools, context window)
data.tools = state.tools.map((t) => ({ name: t.name, description: t.description }));
data.contextWindow = state.model?.contextWindow;
if (!data.systemPrompt) {
data.systemPrompt = state.systemPrompt;
}
if (!outputPath) {
const sessionBasename = basename(sessionFile, ".jsonl");
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
}
const html = generateHtml(data, basename(sessionFile));
writeFileSync(outputPath, html, "utf8");
return outputPath;
}
/**
* Export session file to HTML (standalone, without AgentState).
* Auto-detects format: session manager format or streaming event format.
* Used by CLI for exporting arbitrary session files.
*/
export function exportFromFile(inputPath: string, outputPath?: string): string {
if (!existsSync(inputPath)) {
throw new Error(`File not found: ${inputPath}`);
}
const content = readFileSync(inputPath, "utf8");
const data = parseSessionFile(content);
if (!outputPath) {
const inputBasename = basename(inputPath, ".jsonl");
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
}
const html = generateHtml(data, basename(inputPath));
writeFileSync(outputPath, html, "utf8");
return outputPath;
}

View file

@ -0,0 +1,102 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AppMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
// ============================================================================
// Custom Message Types
// ============================================================================
/**
* Message type for bash executions via the ! command.
*/
export interface BashExecutionMessage {
role: "bashExecution";
command: string;
output: string;
exitCode: number | null;
cancelled: boolean;
truncated: boolean;
fullOutputPath?: string;
timestamp: number;
}
// Extend CustomMessages via declaration merging
declare module "@mariozechner/pi-agent-core" {
interface CustomMessages {
bashExecution: BashExecutionMessage;
}
}
// ============================================================================
// Type Guards
// ============================================================================
/**
* Type guard for BashExecutionMessage.
*/
export function isBashExecutionMessage(msg: AppMessage | Message): msg is BashExecutionMessage {
return (msg as BashExecutionMessage).role === "bashExecution";
}
// ============================================================================
// Message Formatting
// ============================================================================
/**
* Convert a BashExecutionMessage to user message text for LLM context.
*/
export function bashExecutionToText(msg: BashExecutionMessage): string {
let text = `Ran \`${msg.command}\`\n`;
if (msg.output) {
text += "```\n" + msg.output + "\n```";
} else {
text += "(no output)";
}
if (msg.cancelled) {
text += "\n\n(command cancelled)";
} else if (msg.exitCode !== null && msg.exitCode !== 0) {
text += `\n\nCommand exited with code ${msg.exitCode}`;
}
if (msg.truncated && msg.fullOutputPath) {
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
}
return text;
}
// ============================================================================
// Message Transformer
// ============================================================================
/**
* Transform AppMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's messageTransformer option (for prompt calls)
* - Compaction's generateSummary (for summarization)
*/
export function messageTransformer(messages: AppMessage[]): Message[] {
return messages
.map((m): Message | null => {
if (isBashExecutionMessage(m)) {
// Convert bash execution to user message
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
}
// Pass through standard LLM roles
if (m.role === "user" || m.role === "assistant" || m.role === "toolResult") {
return m as Message;
}
// Filter out unknown message types
return null;
})
.filter((m): m is Message => m !== null);
}

View file

@ -0,0 +1,366 @@
import { type Api, getApiKey, getModels, getProviders, type KnownProvider, type Model } from "@mariozechner/pi-ai";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { getModelsPath } from "../utils/config.js";
import { getOAuthToken, type SupportedOAuthProvider } from "./oauth/index.js";
import { loadOAuthCredentials } from "./oauth/storage.js";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
// Schema for OpenAI compatibility settings
const OpenAICompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
});
// Schema for custom model definition
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
reasoning: Type.Boolean(),
input: Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
cost: Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
contextWindow: Type.Number(),
maxTokens: Type.Number(),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
const ProviderConfigSchema = Type.Object({
baseUrl: Type.String({ minLength: 1 }),
apiKey: Type.String({ minLength: 1 }),
api: Type.Optional(
Type.Union([
Type.Literal("openai-completions"),
Type.Literal("openai-responses"),
Type.Literal("anthropic-messages"),
Type.Literal("google-generative-ai"),
]),
),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Array(ModelDefinitionSchema),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
type ModelsConfig = Static<typeof ModelsConfigSchema>;
type ProviderConfig = Static<typeof ProviderConfigSchema>;
type ModelDefinition = Static<typeof ModelDefinitionSchema>;
// Custom provider API key mappings (provider name -> apiKey config)
const customProviderApiKeys: Map<string, string> = new Map();
/**
* Resolve an API key config value to an actual key.
* First checks if it's an environment variable, then treats as literal.
*/
export function resolveApiKey(keyConfig: string): string | undefined {
// First check if it's an env var name
const envValue = process.env[keyConfig];
if (envValue) return envValue;
// Otherwise treat as literal API key
return keyConfig;
}
/**
* Load custom models from models.json in agent config dir
* Returns { models, error } - either models array or error message
*/
function loadCustomModels(): { models: Model<Api>[]; error: string | null } {
const configPath = getModelsPath();
if (!existsSync(configPath)) {
return { models: [], error: null };
}
try {
const content = readFileSync(configPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const ajv = new Ajv();
const validate = ajv.compile(ModelsConfigSchema);
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return {
models: [],
error: `Invalid models.json schema:\n${errors}\n\nFile: ${configPath}`,
};
}
// Additional validation
try {
validateConfig(config);
} catch (error) {
return {
models: [],
error: `Invalid models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
};
}
// Parse models
return { models: parseModels(config), error: null };
} catch (error) {
if (error instanceof SyntaxError) {
return {
models: [],
error: `Failed to parse models.json: ${error.message}\n\nFile: ${configPath}`,
};
}
return {
models: [],
error: `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${configPath}`,
};
}
}
/**
* Validate config structure and requirements
*/
function validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
for (const modelDef of providerConfig.models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. ` +
`Set at provider or model level.`,
);
}
// Validate required fields
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.name) throw new Error(`Provider ${providerName}: model missing "name"`);
if (modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
/**
* Parse config into Model objects
*/
function parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
// Clear and rebuild custom provider API key mappings
customProviderApiKeys.clear();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Store API key config for this provider
customProviderApiKeys.set(providerName, providerConfig.apiKey);
for (const modelDef of providerConfig.models) {
// Model-level api overrides provider-level api
const api = modelDef.api || providerConfig.api;
if (!api) {
// This should have been caught by validateConfig, but be safe
continue;
}
// Merge headers: provider headers are base, model headers override
let headers =
providerConfig.headers || modelDef.headers ? { ...providerConfig.headers, ...modelDef.headers } : undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader) {
const resolvedKey = resolveApiKey(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: providerConfig.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom), freshly loaded
* Returns { models, error } - either models array or error message
*/
export function loadAndMergeModels(): { models: Model<Api>[]; error: string | null } {
const builtInModels: Model<Api>[] = [];
const providers = getProviders();
// Load all built-in models
for (const provider of providers) {
const providerModels = getModels(provider as KnownProvider);
builtInModels.push(...(providerModels as Model<Api>[]));
}
// Load custom models
const { models: customModels, error } = loadCustomModels();
if (error) {
return { models: [], error };
}
// Merge: custom models come after built-in
return { models: [...builtInModels, ...customModels], error: null };
}
/**
* Get API key for a model (checks custom providers first, then built-in)
* Now async to support OAuth token refresh
*/
export async function getApiKeyForModel(model: Model<Api>): Promise<string | undefined> {
// For custom providers, check their apiKey config
const customKeyConfig = customProviderApiKeys.get(model.provider);
if (customKeyConfig) {
return resolveApiKey(customKeyConfig);
}
// For Anthropic, check OAuth first
if (model.provider === "anthropic") {
// 1. Check OAuth storage (auto-refresh if needed)
const oauthToken = await getOAuthToken("anthropic");
if (oauthToken) {
return oauthToken;
}
// 2. Check ANTHROPIC_OAUTH_TOKEN env var (manual OAuth token)
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv) {
return oauthEnv;
}
// 3. Fall back to ANTHROPIC_API_KEY env var
}
// For built-in providers, use getApiKey from @mariozechner/pi-ai
return getApiKey(model.provider as KnownProvider);
}
/**
* Get only models that have valid API keys available
* Returns { models, error } - either models array or error message
*/
export async function getAvailableModels(): Promise<{ models: Model<Api>[]; error: string | null }> {
const { models: allModels, error } = loadAndMergeModels();
if (error) {
return { models: [], error };
}
const availableModels: Model<Api>[] = [];
for (const model of allModels) {
const apiKey = await getApiKeyForModel(model);
if (apiKey) {
availableModels.push(model);
}
}
return { models: availableModels, error: null };
}
/**
* Find a specific model by provider and ID
* Returns { model, error } - either model or error message
*/
export function findModel(provider: string, modelId: string): { model: Model<Api> | null; error: string | null } {
const { models: allModels, error } = loadAndMergeModels();
if (error) {
return { model: null, error };
}
const model = allModels.find((m) => m.provider === provider && m.id === modelId) || null;
return { model, error: null };
}
/**
* Mapping from model provider to OAuth provider ID.
* Only providers that support OAuth are listed here.
*/
const providerToOAuthProvider: Record<string, SupportedOAuthProvider> = {
anthropic: "anthropic",
// Add more mappings as OAuth support is added for other providers
};
// Cache for OAuth status per provider (avoids file reads on every render)
const oauthStatusCache: Map<string, boolean> = new Map();
/**
* Invalidate the OAuth status cache.
* Call this after login/logout operations.
*/
export function invalidateOAuthCache(): void {
oauthStatusCache.clear();
}
/**
* Check if a model is using OAuth credentials (subscription).
* This checks if OAuth credentials exist and would be used for the model,
* without actually fetching or refreshing the token.
* Results are cached until invalidateOAuthCache() is called.
*/
export function isModelUsingOAuth(model: Model<Api>): boolean {
const oauthProvider = providerToOAuthProvider[model.provider];
if (!oauthProvider) {
return false;
}
// Check cache first
if (oauthStatusCache.has(oauthProvider)) {
return oauthStatusCache.get(oauthProvider)!;
}
// Check if OAuth credentials exist for this provider
let usingOAuth = false;
const credentials = loadOAuthCredentials(oauthProvider);
if (credentials) {
usingOAuth = true;
}
// Also check for manual OAuth token env var (for Anthropic)
if (!usingOAuth && model.provider === "anthropic" && process.env.ANTHROPIC_OAUTH_TOKEN) {
usingOAuth = true;
}
oauthStatusCache.set(oauthProvider, usingOAuth);
return usingOAuth;
}

View file

@ -0,0 +1,128 @@
import { createHash, randomBytes } from "crypto";
import { type OAuthCredentials, saveOAuthCredentials } from "./storage.js";
const CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
/**
* Generate PKCE code verifier and challenge
*/
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = randomBytes(32).toString("base64url");
const challenge = createHash("sha256").update(verifier).digest("base64url");
return { verifier, challenge };
}
/**
* Login with Anthropic OAuth (device code flow)
*/
export async function loginAnthropic(
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<void> {
const { verifier, challenge } = generatePKCE();
// Build authorization URL
const authParams = new URLSearchParams({
code: "true",
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES,
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
});
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuthUrl(authUrl);
// Wait for user to paste authorization code (format: code#state)
const authCode = await onPromptCode();
const splits = authCode.split("#");
const code = splits[0];
const state = splits[1];
// Exchange code for tokens
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code: code,
state: state,
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
// Save credentials
const credentials: OAuthCredentials = {
type: "oauth",
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
};
saveOAuthCredentials("anthropic", credentials);
}
/**
* Refresh Anthropic OAuth token using refresh token
*/
export async function refreshAnthropicToken(refreshToken: string): Promise<OAuthCredentials> {
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "refresh_token",
client_id: CLIENT_ID,
refresh_token: refreshToken,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token refresh failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
return {
type: "oauth",
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
};
}

View file

@ -0,0 +1,115 @@
import { loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
import {
listOAuthProviders as listOAuthProvidersFromStorage,
loadOAuthCredentials,
type OAuthCredentials,
removeOAuthCredentials,
saveOAuthCredentials,
} from "./storage.js";
// Re-export for convenience
export { listOAuthProvidersFromStorage as listOAuthProviders };
export type SupportedOAuthProvider = "anthropic" | "github-copilot";
export interface OAuthProviderInfo {
id: SupportedOAuthProvider;
name: string;
available: boolean;
}
/**
* Get list of OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInfo[] {
return [
{
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
available: true,
},
{
id: "github-copilot",
name: "GitHub Copilot (coming soon)",
available: false,
},
];
}
/**
* Login with OAuth provider
*/
export async function login(
provider: SupportedOAuthProvider,
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<void> {
switch (provider) {
case "anthropic":
await loginAnthropic(onAuthUrl, onPromptCode);
break;
case "github-copilot":
throw new Error("GitHub Copilot OAuth is not yet implemented");
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
}
/**
* Logout from OAuth provider
*/
export async function logout(provider: SupportedOAuthProvider): Promise<void> {
removeOAuthCredentials(provider);
}
/**
* Refresh OAuth token for provider
*/
export async function refreshToken(provider: SupportedOAuthProvider): Promise<string> {
const credentials = loadOAuthCredentials(provider);
if (!credentials) {
throw new Error(`No OAuth credentials found for ${provider}`);
}
let newCredentials: OAuthCredentials;
switch (provider) {
case "anthropic":
newCredentials = await refreshAnthropicToken(credentials.refresh);
break;
case "github-copilot":
throw new Error("GitHub Copilot OAuth is not yet implemented");
default:
throw new Error(`Unknown OAuth provider: ${provider}`);
}
// Save new credentials
saveOAuthCredentials(provider, newCredentials);
return newCredentials.access;
}
/**
* Get OAuth token for provider (auto-refreshes if expired)
*/
export async function getOAuthToken(provider: SupportedOAuthProvider): Promise<string | null> {
const credentials = loadOAuthCredentials(provider);
if (!credentials) {
return null;
}
// Check if token is expired (with 5 min buffer already applied)
if (Date.now() >= credentials.expires) {
// Token expired - refresh it
try {
return await refreshToken(provider);
} catch (error) {
console.error(`Failed to refresh OAuth token for ${provider}:`, error);
// Remove invalid credentials
removeOAuthCredentials(provider);
return null;
}
}
return credentials.access;
}

View file

@ -0,0 +1,86 @@
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { getAgentDir, getOAuthPath } from "../../utils/config.js";
export interface OAuthCredentials {
type: "oauth";
refresh: string;
access: string;
expires: number;
}
interface OAuthStorageFormat {
[provider: string]: OAuthCredentials;
}
/**
* Ensure the config directory exists
*/
function ensureConfigDir(): void {
const configDir = getAgentDir();
if (!existsSync(configDir)) {
mkdirSync(configDir, { recursive: true, mode: 0o700 });
}
}
/**
* Load all OAuth credentials from oauth.json
*/
function loadStorage(): OAuthStorageFormat {
const filePath = getOAuthPath();
if (!existsSync(filePath)) {
return {};
}
try {
const content = readFileSync(filePath, "utf-8");
return JSON.parse(content);
} catch (error) {
console.error(`Warning: Failed to load OAuth credentials: ${error}`);
return {};
}
}
/**
* Save all OAuth credentials to oauth.json
*/
function saveStorage(storage: OAuthStorageFormat): void {
ensureConfigDir();
const filePath = getOAuthPath();
writeFileSync(filePath, JSON.stringify(storage, null, 2), "utf-8");
// Set permissions to owner read/write only
chmodSync(filePath, 0o600);
}
/**
* Load OAuth credentials for a specific provider
*/
export function loadOAuthCredentials(provider: string): OAuthCredentials | null {
const storage = loadStorage();
return storage[provider] || null;
}
/**
* Save OAuth credentials for a specific provider
*/
export function saveOAuthCredentials(provider: string, creds: OAuthCredentials): void {
const storage = loadStorage();
storage[provider] = creds;
saveStorage(storage);
}
/**
* Remove OAuth credentials for a specific provider
*/
export function removeOAuthCredentials(provider: string): void {
const storage = loadStorage();
delete storage[provider];
saveStorage(storage);
}
/**
* List all providers with OAuth credentials
*/
export function listOAuthProviders(): string[] {
const storage = loadStorage();
return Object.keys(storage);
}

View file

@ -0,0 +1,612 @@
import type { AgentState, AppMessage } from "@mariozechner/pi-agent-core";
import { randomBytes } from "crypto";
import { appendFileSync, existsSync, mkdirSync, readdirSync, readFileSync, statSync } from "fs";
import { join, resolve } from "path";
import { getAgentDir } from "../utils/config.js";
function uuidv4(): string {
const bytes = randomBytes(16);
bytes[6] = (bytes[6] & 0x0f) | 0x40;
bytes[8] = (bytes[8] & 0x3f) | 0x80;
const hex = bytes.toString("hex");
return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20, 32)}`;
}
// ============================================================================
// Session entry types
// ============================================================================
export interface SessionHeader {
type: "session";
id: string;
timestamp: string;
cwd: string;
provider: string;
modelId: string;
thinkingLevel: string;
branchedFrom?: string;
}
export interface SessionMessageEntry {
type: "message";
timestamp: string;
message: AppMessage;
}
export interface ThinkingLevelChangeEntry {
type: "thinking_level_change";
timestamp: string;
thinkingLevel: string;
}
export interface ModelChangeEntry {
type: "model_change";
timestamp: string;
provider: string;
modelId: string;
}
export interface CompactionEntry {
type: "compaction";
timestamp: string;
summary: string;
firstKeptEntryIndex: number; // Index into session entries where we start keeping
tokensBefore: number;
}
/** Union of all session entry types */
export type SessionEntry =
| SessionHeader
| SessionMessageEntry
| ThinkingLevelChangeEntry
| ModelChangeEntry
| CompactionEntry;
// ============================================================================
// Session loading with compaction support
// ============================================================================
export interface LoadedSession {
messages: AppMessage[];
thinkingLevel: string;
model: { provider: string; modelId: string } | null;
}
export const SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
<summary>
`;
export const SUMMARY_SUFFIX = `
</summary>`;
/**
* Create a user message containing the summary with the standard prefix.
*/
export function createSummaryMessage(summary: string): AppMessage {
return {
role: "user",
content: SUMMARY_PREFIX + summary + SUMMARY_SUFFIX,
timestamp: Date.now(),
};
}
/**
* Parse session file content into entries.
*/
export function parseSessionEntries(content: string): SessionEntry[] {
const entries: SessionEntry[] = [];
const lines = content.trim().split("\n");
for (const line of lines) {
if (!line.trim()) continue;
try {
const entry = JSON.parse(line) as SessionEntry;
entries.push(entry);
} catch {
// Skip malformed lines
}
}
return entries;
}
/**
* Load session from entries, handling compaction events.
*
* Algorithm:
* 1. Find latest compaction event (if any)
* 2. Keep all entries from firstKeptEntryIndex onwards (extracting messages)
* 3. Prepend summary as user message
*/
/**
* Get the latest compaction entry from session entries, if any.
*/
export function getLatestCompactionEntry(entries: SessionEntry[]): CompactionEntry | null {
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
return entries[i] as CompactionEntry;
}
}
return null;
}
export function loadSessionFromEntries(entries: SessionEntry[]): LoadedSession {
// Find model and thinking level (always scan all entries)
let thinkingLevel = "off";
let model: { provider: string; modelId: string } | null = null;
for (const entry of entries) {
if (entry.type === "session") {
thinkingLevel = entry.thinkingLevel;
model = { provider: entry.provider, modelId: entry.modelId };
} else if (entry.type === "thinking_level_change") {
thinkingLevel = entry.thinkingLevel;
} else if (entry.type === "model_change") {
model = { provider: entry.provider, modelId: entry.modelId };
}
}
// Find latest compaction event
let latestCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
latestCompactionIndex = i;
break;
}
}
// No compaction: return all messages
if (latestCompactionIndex === -1) {
const messages: AppMessage[] = [];
for (const entry of entries) {
if (entry.type === "message") {
messages.push(entry.message);
}
}
return { messages, thinkingLevel, model };
}
const compactionEvent = entries[latestCompactionIndex] as CompactionEntry;
// Extract messages from firstKeptEntryIndex to end (skipping compaction entries)
const keptMessages: AppMessage[] = [];
for (let i = compactionEvent.firstKeptEntryIndex; i < entries.length; i++) {
const entry = entries[i];
if (entry.type === "message") {
keptMessages.push(entry.message);
}
}
// Build final messages: summary + kept messages
const summaryMessage = createSummaryMessage(compactionEvent.summary);
const messages = [summaryMessage, ...keptMessages];
return { messages, thinkingLevel, model };
}
export class SessionManager {
private sessionId!: string;
private sessionFile!: string;
private sessionDir: string;
private enabled: boolean = true;
private sessionInitialized: boolean = false;
private pendingMessages: any[] = [];
constructor(continueSession: boolean = false, customSessionPath?: string) {
this.sessionDir = this.getSessionDirectory();
if (customSessionPath) {
// Use custom session file path
this.sessionFile = resolve(customSessionPath);
this.loadSessionId();
// Mark as initialized since we're loading an existing session
this.sessionInitialized = existsSync(this.sessionFile);
} else if (continueSession) {
const mostRecent = this.findMostRecentlyModifiedSession();
if (mostRecent) {
this.sessionFile = mostRecent;
this.loadSessionId();
// Mark as initialized since we're loading an existing session
this.sessionInitialized = true;
} else {
this.initNewSession();
}
} else {
this.initNewSession();
}
}
/** Disable session saving (for --no-session mode) */
disable() {
this.enabled = false;
}
private getSessionDirectory(): string {
const cwd = process.cwd();
// Replace all path separators and colons (for Windows drive letters) with dashes
const safePath = "--" + cwd.replace(/^[/\\]/, "").replace(/[/\\:]/g, "-") + "--";
const configDir = getAgentDir();
const sessionDir = join(configDir, "sessions", safePath);
if (!existsSync(sessionDir)) {
mkdirSync(sessionDir, { recursive: true });
}
return sessionDir;
}
private initNewSession(): void {
this.sessionId = uuidv4();
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
this.sessionFile = join(this.sessionDir, `${timestamp}_${this.sessionId}.jsonl`);
}
/** Reset to a fresh session. Clears pending messages and starts a new session file. */
reset(): void {
this.pendingMessages = [];
this.sessionInitialized = false;
this.initNewSession();
}
private findMostRecentlyModifiedSession(): string | null {
try {
const files = readdirSync(this.sessionDir)
.filter((f) => f.endsWith(".jsonl"))
.map((f) => ({
name: f,
path: join(this.sessionDir, f),
mtime: statSync(join(this.sessionDir, f)).mtime,
}))
.sort((a, b) => b.mtime.getTime() - a.mtime.getTime());
return files[0]?.path || null;
} catch {
return null;
}
}
private loadSessionId(): void {
if (!existsSync(this.sessionFile)) return;
const lines = readFileSync(this.sessionFile, "utf8").trim().split("\n");
for (const line of lines) {
try {
const entry = JSON.parse(line);
if (entry.type === "session") {
this.sessionId = entry.id;
return;
}
} catch {
// Skip malformed lines
}
}
this.sessionId = uuidv4();
}
startSession(state: AgentState): void {
if (!this.enabled || this.sessionInitialized) return;
this.sessionInitialized = true;
const entry: SessionHeader = {
type: "session",
id: this.sessionId,
timestamp: new Date().toISOString(),
cwd: process.cwd(),
provider: state.model.provider,
modelId: state.model.id,
thinkingLevel: state.thinkingLevel,
};
appendFileSync(this.sessionFile, JSON.stringify(entry) + "\n");
// Write any queued messages
for (const msg of this.pendingMessages) {
appendFileSync(this.sessionFile, JSON.stringify(msg) + "\n");
}
this.pendingMessages = [];
}
saveMessage(message: any): void {
if (!this.enabled) return;
const entry: SessionMessageEntry = {
type: "message",
timestamp: new Date().toISOString(),
message,
};
if (!this.sessionInitialized) {
this.pendingMessages.push(entry);
} else {
appendFileSync(this.sessionFile, JSON.stringify(entry) + "\n");
}
}
saveThinkingLevelChange(thinkingLevel: string): void {
if (!this.enabled) return;
const entry: ThinkingLevelChangeEntry = {
type: "thinking_level_change",
timestamp: new Date().toISOString(),
thinkingLevel,
};
if (!this.sessionInitialized) {
this.pendingMessages.push(entry);
} else {
appendFileSync(this.sessionFile, JSON.stringify(entry) + "\n");
}
}
saveModelChange(provider: string, modelId: string): void {
if (!this.enabled) return;
const entry: ModelChangeEntry = {
type: "model_change",
timestamp: new Date().toISOString(),
provider,
modelId,
};
if (!this.sessionInitialized) {
this.pendingMessages.push(entry);
} else {
appendFileSync(this.sessionFile, JSON.stringify(entry) + "\n");
}
}
saveCompaction(entry: CompactionEntry): void {
if (!this.enabled) return;
appendFileSync(this.sessionFile, JSON.stringify(entry) + "\n");
}
/**
* Load session data (messages, model, thinking level) with compaction support.
*/
loadSession(): LoadedSession {
const entries = this.loadEntries();
return loadSessionFromEntries(entries);
}
/**
* @deprecated Use loadSession().messages instead
*/
loadMessages(): AppMessage[] {
return this.loadSession().messages;
}
/**
* @deprecated Use loadSession().thinkingLevel instead
*/
loadThinkingLevel(): string {
return this.loadSession().thinkingLevel;
}
/**
* @deprecated Use loadSession().model instead
*/
loadModel(): { provider: string; modelId: string } | null {
return this.loadSession().model;
}
getSessionId(): string {
return this.sessionId;
}
getSessionFile(): string {
return this.sessionFile;
}
/**
* Load all entries from the session file.
*/
loadEntries(): SessionEntry[] {
if (!existsSync(this.sessionFile)) return [];
const content = readFileSync(this.sessionFile, "utf8");
const entries: SessionEntry[] = [];
const lines = content.trim().split("\n");
for (const line of lines) {
if (!line.trim()) continue;
try {
const entry = JSON.parse(line) as SessionEntry;
entries.push(entry);
} catch {
// Skip malformed lines
}
}
return entries;
}
/**
* Load all sessions for the current directory with metadata
*/
loadAllSessions(): Array<{
path: string;
id: string;
created: Date;
modified: Date;
messageCount: number;
firstMessage: string;
allMessagesText: string;
}> {
const sessions: Array<{
path: string;
id: string;
created: Date;
modified: Date;
messageCount: number;
firstMessage: string;
allMessagesText: string;
}> = [];
try {
const files = readdirSync(this.sessionDir)
.filter((f) => f.endsWith(".jsonl"))
.map((f) => join(this.sessionDir, f));
for (const file of files) {
try {
const stats = statSync(file);
const content = readFileSync(file, "utf8");
const lines = content.trim().split("\n");
let sessionId = "";
let created = stats.birthtime;
let messageCount = 0;
let firstMessage = "";
const allMessages: string[] = [];
for (const line of lines) {
try {
const entry = JSON.parse(line);
// Extract session ID from first session entry
if (entry.type === "session" && !sessionId) {
sessionId = entry.id;
created = new Date(entry.timestamp);
}
// Count messages and collect all text
if (entry.type === "message") {
messageCount++;
// Extract text from user and assistant messages
if (entry.message.role === "user" || entry.message.role === "assistant") {
const textContent = entry.message.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join(" ");
if (textContent) {
allMessages.push(textContent);
// Get first user message for display
if (!firstMessage && entry.message.role === "user") {
firstMessage = textContent;
}
}
}
}
} catch {
// Skip malformed lines
}
}
sessions.push({
path: file,
id: sessionId || "unknown",
created,
modified: stats.mtime,
messageCount,
firstMessage: firstMessage || "(no messages)",
allMessagesText: allMessages.join(" "),
});
} catch (error) {
// Skip files that can't be read
console.error(`Failed to read session file ${file}:`, error);
}
}
// Sort by modified date (most recent first)
sessions.sort((a, b) => b.modified.getTime() - a.modified.getTime());
} catch (error) {
console.error("Failed to load sessions:", error);
}
return sessions;
}
/**
* Set the session file to an existing session
*/
setSessionFile(path: string): void {
this.sessionFile = path;
this.loadSessionId();
// Mark as initialized since we're loading an existing session
this.sessionInitialized = existsSync(path);
}
/**
* Check if we should initialize the session based on message history.
* Session is initialized when we have at least 1 user message and 1 assistant message.
*/
shouldInitializeSession(messages: any[]): boolean {
if (this.sessionInitialized) return false;
const userMessages = messages.filter((m) => m.role === "user");
const assistantMessages = messages.filter((m) => m.role === "assistant");
return userMessages.length >= 1 && assistantMessages.length >= 1;
}
/**
* Create a branched session from a specific message index.
* If branchFromIndex is -1, creates an empty session.
* Returns the new session file path.
*/
createBranchedSession(state: any, branchFromIndex: number): string {
// Create a new session ID for the branch
const newSessionId = uuidv4();
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
const newSessionFile = join(this.sessionDir, `${timestamp}_${newSessionId}.jsonl`);
// Write session header
const entry: SessionHeader = {
type: "session",
id: newSessionId,
timestamp: new Date().toISOString(),
cwd: process.cwd(),
provider: state.model.provider,
modelId: state.model.id,
thinkingLevel: state.thinkingLevel,
branchedFrom: this.sessionFile,
};
appendFileSync(newSessionFile, JSON.stringify(entry) + "\n");
// Write messages up to and including the branch point (if >= 0)
if (branchFromIndex >= 0) {
const messagesToWrite = state.messages.slice(0, branchFromIndex + 1);
for (const message of messagesToWrite) {
const messageEntry: SessionMessageEntry = {
type: "message",
timestamp: new Date().toISOString(),
message,
};
appendFileSync(newSessionFile, JSON.stringify(messageEntry) + "\n");
}
}
return newSessionFile;
}
/**
* Create a branched session from session entries up to (but not including) a specific entry index.
* This preserves compaction events and all entry types.
* Returns the new session file path.
*/
createBranchedSessionFromEntries(entries: SessionEntry[], branchBeforeIndex: number): string {
const newSessionId = uuidv4();
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
const newSessionFile = join(this.sessionDir, `${timestamp}_${newSessionId}.jsonl`);
// Copy all entries up to (but not including) the branch point
for (let i = 0; i < branchBeforeIndex; i++) {
const entry = entries[i];
if (entry.type === "session") {
// Rewrite session header with new ID and branchedFrom
const newHeader: SessionHeader = {
...entry,
id: newSessionId,
timestamp: new Date().toISOString(),
branchedFrom: this.sessionFile,
};
appendFileSync(newSessionFile, JSON.stringify(newHeader) + "\n");
} else {
// Copy other entries as-is
appendFileSync(newSessionFile, JSON.stringify(entry) + "\n");
}
}
return newSessionFile;
}
}

View file

@ -0,0 +1,176 @@
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import { getAgentDir } from "../utils/config.js";
export interface CompactionSettings {
enabled?: boolean; // default: true
reserveTokens?: number; // default: 16384
keepRecentTokens?: number; // default: 20000
}
export interface Settings {
lastChangelogVersion?: string;
defaultProvider?: string;
defaultModel?: string;
defaultThinkingLevel?: "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
queueMode?: "all" | "one-at-a-time";
theme?: string;
compaction?: CompactionSettings;
hideThinkingBlock?: boolean;
shellPath?: string; // Custom shell path (e.g., for Cygwin users on Windows)
collapseChangelog?: boolean; // Show condensed changelog after update (use /changelog for full)
}
export class SettingsManager {
private settingsPath: string;
private settings: Settings;
constructor(baseDir?: string) {
const dir = baseDir || getAgentDir();
this.settingsPath = join(dir, "settings.json");
this.settings = this.load();
}
private load(): Settings {
if (!existsSync(this.settingsPath)) {
return {};
}
try {
const content = readFileSync(this.settingsPath, "utf-8");
return JSON.parse(content);
} catch (error) {
console.error(`Warning: Could not read settings file: ${error}`);
return {};
}
}
private save(): void {
try {
// Ensure directory exists
const dir = dirname(this.settingsPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true });
}
writeFileSync(this.settingsPath, JSON.stringify(this.settings, null, 2), "utf-8");
} catch (error) {
console.error(`Warning: Could not save settings file: ${error}`);
}
}
getLastChangelogVersion(): string | undefined {
return this.settings.lastChangelogVersion;
}
setLastChangelogVersion(version: string): void {
this.settings.lastChangelogVersion = version;
this.save();
}
getDefaultProvider(): string | undefined {
return this.settings.defaultProvider;
}
getDefaultModel(): string | undefined {
return this.settings.defaultModel;
}
setDefaultProvider(provider: string): void {
this.settings.defaultProvider = provider;
this.save();
}
setDefaultModel(modelId: string): void {
this.settings.defaultModel = modelId;
this.save();
}
setDefaultModelAndProvider(provider: string, modelId: string): void {
this.settings.defaultProvider = provider;
this.settings.defaultModel = modelId;
this.save();
}
getQueueMode(): "all" | "one-at-a-time" {
return this.settings.queueMode || "one-at-a-time";
}
setQueueMode(mode: "all" | "one-at-a-time"): void {
this.settings.queueMode = mode;
this.save();
}
getTheme(): string | undefined {
return this.settings.theme;
}
setTheme(theme: string): void {
this.settings.theme = theme;
this.save();
}
getDefaultThinkingLevel(): "off" | "minimal" | "low" | "medium" | "high" | "xhigh" | undefined {
return this.settings.defaultThinkingLevel;
}
setDefaultThinkingLevel(level: "off" | "minimal" | "low" | "medium" | "high" | "xhigh"): void {
this.settings.defaultThinkingLevel = level;
this.save();
}
getCompactionEnabled(): boolean {
return this.settings.compaction?.enabled ?? true;
}
setCompactionEnabled(enabled: boolean): void {
if (!this.settings.compaction) {
this.settings.compaction = {};
}
this.settings.compaction.enabled = enabled;
this.save();
}
getCompactionReserveTokens(): number {
return this.settings.compaction?.reserveTokens ?? 16384;
}
getCompactionKeepRecentTokens(): number {
return this.settings.compaction?.keepRecentTokens ?? 20000;
}
getCompactionSettings(): { enabled: boolean; reserveTokens: number; keepRecentTokens: number } {
return {
enabled: this.getCompactionEnabled(),
reserveTokens: this.getCompactionReserveTokens(),
keepRecentTokens: this.getCompactionKeepRecentTokens(),
};
}
getHideThinkingBlock(): boolean {
return this.settings.hideThinkingBlock ?? false;
}
setHideThinkingBlock(hide: boolean): void {
this.settings.hideThinkingBlock = hide;
this.save();
}
getShellPath(): string | undefined {
return this.settings.shellPath;
}
setShellPath(path: string | undefined): void {
this.settings.shellPath = path;
this.save();
}
getCollapseChangelog(): boolean {
return this.settings.collapseChangelog ?? false;
}
setCollapseChangelog(collapse: boolean): void {
this.settings.collapseChangelog = collapse;
this.save();
}
}

View file

@ -0,0 +1,205 @@
import { existsSync, readdirSync, readFileSync } from "fs";
import { join, resolve } from "path";
import { CONFIG_DIR_NAME, getCommandsDir } from "../utils/config.js";
/**
* Represents a custom slash command loaded from a file
*/
export interface FileSlashCommand {
name: string;
description: string;
content: string;
source: string; // e.g., "(user)", "(project)", "(project:frontend)"
}
/**
* Parse YAML frontmatter from markdown content
* Returns { frontmatter, content } where content has frontmatter stripped
*/
function parseFrontmatter(content: string): { frontmatter: Record<string, string>; content: string } {
const frontmatter: Record<string, string> = {};
if (!content.startsWith("---")) {
return { frontmatter, content };
}
const endIndex = content.indexOf("\n---", 3);
if (endIndex === -1) {
return { frontmatter, content };
}
const frontmatterBlock = content.slice(4, endIndex);
const remainingContent = content.slice(endIndex + 4).trim();
// Simple YAML parsing - just key: value pairs
for (const line of frontmatterBlock.split("\n")) {
const match = line.match(/^(\w+):\s*(.*)$/);
if (match) {
frontmatter[match[1]] = match[2].trim();
}
}
return { frontmatter, content: remainingContent };
}
/**
* Parse command arguments respecting quoted strings (bash-style)
* Returns array of arguments
*/
export function parseCommandArgs(argsString: string): string[] {
const args: string[] = [];
let current = "";
let inQuote: string | null = null;
for (let i = 0; i < argsString.length; i++) {
const char = argsString[i];
if (inQuote) {
if (char === inQuote) {
inQuote = null;
} else {
current += char;
}
} else if (char === '"' || char === "'") {
inQuote = char;
} else if (char === " " || char === "\t") {
if (current) {
args.push(current);
current = "";
}
} else {
current += char;
}
}
if (current) {
args.push(current);
}
return args;
}
/**
* Substitute argument placeholders in command content
* Supports $1, $2, ... for positional args and $@ for all args
*/
export function substituteArgs(content: string, args: string[]): string {
let result = content;
// Replace $@ with all args joined
result = result.replace(/\$@/g, args.join(" "));
// Replace $1, $2, etc. with positional args
result = result.replace(/\$(\d+)/g, (_, num) => {
const index = parseInt(num, 10) - 1;
return args[index] ?? "";
});
return result;
}
/**
* Recursively scan a directory for .md files and load them as slash commands
*/
function loadCommandsFromDir(dir: string, source: "user" | "project", subdir: string = ""): FileSlashCommand[] {
const commands: FileSlashCommand[] = [];
if (!existsSync(dir)) {
return commands;
}
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const fullPath = join(dir, entry.name);
if (entry.isDirectory()) {
// Recurse into subdirectory
const newSubdir = subdir ? `${subdir}:${entry.name}` : entry.name;
commands.push(...loadCommandsFromDir(fullPath, source, newSubdir));
} else if (entry.isFile() && entry.name.endsWith(".md")) {
try {
const rawContent = readFileSync(fullPath, "utf-8");
const { frontmatter, content } = parseFrontmatter(rawContent);
const name = entry.name.slice(0, -3); // Remove .md extension
// Build source string
let sourceStr: string;
if (source === "user") {
sourceStr = subdir ? `(user:${subdir})` : "(user)";
} else {
sourceStr = subdir ? `(project:${subdir})` : "(project)";
}
// Get description from frontmatter or first non-empty line
let description = frontmatter.description || "";
if (!description) {
const firstLine = content.split("\n").find((line) => line.trim());
if (firstLine) {
// Truncate if too long
description = firstLine.slice(0, 60);
if (firstLine.length > 60) description += "...";
}
}
// Append source to description
description = description ? `${description} ${sourceStr}` : sourceStr;
commands.push({
name,
description,
content,
source: sourceStr,
});
} catch (error) {
// Silently skip files that can't be read
}
}
}
} catch (error) {
// Silently skip directories that can't be read
}
return commands;
}
/**
* Load all custom slash commands from:
* 1. Global: ~/{CONFIG_DIR_NAME}/agent/commands/
* 2. Project: ./{CONFIG_DIR_NAME}/commands/
*/
export function loadSlashCommands(): FileSlashCommand[] {
const commands: FileSlashCommand[] = [];
// 1. Load global commands from ~/{CONFIG_DIR_NAME}/agent/commands/
const globalCommandsDir = getCommandsDir();
commands.push(...loadCommandsFromDir(globalCommandsDir, "user"));
// 2. Load project commands from ./{CONFIG_DIR_NAME}/commands/
const projectCommandsDir = resolve(process.cwd(), CONFIG_DIR_NAME, "commands");
commands.push(...loadCommandsFromDir(projectCommandsDir, "project"));
return commands;
}
/**
* Expand a slash command if it matches a file-based command.
* Returns the expanded content or the original text if not a slash command.
*/
export function expandSlashCommand(text: string, fileCommands: FileSlashCommand[]): string {
if (!text.startsWith("/")) return text;
const spaceIndex = text.indexOf(" ");
const commandName = spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
const fileCommand = fileCommands.find((cmd) => cmd.name === commandName);
if (fileCommand) {
const args = parseCommandArgs(argsString);
return substituteArgs(fileCommand.content, args);
}
return text;
}

View file

@ -0,0 +1,191 @@
import { randomBytes } from "node:crypto";
import { createWriteStream } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { getShellConfig, killProcessTree } from "../../utils/shell.js";
import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize, type TruncationResult, truncateTail } from "./truncate.js";
/**
* Generate a unique temp file path for bash output
*/
function getTempFilePath(): string {
const id = randomBytes(8).toString("hex");
return join(tmpdir(), `pi-bash-${id}.log`);
}
const bashSchema = Type.Object({
command: Type.String({ description: "Bash command to execute" }),
timeout: Type.Optional(Type.Number({ description: "Timeout in seconds (optional, no default timeout)" })),
});
interface BashToolDetails {
truncation?: TruncationResult;
fullOutputPath?: string;
}
export const bashTool: AgentTool<typeof bashSchema> = {
name: "bash",
label: "bash",
description: `Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. Optionally provide a timeout in seconds.`,
parameters: bashSchema,
execute: async (
_toolCallId: string,
{ command, timeout }: { command: string; timeout?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
const child = spawn(shell, [...args, command], {
detached: true,
stdio: ["ignore", "pipe", "pipe"],
});
// We'll stream to a temp file if output gets large
let tempFilePath: string | undefined;
let tempFileStream: ReturnType<typeof createWriteStream> | undefined;
let totalBytes = 0;
// Keep a rolling buffer of the last chunk for tail truncation
const chunks: Buffer[] = [];
let chunksBytes = 0;
// Keep more than we need so we have enough for truncation
const maxChunksBytes = DEFAULT_MAX_BYTES * 2;
let timedOut = false;
// Set timeout if provided
let timeoutHandle: NodeJS.Timeout | undefined;
if (timeout !== undefined && timeout > 0) {
timeoutHandle = setTimeout(() => {
timedOut = true;
onAbort();
}, timeout * 1000);
}
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Start writing to temp file once we exceed the threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
tempFilePath = getTempFilePath();
tempFileStream = createWriteStream(tempFilePath);
// Write all buffered chunks to the file
for (const chunk of chunks) {
tempFileStream.write(chunk);
}
}
// Write to temp file if we have one
if (tempFileStream) {
tempFileStream.write(data);
}
// Keep rolling buffer of recent data
chunks.push(data);
chunksBytes += data.length;
// Trim old chunks if buffer is too large
while (chunksBytes > maxChunksBytes && chunks.length > 1) {
const removed = chunks.shift()!;
chunksBytes -= removed.length;
}
};
// Collect stdout and stderr together
if (child.stdout) {
child.stdout.on("data", handleData);
}
if (child.stderr) {
child.stderr.on("data", handleData);
}
// Handle process exit
child.on("close", (code) => {
if (timeoutHandle) {
clearTimeout(timeoutHandle);
}
if (signal) {
signal.removeEventListener("abort", onAbort);
}
// Close temp file stream
if (tempFileStream) {
tempFileStream.end();
}
// Combine all buffered chunks
const fullBuffer = Buffer.concat(chunks);
const fullOutput = fullBuffer.toString("utf-8");
if (signal?.aborted) {
let output = fullOutput;
if (output) output += "\n\n";
output += "Command aborted";
reject(new Error(output));
return;
}
if (timedOut) {
let output = fullOutput;
if (output) output += "\n\n";
output += `Command timed out after ${timeout} seconds`;
reject(new Error(output));
return;
}
// Apply tail truncation
const truncation = truncateTail(fullOutput);
let outputText = truncation.content || "(no output)";
// Build details with truncation info
let details: BashToolDetails | undefined;
if (truncation.truncated) {
details = {
truncation,
fullOutputPath: tempFilePath,
};
// Build actionable notice
const startLine = truncation.totalLines - truncation.outputLines + 1;
const endLine = truncation.totalLines;
if (truncation.lastLinePartial) {
// Edge case: last line alone > 30KB
const lastLineSize = formatSize(Buffer.byteLength(fullOutput.split("\n").pop() || "", "utf-8"));
outputText += `\n\n[Showing last ${formatSize(truncation.outputBytes)} of line ${endLine} (line is ${lastLineSize}). Full output: ${tempFilePath}]`;
} else if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines}. Full output: ${tempFilePath}]`;
} else {
outputText += `\n\n[Showing lines ${startLine}-${endLine} of ${truncation.totalLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Full output: ${tempFilePath}]`;
}
}
if (code !== 0 && code !== null) {
outputText += `\n\nCommand exited with code ${code}`;
reject(new Error(outputText));
} else {
resolve({ content: [{ type: "text", text: outputText }], details });
}
});
// Handle abort signal - kill entire process tree
const onAbort = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (signal) {
if (signal.aborted) {
onAbort();
} else {
signal.addEventListener("abort", onAbort, { once: true });
}
}
});
},
};

View file

@ -0,0 +1,270 @@
import * as os from "node:os";
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import * as Diff from "diff";
import { constants } from "fs";
import { access, readFile, writeFile } from "fs/promises";
import { resolve as resolvePath } from "path";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return os.homedir();
}
if (filePath.startsWith("~/")) {
return os.homedir() + filePath.slice(1);
}
return filePath;
}
/**
* Generate a unified diff string with line numbers and context
*/
function generateDiffString(oldContent: string, newContent: string, contextLines = 4): string {
const parts = Diff.diffLines(oldContent, newContent);
const output: string[] = [];
const oldLines = oldContent.split("\n");
const newLines = newContent.split("\n");
const maxLineNum = Math.max(oldLines.length, newLines.length);
const lineNumWidth = String(maxLineNum).length;
let oldLineNum = 1;
let newLineNum = 1;
let lastWasChange = false;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
const raw = part.value.split("\n");
if (raw[raw.length - 1] === "") {
raw.pop();
}
if (part.added || part.removed) {
// Show the change
for (const line of raw) {
if (part.added) {
const lineNum = String(newLineNum).padStart(lineNumWidth, " ");
output.push(`+${lineNum} ${line}`);
newLineNum++;
} else {
// removed
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(`-${lineNum} ${line}`);
oldLineNum++;
}
}
lastWasChange = true;
} else {
// Context lines - only show a few before/after changes
const nextPartIsChange = i < parts.length - 1 && (parts[i + 1].added || parts[i + 1].removed);
if (lastWasChange || nextPartIsChange) {
// Show context
let linesToShow = raw;
let skipStart = 0;
let skipEnd = 0;
if (!lastWasChange) {
// Show only last N lines as leading context
skipStart = Math.max(0, raw.length - contextLines);
linesToShow = raw.slice(skipStart);
}
if (!nextPartIsChange && linesToShow.length > contextLines) {
// Show only first N lines as trailing context
skipEnd = linesToShow.length - contextLines;
linesToShow = linesToShow.slice(0, contextLines);
}
// Add ellipsis if we skipped lines at start
if (skipStart > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped leading context
oldLineNum += skipStart;
newLineNum += skipStart;
}
for (const line of linesToShow) {
const lineNum = String(oldLineNum).padStart(lineNumWidth, " ");
output.push(` ${lineNum} ${line}`);
oldLineNum++;
newLineNum++;
}
// Add ellipsis if we skipped lines at end
if (skipEnd > 0) {
output.push(` ${"".padStart(lineNumWidth, " ")} ...`);
// Update line numbers for the skipped trailing context
oldLineNum += skipEnd;
newLineNum += skipEnd;
}
} else {
// Skip these context lines entirely
oldLineNum += raw.length;
newLineNum += raw.length;
}
lastWasChange = false;
}
}
return output.join("\n");
}
const editSchema = Type.Object({
path: Type.String({ description: "Path to the file to edit (relative or absolute)" }),
oldText: Type.String({ description: "Exact text to find and replace (must match exactly)" }),
newText: Type.String({ description: "New text to replace the old text with" }),
});
export const editTool: AgentTool<typeof editSchema> = {
name: "edit",
label: "edit",
description:
"Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits.",
parameters: editSchema,
execute: async (
_toolCallId: string,
{ path, oldText, newText }: { path: string; oldText: string; newText: string },
signal?: AbortSignal,
) => {
const absolutePath = resolvePath(expandPath(path));
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: { diff: string } | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the edit operation
(async () => {
try {
// Check if file exists
try {
await access(absolutePath, constants.R_OK | constants.W_OK);
} catch {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(new Error(`File not found: ${path}`));
return;
}
// Check if aborted before reading
if (aborted) {
return;
}
// Read the file
const content = await readFile(absolutePath, "utf-8");
// Check if aborted after reading
if (aborted) {
return;
}
// Check if old text exists
if (!content.includes(oldText)) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Could not find the exact text in ${path}. The old text must match exactly including all whitespace and newlines.`,
),
);
return;
}
// Count occurrences
const occurrences = content.split(oldText).length - 1;
if (occurrences > 1) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`Found ${occurrences} occurrences of the text in ${path}. The text must be unique. Please provide more context to make it unique.`,
),
);
return;
}
// Check if aborted before writing
if (aborted) {
return;
}
// Perform replacement using indexOf + substring (raw string replace, no special character interpretation)
// String.replace() interprets $ in the replacement string, so we do manual replacement
const index = content.indexOf(oldText);
const newContent = content.substring(0, index) + newText + content.substring(index + oldText.length);
// Verify the replacement actually changed something
if (content === newContent) {
if (signal) {
signal.removeEventListener("abort", onAbort);
}
reject(
new Error(
`No changes made to ${path}. The replacement produced identical content. This might indicate an issue with special characters or the text not existing as expected.`,
),
);
return;
}
await writeFile(absolutePath, newContent, "utf-8");
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({
content: [
{
type: "text",
text: `Successfully replaced text in ${path}. Changed ${oldText.length} characters to ${newText.length} characters.`,
},
],
details: { diff: generateDiffString(content, newContent) },
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};

View file

@ -0,0 +1,203 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { spawnSync } from "child_process";
import { existsSync } from "fs";
import { globSync } from "glob";
import { homedir } from "os";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import { DEFAULT_MAX_BYTES, formatSize, type TruncationResult, truncateHead } from "./truncate.js";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return homedir();
}
if (filePath.startsWith("~/")) {
return homedir() + filePath.slice(1);
}
return filePath;
}
const findSchema = Type.Object({
pattern: Type.String({
description: "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'",
}),
path: Type.Optional(Type.String({ description: "Directory to search in (default: current directory)" })),
limit: Type.Optional(Type.Number({ description: "Maximum number of results (default: 1000)" })),
});
const DEFAULT_LIMIT = 1000;
interface FindToolDetails {
truncation?: TruncationResult;
resultLimitReached?: number;
}
export const findTool: AgentTool<typeof findSchema> = {
name: "find",
label: "find",
description: `Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} results or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: findSchema,
execute: async (
_toolCallId: string,
{ pattern, path: searchDir, limit }: { pattern: string; path?: string; limit?: number },
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
(async () => {
try {
// Ensure fd is available
const fdPath = await ensureTool("fd", true);
if (!fdPath) {
reject(new Error("fd is not available and could not be downloaded"));
return;
}
const searchPath = path.resolve(expandPath(searchDir || "."));
const effectiveLimit = limit ?? DEFAULT_LIMIT;
// Build fd arguments
const args: string[] = [
"--glob", // Use glob pattern
"--color=never", // No ANSI colors
"--hidden", // Search hidden files (but still respect .gitignore)
"--max-results",
String(effectiveLimit),
];
// Include .gitignore files (root + nested) so fd respects them even outside git repos
const gitignoreFiles = new Set<string>();
const rootGitignore = path.join(searchPath, ".gitignore");
if (existsSync(rootGitignore)) {
gitignoreFiles.add(rootGitignore);
}
try {
const nestedGitignores = globSync("**/.gitignore", {
cwd: searchPath,
dot: true,
absolute: true,
ignore: ["**/node_modules/**", "**/.git/**"],
});
for (const file of nestedGitignores) {
gitignoreFiles.add(file);
}
} catch {
// Ignore glob errors
}
for (const gitignorePath of gitignoreFiles) {
args.push("--ignore-file", gitignorePath);
}
// Pattern and path
args.push(pattern, searchPath);
// Run fd
const result = spawnSync(fdPath, args, {
encoding: "utf-8",
maxBuffer: 10 * 1024 * 1024, // 10MB
});
signal?.removeEventListener("abort", onAbort);
if (result.error) {
reject(new Error(`Failed to run fd: ${result.error.message}`));
return;
}
const output = result.stdout?.trim() || "";
if (result.status !== 0) {
const errorMsg = result.stderr?.trim() || `fd exited with code ${result.status}`;
// fd returns non-zero for some errors but may still have partial output
if (!output) {
reject(new Error(errorMsg));
return;
}
}
if (!output) {
resolve({
content: [{ type: "text", text: "No files found matching pattern" }],
details: undefined,
});
return;
}
const lines = output.split("\n");
const relativized: string[] = [];
for (const rawLine of lines) {
const line = rawLine.replace(/\r$/, "").trim();
if (!line) {
continue;
}
const hadTrailingSlash = line.endsWith("/") || line.endsWith("\\");
let relativePath = line;
if (line.startsWith(searchPath)) {
relativePath = line.slice(searchPath.length + 1); // +1 for the /
} else {
relativePath = path.relative(searchPath, line);
}
if (hadTrailingSlash && !relativePath.endsWith("/")) {
relativePath += "/";
}
relativized.push(relativePath);
}
// Check if we hit the result limit
const resultLimitReached = relativized.length >= effectiveLimit;
// Apply byte truncation (no line limit since we already have result limit)
const rawOutput = relativized.join("\n");
const truncation = truncateHead(rawOutput, { maxLines: Number.MAX_SAFE_INTEGER });
let resultOutput = truncation.content;
const details: FindToolDetails = {};
// Build notices
const notices: string[] = [];
if (resultLimitReached) {
notices.push(
`${effectiveLimit} results limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.resultLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
resultOutput += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: resultOutput }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
})();
});
},
};

View file

@ -0,0 +1,320 @@
import { createInterface } from "node:readline";
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { readFileSync, type Stats, statSync } from "fs";
import { homedir } from "os";
import path from "path";
import { ensureTool } from "../../utils/tools-manager.js";
import {
DEFAULT_MAX_BYTES,
formatSize,
GREP_MAX_LINE_LENGTH,
type TruncationResult,
truncateHead,
truncateLine,
} from "./truncate.js";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return homedir();
}
if (filePath.startsWith("~/")) {
return homedir() + filePath.slice(1);
}
return filePath;
}
const grepSchema = Type.Object({
pattern: Type.String({ description: "Search pattern (regex or literal string)" }),
path: Type.Optional(Type.String({ description: "Directory or file to search (default: current directory)" })),
glob: Type.Optional(Type.String({ description: "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'" })),
ignoreCase: Type.Optional(Type.Boolean({ description: "Case-insensitive search (default: false)" })),
literal: Type.Optional(
Type.Boolean({ description: "Treat pattern as literal string instead of regex (default: false)" }),
),
context: Type.Optional(
Type.Number({ description: "Number of lines to show before and after each match (default: 0)" }),
),
limit: Type.Optional(Type.Number({ description: "Maximum number of matches to return (default: 100)" })),
});
const DEFAULT_LIMIT = 100;
interface GrepToolDetails {
truncation?: TruncationResult;
matchLimitReached?: number;
linesTruncated?: boolean;
}
export const grepTool: AgentTool<typeof grepSchema> = {
name: "grep",
label: "grep",
description: `Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to ${DEFAULT_LIMIT} matches or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Long lines are truncated to ${GREP_MAX_LINE_LENGTH} chars.`,
parameters: grepSchema,
execute: async (
_toolCallId: string,
{
pattern,
path: searchDir,
glob,
ignoreCase,
literal,
context,
limit,
}: {
pattern: string;
path?: string;
glob?: string;
ignoreCase?: boolean;
literal?: boolean;
context?: number;
limit?: number;
},
signal?: AbortSignal,
) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let settled = false;
const settle = (fn: () => void) => {
if (!settled) {
settled = true;
fn();
}
};
(async () => {
try {
const rgPath = await ensureTool("rg", true);
if (!rgPath) {
settle(() => reject(new Error("ripgrep (rg) is not available and could not be downloaded")));
return;
}
const searchPath = path.resolve(expandPath(searchDir || "."));
let searchStat: Stats;
try {
searchStat = statSync(searchPath);
} catch (err) {
settle(() => reject(new Error(`Path not found: ${searchPath}`)));
return;
}
const isDirectory = searchStat.isDirectory();
const contextValue = context && context > 0 ? context : 0;
const effectiveLimit = Math.max(1, limit ?? DEFAULT_LIMIT);
const formatPath = (filePath: string): string => {
if (isDirectory) {
const relative = path.relative(searchPath, filePath);
if (relative && !relative.startsWith("..")) {
return relative.replace(/\\/g, "/");
}
}
return path.basename(filePath);
};
const fileCache = new Map<string, string[]>();
const getFileLines = (filePath: string): string[] => {
let lines = fileCache.get(filePath);
if (!lines) {
try {
const content = readFileSync(filePath, "utf-8");
lines = content.replace(/\r\n/g, "\n").replace(/\r/g, "\n").split("\n");
} catch {
lines = [];
}
fileCache.set(filePath, lines);
}
return lines;
};
const args: string[] = ["--json", "--line-number", "--color=never", "--hidden"];
if (ignoreCase) {
args.push("--ignore-case");
}
if (literal) {
args.push("--fixed-strings");
}
if (glob) {
args.push("--glob", glob);
}
args.push(pattern, searchPath);
const child = spawn(rgPath, args, { stdio: ["ignore", "pipe", "pipe"] });
const rl = createInterface({ input: child.stdout });
let stderr = "";
let matchCount = 0;
let matchLimitReached = false;
let linesTruncated = false;
let aborted = false;
let killedDueToLimit = false;
const outputLines: string[] = [];
const cleanup = () => {
rl.close();
signal?.removeEventListener("abort", onAbort);
};
const stopChild = (dueToLimit: boolean = false) => {
if (!child.killed) {
killedDueToLimit = dueToLimit;
child.kill();
}
};
const onAbort = () => {
aborted = true;
stopChild();
};
signal?.addEventListener("abort", onAbort, { once: true });
child.stderr?.on("data", (chunk) => {
stderr += chunk.toString();
});
const formatBlock = (filePath: string, lineNumber: number): string[] => {
const relativePath = formatPath(filePath);
const lines = getFileLines(filePath);
if (!lines.length) {
return [`${relativePath}:${lineNumber}: (unable to read file)`];
}
const block: string[] = [];
const start = contextValue > 0 ? Math.max(1, lineNumber - contextValue) : lineNumber;
const end = contextValue > 0 ? Math.min(lines.length, lineNumber + contextValue) : lineNumber;
for (let current = start; current <= end; current++) {
const lineText = lines[current - 1] ?? "";
const sanitized = lineText.replace(/\r/g, "");
const isMatchLine = current === lineNumber;
// Truncate long lines
const { text: truncatedText, wasTruncated } = truncateLine(sanitized);
if (wasTruncated) {
linesTruncated = true;
}
if (isMatchLine) {
block.push(`${relativePath}:${current}: ${truncatedText}`);
} else {
block.push(`${relativePath}-${current}- ${truncatedText}`);
}
}
return block;
};
rl.on("line", (line) => {
if (!line.trim() || matchCount >= effectiveLimit) {
return;
}
let event: any;
try {
event = JSON.parse(line);
} catch {
return;
}
if (event.type === "match") {
matchCount++;
const filePath = event.data?.path?.text;
const lineNumber = event.data?.line_number;
if (filePath && typeof lineNumber === "number") {
outputLines.push(...formatBlock(filePath, lineNumber));
}
if (matchCount >= effectiveLimit) {
matchLimitReached = true;
stopChild(true);
}
}
});
child.on("error", (error) => {
cleanup();
settle(() => reject(new Error(`Failed to run ripgrep: ${error.message}`)));
});
child.on("close", (code) => {
cleanup();
if (aborted) {
settle(() => reject(new Error("Operation aborted")));
return;
}
if (!killedDueToLimit && code !== 0 && code !== 1) {
const errorMsg = stderr.trim() || `ripgrep exited with code ${code}`;
settle(() => reject(new Error(errorMsg)));
return;
}
if (matchCount === 0) {
settle(() =>
resolve({ content: [{ type: "text", text: "No matches found" }], details: undefined }),
);
return;
}
// Apply byte truncation (no line limit since we already have match limit)
const rawOutput = outputLines.join("\n");
const truncation = truncateHead(rawOutput, { maxLines: Number.MAX_SAFE_INTEGER });
let output = truncation.content;
const details: GrepToolDetails = {};
// Build notices
const notices: string[] = [];
if (matchLimitReached) {
notices.push(
`${effectiveLimit} matches limit reached. Use limit=${effectiveLimit * 2} for more, or refine pattern`,
);
details.matchLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (linesTruncated) {
notices.push(
`Some lines truncated to ${GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines`,
);
details.linesTruncated = true;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
settle(() =>
resolve({
content: [{ type: "text", text: output }],
details: Object.keys(details).length > 0 ? details : undefined,
}),
);
});
} catch (err) {
settle(() => reject(err as Error));
}
})();
});
},
};

View file

@ -0,0 +1,31 @@
export { bashTool } from "./bash.js";
export { editTool } from "./edit.js";
export { findTool } from "./find.js";
export { grepTool } from "./grep.js";
export { lsTool } from "./ls.js";
export { readTool } from "./read.js";
export { writeTool } from "./write.js";
import { bashTool } from "./bash.js";
import { editTool } from "./edit.js";
import { findTool } from "./find.js";
import { grepTool } from "./grep.js";
import { lsTool } from "./ls.js";
import { readTool } from "./read.js";
import { writeTool } from "./write.js";
// Default tools for full access mode
export const codingTools = [readTool, bashTool, editTool, writeTool];
// All available tools (including read-only exploration tools)
export const allTools = {
read: readTool,
bash: bashTool,
edit: editTool,
write: writeTool,
grep: grepTool,
find: findTool,
ls: lsTool,
};
export type ToolName = keyof typeof allTools;

View file

@ -0,0 +1,144 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { existsSync, readdirSync, statSync } from "fs";
import { homedir } from "os";
import nodePath from "path";
import { DEFAULT_MAX_BYTES, formatSize, type TruncationResult, truncateHead } from "./truncate.js";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return homedir();
}
if (filePath.startsWith("~/")) {
return homedir() + filePath.slice(1);
}
return filePath;
}
const lsSchema = Type.Object({
path: Type.Optional(Type.String({ description: "Directory to list (default: current directory)" })),
limit: Type.Optional(Type.Number({ description: "Maximum number of entries to return (default: 500)" })),
});
const DEFAULT_LIMIT = 500;
interface LsToolDetails {
truncation?: TruncationResult;
entryLimitReached?: number;
}
export const lsTool: AgentTool<typeof lsSchema> = {
name: "ls",
label: "ls",
description: `List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to ${DEFAULT_LIMIT} entries or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first).`,
parameters: lsSchema,
execute: async (_toolCallId: string, { path, limit }: { path?: string; limit?: number }, signal?: AbortSignal) => {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
const onAbort = () => reject(new Error("Operation aborted"));
signal?.addEventListener("abort", onAbort, { once: true });
try {
const dirPath = nodePath.resolve(expandPath(path || "."));
const effectiveLimit = limit ?? DEFAULT_LIMIT;
// Check if path exists
if (!existsSync(dirPath)) {
reject(new Error(`Path not found: ${dirPath}`));
return;
}
// Check if path is a directory
const stat = statSync(dirPath);
if (!stat.isDirectory()) {
reject(new Error(`Not a directory: ${dirPath}`));
return;
}
// Read directory entries
let entries: string[];
try {
entries = readdirSync(dirPath);
} catch (e: any) {
reject(new Error(`Cannot read directory: ${e.message}`));
return;
}
// Sort alphabetically (case-insensitive)
entries.sort((a, b) => a.toLowerCase().localeCompare(b.toLowerCase()));
// Format entries with directory indicators
const results: string[] = [];
let entryLimitReached = false;
for (const entry of entries) {
if (results.length >= effectiveLimit) {
entryLimitReached = true;
break;
}
const fullPath = nodePath.join(dirPath, entry);
let suffix = "";
try {
const entryStat = statSync(fullPath);
if (entryStat.isDirectory()) {
suffix = "/";
}
} catch {
// Skip entries we can't stat
continue;
}
results.push(entry + suffix);
}
signal?.removeEventListener("abort", onAbort);
if (results.length === 0) {
resolve({ content: [{ type: "text", text: "(empty directory)" }], details: undefined });
return;
}
// Apply byte truncation (no line limit since we already have entry limit)
const rawOutput = results.join("\n");
const truncation = truncateHead(rawOutput, { maxLines: Number.MAX_SAFE_INTEGER });
let output = truncation.content;
const details: LsToolDetails = {};
// Build notices
const notices: string[] = [];
if (entryLimitReached) {
notices.push(`${effectiveLimit} entries limit reached. Use limit=${effectiveLimit * 2} for more`);
details.entryLimitReached = effectiveLimit;
}
if (truncation.truncated) {
notices.push(`${formatSize(DEFAULT_MAX_BYTES)} limit reached`);
details.truncation = truncation;
}
if (notices.length > 0) {
output += `\n\n[${notices.join(". ")}]`;
}
resolve({
content: [{ type: "text", text: output }],
details: Object.keys(details).length > 0 ? details : undefined,
});
} catch (e: any) {
signal?.removeEventListener("abort", onAbort);
reject(e);
}
});
},
};

View file

@ -0,0 +1,198 @@
import * as os from "node:os";
import type { AgentTool, ImageContent, TextContent } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { constants } from "fs";
import { access, readFile } from "fs/promises";
import { extname, resolve as resolvePath } from "path";
import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize, type TruncationResult, truncateHead } from "./truncate.js";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return os.homedir();
}
if (filePath.startsWith("~/")) {
return os.homedir() + filePath.slice(1);
}
return filePath;
}
/**
* Map of file extensions to MIME types for common image formats
*/
const IMAGE_MIME_TYPES: Record<string, string> = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
};
/**
* Check if a file is an image based on its extension
*/
function isImageFile(filePath: string): string | null {
const ext = extname(filePath).toLowerCase();
return IMAGE_MIME_TYPES[ext] || null;
}
const readSchema = Type.Object({
path: Type.String({ description: "Path to the file to read (relative or absolute)" }),
offset: Type.Optional(Type.Number({ description: "Line number to start reading from (1-indexed)" })),
limit: Type.Optional(Type.Number({ description: "Maximum number of lines to read" })),
});
interface ReadToolDetails {
truncation?: TruncationResult;
}
export const readTool: AgentTool<typeof readSchema> = {
name: "read",
label: "read",
description: `Read the contents of a file. Supports text files and images (jpg, png, gif, webp). Images are sent as attachments. For text files, output is truncated to ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB (whichever is hit first). Use offset/limit for large files.`,
parameters: readSchema,
execute: async (
_toolCallId: string,
{ path, offset, limit }: { path: string; offset?: number; limit?: number },
signal?: AbortSignal,
) => {
const absolutePath = resolvePath(expandPath(path));
const mimeType = isImageFile(absolutePath);
return new Promise<{ content: (TextContent | ImageContent)[]; details: ReadToolDetails | undefined }>(
(resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the read operation
(async () => {
try {
// Check if file exists
await access(absolutePath, constants.R_OK);
// Check if aborted before reading
if (aborted) {
return;
}
// Read the file based on type
let content: (TextContent | ImageContent)[];
let details: ReadToolDetails | undefined;
if (mimeType) {
// Read as image (binary)
const buffer = await readFile(absolutePath);
const base64 = buffer.toString("base64");
content = [
{ type: "text", text: `Read image file [${mimeType}]` },
{ type: "image", data: base64, mimeType },
];
} else {
// Read as text
const textContent = await readFile(absolutePath, "utf-8");
const allLines = textContent.split("\n");
const totalFileLines = allLines.length;
// Apply offset if specified (1-indexed to 0-indexed)
const startLine = offset ? Math.max(0, offset - 1) : 0;
const startLineDisplay = startLine + 1; // For display (1-indexed)
// Check if offset is out of bounds
if (startLine >= allLines.length) {
throw new Error(`Offset ${offset} is beyond end of file (${allLines.length} lines total)`);
}
// If limit is specified by user, use it; otherwise we'll let truncateHead decide
let selectedContent: string;
let userLimitedLines: number | undefined;
if (limit !== undefined) {
const endLine = Math.min(startLine + limit, allLines.length);
selectedContent = allLines.slice(startLine, endLine).join("\n");
userLimitedLines = endLine - startLine;
} else {
selectedContent = allLines.slice(startLine).join("\n");
}
// Apply truncation (respects both line and byte limits)
const truncation = truncateHead(selectedContent);
let outputText: string;
if (truncation.firstLineExceedsLimit) {
// First line at offset exceeds 30KB - tell model to use bash
const firstLineSize = formatSize(Buffer.byteLength(allLines[startLine], "utf-8"));
outputText = `[Line ${startLineDisplay} is ${firstLineSize}, exceeds ${formatSize(DEFAULT_MAX_BYTES)} limit. Use bash: sed -n '${startLineDisplay}p' ${path} | head -c ${DEFAULT_MAX_BYTES}]`;
details = { truncation };
} else if (truncation.truncated) {
// Truncation occurred - build actionable notice
const endLineDisplay = startLineDisplay + truncation.outputLines - 1;
const nextOffset = endLineDisplay + 1;
outputText = truncation.content;
if (truncation.truncatedBy === "lines") {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines}. Use offset=${nextOffset} to continue]`;
} else {
outputText += `\n\n[Showing lines ${startLineDisplay}-${endLineDisplay} of ${totalFileLines} (${formatSize(DEFAULT_MAX_BYTES)} limit). Use offset=${nextOffset} to continue]`;
}
details = { truncation };
} else if (userLimitedLines !== undefined && startLine + userLimitedLines < allLines.length) {
// User specified limit, there's more content, but no truncation
const endLineDisplay = startLineDisplay + userLimitedLines - 1;
const remaining = allLines.length - (startLine + userLimitedLines);
const nextOffset = startLine + userLimitedLines + 1;
outputText = truncation.content;
outputText += `\n\n[${remaining} more lines in file. Use offset=${nextOffset} to continue]`;
} else {
// No truncation, no user limit exceeded
outputText = truncation.content;
}
content = [{ type: "text", text: outputText }];
}
// Check if aborted after reading
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({ content, details });
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
},
);
},
};

View file

@ -0,0 +1,251 @@
/**
* Shared truncation utilities for tool outputs.
*
* Truncation is based on two independent limits - whichever is hit first wins:
* - Line limit (default: 2000 lines)
* - Byte limit (default: 30KB)
*
* Never returns partial lines (except bash tail truncation edge case).
*/
export const DEFAULT_MAX_LINES = 2000;
export const DEFAULT_MAX_BYTES = 50 * 1024; // 50KB
export const GREP_MAX_LINE_LENGTH = 500; // Max chars per grep match line
export interface TruncationResult {
/** The truncated content */
content: string;
/** Whether truncation occurred */
truncated: boolean;
/** Which limit was hit: "lines", "bytes", or null if not truncated */
truncatedBy: "lines" | "bytes" | null;
/** Total number of lines in the original content */
totalLines: number;
/** Total number of bytes in the original content */
totalBytes: number;
/** Number of complete lines in the truncated output */
outputLines: number;
/** Number of bytes in the truncated output */
outputBytes: number;
/** Whether the last line was partially truncated (only for tail truncation edge case) */
lastLinePartial: boolean;
/** Whether the first line exceeded the byte limit (for head truncation) */
firstLineExceedsLimit: boolean;
}
export interface TruncationOptions {
/** Maximum number of lines (default: 2000) */
maxLines?: number;
/** Maximum number of bytes (default: 30KB) */
maxBytes?: number;
}
/**
* Format bytes as human-readable size.
*/
export function formatSize(bytes: number): string {
if (bytes < 1024) {
return `${bytes}B`;
} else if (bytes < 1024 * 1024) {
return `${(bytes / 1024).toFixed(1)}KB`;
} else {
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
}
}
/**
* Truncate content from the head (keep first N lines/bytes).
* Suitable for file reads where you want to see the beginning.
*
* Never returns partial lines. If first line exceeds byte limit,
* returns empty content with firstLineExceedsLimit=true.
*/
export function truncateHead(content: string, options: TruncationOptions = {}): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
};
}
// Check if first line alone exceeds byte limit
const firstLineBytes = Buffer.byteLength(lines[0], "utf-8");
if (firstLineBytes > maxBytes) {
return {
content: "",
truncated: true,
truncatedBy: "bytes",
totalLines,
totalBytes,
outputLines: 0,
outputBytes: 0,
lastLinePartial: false,
firstLineExceedsLimit: true,
};
}
// Collect complete lines that fit
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
for (let i = 0; i < lines.length && i < maxLines; i++) {
const line = lines[i];
const lineBytes = Buffer.byteLength(line, "utf-8") + (i > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
break;
}
outputLinesArr.push(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
};
}
/**
* Truncate content from the tail (keep last N lines/bytes).
* Suitable for bash output where you want to see the end (errors, final results).
*
* May return partial first line if the last line of original content exceeds byte limit.
*/
export function truncateTail(content: string, options: TruncationOptions = {}): TruncationResult {
const maxLines = options.maxLines ?? DEFAULT_MAX_LINES;
const maxBytes = options.maxBytes ?? DEFAULT_MAX_BYTES;
const totalBytes = Buffer.byteLength(content, "utf-8");
const lines = content.split("\n");
const totalLines = lines.length;
// Check if no truncation needed
if (totalLines <= maxLines && totalBytes <= maxBytes) {
return {
content,
truncated: false,
truncatedBy: null,
totalLines,
totalBytes,
outputLines: totalLines,
outputBytes: totalBytes,
lastLinePartial: false,
firstLineExceedsLimit: false,
};
}
// Work backwards from the end
const outputLinesArr: string[] = [];
let outputBytesCount = 0;
let truncatedBy: "lines" | "bytes" = "lines";
let lastLinePartial = false;
for (let i = lines.length - 1; i >= 0 && outputLinesArr.length < maxLines; i--) {
const line = lines[i];
const lineBytes = Buffer.byteLength(line, "utf-8") + (outputLinesArr.length > 0 ? 1 : 0); // +1 for newline
if (outputBytesCount + lineBytes > maxBytes) {
truncatedBy = "bytes";
// Edge case: if we haven't added ANY lines yet and this line exceeds maxBytes,
// take the end of the line (partial)
if (outputLinesArr.length === 0) {
const truncatedLine = truncateStringToBytesFromEnd(line, maxBytes);
outputLinesArr.unshift(truncatedLine);
outputBytesCount = Buffer.byteLength(truncatedLine, "utf-8");
lastLinePartial = true;
}
break;
}
outputLinesArr.unshift(line);
outputBytesCount += lineBytes;
}
// If we exited due to line limit
if (outputLinesArr.length >= maxLines && outputBytesCount <= maxBytes) {
truncatedBy = "lines";
}
const outputContent = outputLinesArr.join("\n");
const finalOutputBytes = Buffer.byteLength(outputContent, "utf-8");
return {
content: outputContent,
truncated: true,
truncatedBy,
totalLines,
totalBytes,
outputLines: outputLinesArr.length,
outputBytes: finalOutputBytes,
lastLinePartial,
firstLineExceedsLimit: false,
};
}
/**
* Truncate a string to fit within a byte limit (from the end).
* Handles multi-byte UTF-8 characters correctly.
*/
function truncateStringToBytesFromEnd(str: string, maxBytes: number): string {
const buf = Buffer.from(str, "utf-8");
if (buf.length <= maxBytes) {
return str;
}
// Start from the end, skip maxBytes back
let start = buf.length - maxBytes;
// Find a valid UTF-8 boundary (start of a character)
while (start < buf.length && (buf[start] & 0xc0) === 0x80) {
start++;
}
return buf.slice(start).toString("utf-8");
}
/**
* Truncate a single line to max characters, adding [truncated] suffix.
* Used for grep match lines.
*/
export function truncateLine(
line: string,
maxChars: number = GREP_MAX_LINE_LENGTH,
): { text: string; wasTruncated: boolean } {
if (line.length <= maxChars) {
return { text: line, wasTruncated: false };
}
return { text: line.slice(0, maxChars) + "... [truncated]", wasTruncated: true };
}

View file

@ -0,0 +1,95 @@
import * as os from "node:os";
import type { AgentTool } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { mkdir, writeFile } from "fs/promises";
import { dirname, resolve as resolvePath } from "path";
/**
* Expand ~ to home directory
*/
function expandPath(filePath: string): string {
if (filePath === "~") {
return os.homedir();
}
if (filePath.startsWith("~/")) {
return os.homedir() + filePath.slice(1);
}
return filePath;
}
const writeSchema = Type.Object({
path: Type.String({ description: "Path to the file to write (relative or absolute)" }),
content: Type.String({ description: "Content to write to the file" }),
});
export const writeTool: AgentTool<typeof writeSchema> = {
name: "write",
label: "write",
description:
"Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories.",
parameters: writeSchema,
execute: async (_toolCallId: string, { path, content }: { path: string; content: string }, signal?: AbortSignal) => {
const absolutePath = resolvePath(expandPath(path));
const dir = dirname(absolutePath);
return new Promise<{ content: Array<{ type: "text"; text: string }>; details: undefined }>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
reject(new Error("Operation aborted"));
return;
}
let aborted = false;
// Set up abort handler
const onAbort = () => {
aborted = true;
reject(new Error("Operation aborted"));
};
if (signal) {
signal.addEventListener("abort", onAbort, { once: true });
}
// Perform the write operation
(async () => {
try {
// Create parent directories if needed
await mkdir(dir, { recursive: true });
// Check if aborted before writing
if (aborted) {
return;
}
// Write the file
await writeFile(absolutePath, content, "utf-8");
// Check if aborted after writing
if (aborted) {
return;
}
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
resolve({
content: [{ type: "text", text: `Successfully wrote ${content.length} bytes to ${path}` }],
details: undefined,
});
} catch (error: any) {
// Clean up abort handler
if (signal) {
signal.removeEventListener("abort", onAbort);
}
if (!aborted) {
reject(error);
}
}
})();
});
},
};