Rework custom tools API with CustomToolContext

- CustomAgentTool renamed to CustomTool
- ToolAPI renamed to CustomToolAPI
- ToolContext renamed to CustomToolContext
- ToolSessionEvent renamed to CustomToolSessionEvent
- Added CustomToolContext parameter to execute() and onSession()
- CustomToolFactory now returns CustomTool<any, any> for type compatibility
- dispose() replaced with onSession({ reason: 'shutdown' })
- Added wrapCustomTool() to convert CustomTool to AgentTool
- Session exposes setToolUIContext() instead of leaking internals
- Fix ToolExecutionComponent to sync with toolOutputExpanded state
- Update all custom tool examples for new API
This commit is contained in:
Mario Zechner 2025-12-31 12:05:24 +01:00
parent b123df5fab
commit 568150f18b
27 changed files with 336 additions and 289 deletions

View file

@ -10,9 +10,10 @@ const factory: CustomToolFactory = (_pi) => ({
}), }),
async execute(_toolCallId, params) { async execute(_toolCallId, params) {
const { name } = params as { name: string };
return { return {
content: [{ type: "text", text: `Hello, ${params.name}!` }], content: [{ type: "text", text: `Hello, ${name}!` }],
details: { greeted: params.name }, details: { greeted: name },
}; };
}, },
}); });

View file

@ -2,7 +2,7 @@
* Question Tool - Let the LLM ask the user a question with options * Question Tool - Let the LLM ask the user a question with options
*/ */
import type { CustomAgentTool, CustomToolFactory } from "@mariozechner/pi-coding-agent"; import type { CustomTool, CustomToolFactory } from "@mariozechner/pi-coding-agent";
import { Text } from "@mariozechner/pi-tui"; import { Text } from "@mariozechner/pi-tui";
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
@ -18,7 +18,7 @@ const QuestionParams = Type.Object({
}); });
const factory: CustomToolFactory = (pi) => { const factory: CustomToolFactory = (pi) => {
const tool: CustomAgentTool<typeof QuestionParams, QuestionDetails> = { const tool: CustomTool<typeof QuestionParams, QuestionDetails> = {
name: "question", name: "question",
label: "Question", label: "Question",
description: "Ask the user a question and let them pick from options. Use when you need user input to proceed.", description: "Ask the user a question and let them pick from options. Use when you need user input to proceed.",

View file

@ -20,10 +20,10 @@ import type { AgentToolResult } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai"; import type { Message } from "@mariozechner/pi-ai";
import { StringEnum } from "@mariozechner/pi-ai"; import { StringEnum } from "@mariozechner/pi-ai";
import { import {
type CustomAgentTool, type CustomTool,
type CustomToolAPI,
type CustomToolFactory, type CustomToolFactory,
getMarkdownTheme, getMarkdownTheme,
type ToolAPI,
} from "@mariozechner/pi-coding-agent"; } from "@mariozechner/pi-coding-agent";
import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui"; import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
@ -224,7 +224,7 @@ function writePromptToTempFile(agentName: string, prompt: string): { dir: string
type OnUpdateCallback = (partial: AgentToolResult<SubagentDetails>) => void; type OnUpdateCallback = (partial: AgentToolResult<SubagentDetails>) => void;
async function runSingleAgent( async function runSingleAgent(
pi: ToolAPI, pi: CustomToolAPI,
agents: AgentConfig[], agents: AgentConfig[],
agentName: string, agentName: string,
task: string, task: string,
@ -411,7 +411,7 @@ const SubagentParams = Type.Object({
}); });
const factory: CustomToolFactory = (pi) => { const factory: CustomToolFactory = (pi) => {
const tool: CustomAgentTool<typeof SubagentParams, SubagentDetails> = { const tool: CustomTool<typeof SubagentParams, SubagentDetails> = {
name: "subagent", name: "subagent",
label: "Subagent", label: "Subagent",
get description() { get description() {
@ -433,7 +433,7 @@ const factory: CustomToolFactory = (pi) => {
}, },
parameters: SubagentParams, parameters: SubagentParams,
async execute(_toolCallId, params, signal, onUpdate) { async execute(_toolCallId, params, signal, onUpdate, _ctx) {
const agentScope: AgentScope = params.agentScope ?? "user"; const agentScope: AgentScope = params.agentScope ?? "user";
const discovery = discoverAgents(pi.cwd, agentScope); const discovery = discoverAgents(pi.cwd, agentScope);
const agents = discovery.agents; const agents = discovery.agents;

View file

@ -9,7 +9,12 @@
*/ */
import { StringEnum } from "@mariozechner/pi-ai"; import { StringEnum } from "@mariozechner/pi-ai";
import type { CustomAgentTool, CustomToolFactory, ToolSessionEvent } from "@mariozechner/pi-coding-agent"; import type {
CustomTool,
CustomToolContext,
CustomToolFactory,
CustomToolSessionEvent,
} from "@mariozechner/pi-coding-agent";
import { Text } from "@mariozechner/pi-tui"; import { Text } from "@mariozechner/pi-tui";
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
@ -43,11 +48,12 @@ const factory: CustomToolFactory = (_pi) => {
* Reconstruct state from session entries. * Reconstruct state from session entries.
* Scans tool results for this tool and applies them in order. * Scans tool results for this tool and applies them in order.
*/ */
const reconstructState = (event: ToolSessionEvent) => { const reconstructState = (_event: CustomToolSessionEvent, ctx: CustomToolContext) => {
todos = []; todos = [];
nextId = 1; nextId = 1;
for (const entry of event.entries) { // Use getBranch() to get entries on the current branch
for (const entry of ctx.sessionManager.getBranch()) {
if (entry.type !== "message") continue; if (entry.type !== "message") continue;
const msg = entry.message; const msg = entry.message;
@ -63,7 +69,7 @@ const factory: CustomToolFactory = (_pi) => {
} }
}; };
const tool: CustomAgentTool<typeof TodoParams, TodoDetails> = { const tool: CustomTool<typeof TodoParams, TodoDetails> = {
name: "todo", name: "todo",
label: "Todo", label: "Todo",
description: "Manage a todo list. Actions: list, add (text), toggle (id), clear", description: "Manage a todo list. Actions: list, add (text), toggle (id), clear",

View file

@ -14,7 +14,6 @@
*/ */
import { complete, getModel } from "@mariozechner/pi-ai"; import { complete, getModel } from "@mariozechner/pi-ai";
import type { CompactionEntry } from "@mariozechner/pi-coding-agent";
import { convertToLlm } from "@mariozechner/pi-coding-agent"; import { convertToLlm } from "@mariozechner/pi-coding-agent";
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
@ -22,7 +21,7 @@ export default function (pi: HookAPI) {
pi.on("session_before_compact", async (event, ctx) => { pi.on("session_before_compact", async (event, ctx) => {
ctx.ui.notify("Custom compaction hook triggered", "info"); ctx.ui.notify("Custom compaction hook triggered", "info");
const { preparation, branchEntries, signal } = event; const { preparation, branchEntries: _, signal } = event;
const { messagesToSummarize, turnPrefixMessages, tokensBefore, firstKeptEntryId, previousSummary } = preparation; const { messagesToSummarize, turnPrefixMessages, tokensBefore, firstKeptEntryId, previousSummary } = preparation;
// Use Gemini Flash for summarization (cheaper/faster than most conversation models) // Use Gemini Flash for summarization (cheaper/faster than most conversation models)

View file

@ -11,7 +11,7 @@
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
import { import {
bashTool, // read, bash, edit, write - uses process.cwd() bashTool, // read, bash, edit, write - uses process.cwd()
type CustomAgentTool, type CustomTool,
createAgentSession, createAgentSession,
createBashTool, createBashTool,
createCodingTools, // Factory: creates tools for specific cwd createCodingTools, // Factory: creates tools for specific cwd
@ -55,7 +55,7 @@ await createAgentSession({
console.log("Specific tools with custom cwd session created"); console.log("Specific tools with custom cwd session created");
// Inline custom tool (needs TypeBox schema) // Inline custom tool (needs TypeBox schema)
const weatherTool: CustomAgentTool = { const weatherTool: CustomTool = {
name: "get_weather", name: "get_weather",
label: "Get Weather", label: "Get Weather",
description: "Get current weather for a city", description: "Get current weather for a city",

View file

@ -12,7 +12,7 @@ import { getModel } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
import { import {
AuthStorage, AuthStorage,
type CustomAgentTool, type CustomTool,
createAgentSession, createAgentSession,
createBashTool, createBashTool,
createReadTool, createReadTool,
@ -42,7 +42,7 @@ const auditHook: HookFactory = (api) => {
}; };
// Inline custom tool // Inline custom tool
const statusTool: CustomAgentTool = { const statusTool: CustomTool = {
name: "status", name: "status",
label: "Status", label: "Status",
description: "Get system status", description: "Get system status",
@ -68,15 +68,12 @@ const cwd = process.cwd();
const { session } = await createAgentSession({ const { session } = await createAgentSession({
cwd, cwd,
agentDir: "/tmp/my-agent", agentDir: "/tmp/my-agent",
model, model,
thinkingLevel: "off", thinkingLevel: "off",
authStorage, authStorage,
modelRegistry, modelRegistry,
systemPrompt: `You are a minimal assistant. systemPrompt: `You are a minimal assistant.
Available: read, bash, status. Be concise.`, Available: read, bash, status. Be concise.`,
// Use factory functions with the same cwd to ensure path resolution works correctly // Use factory functions with the same cwd to ensure path resolution works correctly
tools: [createReadTool(cwd), createBashTool(cwd)], tools: [createReadTool(cwd), createBashTool(cwd)],
customTools: [{ tool: statusTool }], customTools: [{ tool: statusTool }],

View file

@ -27,7 +27,7 @@ import {
prepareCompaction, prepareCompaction,
shouldCompact, shouldCompact,
} from "./compaction/index.js"; } from "./compaction/index.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custom-tools/index.js"; import type { CustomToolContext, CustomToolSessionEvent, LoadedCustomTool } from "./custom-tools/index.js";
import { exportSessionToHtml } from "./export-html.js"; import { exportSessionToHtml } from "./export-html.js";
import type { import type {
HookContext, HookContext,
@ -698,7 +698,7 @@ export class AgentSession {
} }
// Emit session event to custom tools // Emit session event to custom tools
await this.emitToolSessionEvent("new", previousSessionFile); await this.emitCustomToolSessionEvent("new", previousSessionFile);
return true; return true;
} }
@ -895,7 +895,7 @@ export class AgentSession {
throw new Error(`No API key for ${this.model.provider}`); throw new Error(`No API key for ${this.model.provider}`);
} }
const pathEntries = this.sessionManager.getPath(); const pathEntries = this.sessionManager.getBranch();
const settings = this.settingsManager.getCompactionSettings(); const settings = this.settingsManager.getCompactionSettings();
const preparation = prepareCompaction(pathEntries, settings); const preparation = prepareCompaction(pathEntries, settings);
@ -1068,7 +1068,7 @@ export class AgentSession {
return; return;
} }
const pathEntries = this.sessionManager.getPath(); const pathEntries = this.sessionManager.getBranch();
const preparation = prepareCompaction(pathEntries, settings); const preparation = prepareCompaction(pathEntries, settings);
if (!preparation) { if (!preparation) {
@ -1473,7 +1473,7 @@ export class AgentSession {
} }
// Emit session event to custom tools // Emit session event to custom tools
await this.emitToolSessionEvent("switch", previousSessionFile); await this.emitCustomToolSessionEvent("switch", previousSessionFile);
this.agent.replaceMessages(sessionContext.messages); this.agent.replaceMessages(sessionContext.messages);
@ -1550,7 +1550,7 @@ export class AgentSession {
} }
// Emit session event to custom tools (with reason "branch") // Emit session event to custom tools (with reason "branch")
await this.emitToolSessionEvent("branch", previousSessionFile); await this.emitCustomToolSessionEvent("branch", previousSessionFile);
if (!skipConversationRestore) { if (!skipConversationRestore) {
this.agent.replaceMessages(sessionContext.messages); this.agent.replaceMessages(sessionContext.messages);
@ -1720,7 +1720,7 @@ export class AgentSession {
} }
// Emit to custom tools // Emit to custom tools
await this.emitToolSessionEvent("tree", this.sessionFile); await this.emitCustomToolSessionEvent("tree", this.sessionFile);
this._branchSummaryAbortController = undefined; this._branchSummaryAbortController = undefined;
return { editorText, cancelled: false, summaryEntry }; return { editorText, cancelled: false, summaryEntry };
@ -1877,20 +1877,23 @@ export class AgentSession {
* Emit session event to all custom tools. * Emit session event to all custom tools.
* Called on session switch, branch, tree navigation, and shutdown. * Called on session switch, branch, tree navigation, and shutdown.
*/ */
async emitToolSessionEvent( async emitCustomToolSessionEvent(
reason: ToolSessionEvent["reason"], reason: CustomToolSessionEvent["reason"],
previousSessionFile?: string | undefined, previousSessionFile?: string | undefined,
): Promise<void> { ): Promise<void> {
const event: ToolSessionEvent = { if (!this._customTools) return;
entries: this.sessionManager.getEntries(),
sessionFile: this.sessionFile, const event: CustomToolSessionEvent = { reason, previousSessionFile };
previousSessionFile, const ctx: CustomToolContext = {
reason, sessionManager: this.sessionManager,
modelRegistry: this._modelRegistry,
model: this.agent.state.model,
}; };
for (const { tool } of this._customTools) { for (const { tool } of this._customTools) {
if (tool.onSession) { if (tool.onSession) {
try { try {
await tool.onSession(event); await tool.onSession(event, ctx);
} catch (_err) { } catch (_err) {
// Silently ignore tool errors during session events // Silently ignore tool errors during session events
} }

View file

@ -102,8 +102,8 @@ export function collectEntriesForBranchSummary(
} }
// Find common ancestor (deepest node that's on both paths) // Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getPath(oldLeafId).map((e) => e.id)); const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
const targetPath = session.getPath(targetId); const targetPath = session.getBranch(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor // targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null; let commonAncestorId: string | null = null;

View file

@ -4,14 +4,18 @@
export { discoverAndLoadCustomTools, loadCustomTools } from "./loader.js"; export { discoverAndLoadCustomTools, loadCustomTools } from "./loader.js";
export type { export type {
AgentToolResult,
AgentToolUpdateCallback, AgentToolUpdateCallback,
CustomAgentTool, CustomTool,
CustomToolAPI,
CustomToolContext,
CustomToolFactory, CustomToolFactory,
CustomToolResult,
CustomToolSessionEvent,
CustomToolsLoadResult, CustomToolsLoadResult,
CustomToolUIContext,
ExecResult, ExecResult,
LoadedCustomTool, LoadedCustomTool,
RenderResultOptions, RenderResultOptions,
SessionEvent,
ToolAPI,
ToolUIContext,
} from "./types.js"; } from "./types.js";
export { wrapCustomTool, wrapCustomTools } from "./wrapper.js";

View file

@ -17,7 +17,7 @@ import { getAgentDir, isBunBinary } from "../../config.js";
import type { ExecOptions } from "../exec.js"; import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js"; import { execCommand } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js"; import type { HookUIContext } from "../hooks/types.js";
import type { CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool, ToolAPI } from "./types.js"; import type { CustomToolAPI, CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool } from "./types.js";
// Create require function to resolve module paths at runtime // Create require function to resolve module paths at runtime
const require = createRequire(import.meta.url); const require = createRequire(import.meta.url);
@ -104,7 +104,7 @@ function createNoOpUIContext(): HookUIContext {
*/ */
async function loadToolWithBun( async function loadToolWithBun(
resolvedPath: string, resolvedPath: string,
sharedApi: ToolAPI, sharedApi: CustomToolAPI,
): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> { ): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> {
try { try {
// Try to import directly - will work for tools without @mariozechner/* imports // Try to import directly - will work for tools without @mariozechner/* imports
@ -149,7 +149,7 @@ async function loadToolWithBun(
async function loadTool( async function loadTool(
toolPath: string, toolPath: string,
cwd: string, cwd: string,
sharedApi: ToolAPI, sharedApi: CustomToolAPI,
): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> { ): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> {
const resolvedPath = resolveToolPath(toolPath, cwd); const resolvedPath = resolveToolPath(toolPath, cwd);
@ -209,7 +209,7 @@ export async function loadCustomTools(
const seenNames = new Set<string>(builtInToolNames); const seenNames = new Set<string>(builtInToolNames);
// Shared API object - all tools get the same instance // Shared API object - all tools get the same instance
const sharedApi: ToolAPI = { const sharedApi: CustomToolAPI = {
cwd, cwd,
exec: (command: string, args: string[], options?: ExecOptions) => exec: (command: string, args: string[], options?: ExecOptions) =>
execCommand(command, args, options?.cwd ?? cwd, options), execCommand(command, args, options?.cwd ?? cwd, options),

View file

@ -5,45 +5,56 @@
* They can provide custom rendering for tool calls and results in the TUI. * They can provide custom rendering for tool calls and results in the TUI.
*/ */
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core"; import type { AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui"; import type { Component } from "@mariozechner/pi-tui";
import type { Static, TSchema } from "@sinclair/typebox"; import type { Static, TSchema } from "@sinclair/typebox";
import type { Theme } from "../../modes/interactive/theme/theme.js"; import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ExecOptions, ExecResult } from "../exec.js"; import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js"; import type { HookUIContext } from "../hooks/types.js";
import type { SessionEntry } from "../session-manager.js"; import type { ModelRegistry } from "../model-registry.js";
import type { ReadonlySessionManager } from "../session-manager.js";
/** Alias for clarity */ /** Alias for clarity */
export type ToolUIContext = HookUIContext; export type CustomToolUIContext = HookUIContext;
/** Re-export for custom tools to use in execute signature */ /** Re-export for custom tools to use in execute signature */
export type { AgentToolUpdateCallback }; export type { AgentToolResult, AgentToolUpdateCallback };
// Re-export for backward compatibility // Re-export for backward compatibility
export type { ExecOptions, ExecResult } from "../exec.js"; export type { ExecOptions, ExecResult } from "../exec.js";
/** API passed to custom tool factory (stable across session changes) */ /** API passed to custom tool factory (stable across session changes) */
export interface ToolAPI { export interface CustomToolAPI {
/** Current working directory */ /** Current working directory */
cwd: string; cwd: string;
/** Execute a command */ /** Execute a command */
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>; exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
/** UI methods for user interaction (select, confirm, input, notify) */ /** UI methods for user interaction (select, confirm, input, notify, custom) */
ui: ToolUIContext; ui: CustomToolUIContext;
/** Whether UI is available (false in print/RPC mode) */ /** Whether UI is available (false in print/RPC mode) */
hasUI: boolean; hasUI: boolean;
} }
/**
* Context passed to tool execute and onSession callbacks.
* Provides access to session state and model information.
*/
export interface CustomToolContext {
/** Session manager (read-only) */
sessionManager: ReadonlySessionManager;
/** Model registry - use for API key resolution and model retrieval */
modelRegistry: ModelRegistry;
/** Current model (may be undefined if no model is selected yet) */
model: Model<any> | undefined;
}
/** Session event passed to onSession callback */ /** Session event passed to onSession callback */
export interface SessionEvent { export interface CustomToolSessionEvent {
/** All session entries (including pre-compaction history) */
entries: SessionEntry[];
/** Current session file path, or undefined in --no-session mode */
sessionFile: string | undefined;
/** Previous session file path, or undefined for "start", "new", and "shutdown" */
previousSessionFile: string | undefined;
/** Reason for the session event */ /** Reason for the session event */
reason: "start" | "switch" | "branch" | "new" | "tree" | "shutdown"; reason: "start" | "switch" | "branch" | "new" | "tree" | "shutdown";
/** Previous session file path, or undefined for "start", "new", and "shutdown" */
previousSessionFile: string | undefined;
} }
/** Rendering options passed to renderResult */ /** Rendering options passed to renderResult */
@ -54,58 +65,89 @@ export interface RenderResultOptions {
isPartial: boolean; isPartial: boolean;
} }
export type CustomToolResult<TDetails = any> = AgentToolResult<TDetails>;
/** /**
* Custom tool with optional lifecycle and rendering methods. * Custom tool definition.
* *
* The execute signature inherited from AgentTool includes an optional onUpdate callback * Custom tools are standalone - they don't extend AgentTool directly.
* for streaming progress updates during long-running operations: * When loaded, they are wrapped in an AgentTool for the agent to use.
* - The callback emits partial results to subscribers (e.g. TUI/RPC), not to the LLM. *
* - Partial updates should use the same TDetails type as the final result (use a union if needed). * The execute callback receives a ToolContext with access to session state,
* model registry, and current model.
* *
* @example * @example
* ```typescript * ```typescript
* type Details = * const factory: CustomToolFactory = (pi) => ({
* | { status: "running"; step: number; total: number } * name: "my_tool",
* | { status: "done"; count: number }; * label: "My Tool",
* description: "Does something useful",
* parameters: Type.Object({ input: Type.String() }),
* *
* async execute(toolCallId, params, signal, onUpdate) { * async execute(toolCallId, params, signal, onUpdate, ctx) {
* const items = params.items || []; * // Access session state via ctx.sessionManager
* for (let i = 0; i < items.length; i++) { * // Access model registry via ctx.modelRegistry
* onUpdate?.({ * // Current model via ctx.model
* content: [{ type: "text", text: `Step ${i + 1}/${items.length}...` }], * return { content: [{ type: "text", text: "Done" }] };
* details: { status: "running", step: i + 1, total: items.length }, * },
*
* onSession(event, ctx) {
* if (event.reason === "shutdown") {
* // Cleanup
* }
* // Reconstruct state from ctx.sessionManager.getEntries()
* }
* }); * });
* await processItem(items[i], signal);
* }
* return { content: [{ type: "text", text: "Done" }], details: { status: "done", count: items.length } };
* }
* ``` * ```
*
* Progress updates are rendered via renderResult with isPartial: true.
*/ */
export interface CustomAgentTool<TParams extends TSchema = TSchema, TDetails = any> export interface CustomTool<TParams extends TSchema = TSchema, TDetails = any> {
extends AgentTool<TParams, TDetails> { /** Tool name (used in LLM tool calls) */
name: string;
/** Human-readable label for UI */
label: string;
/** Description for LLM */
description: string;
/** Parameter schema (TypeBox) */
parameters: TParams;
/**
* Execute the tool.
* @param toolCallId - Unique ID for this tool call
* @param params - Parsed parameters matching the schema
* @param signal - AbortSignal for cancellation
* @param onUpdate - Callback for streaming partial results (for UI, not LLM)
* @param ctx - Context with session manager, model registry, and current model
*/
execute(
toolCallId: string,
params: Static<TParams>,
signal: AbortSignal | undefined,
onUpdate: AgentToolUpdateCallback<TDetails> | undefined,
ctx: CustomToolContext,
): Promise<AgentToolResult<TDetails>>;
/** Called on session lifecycle events - use to reconstruct state or cleanup resources */ /** Called on session lifecycle events - use to reconstruct state or cleanup resources */
onSession?: (event: SessionEvent) => void | Promise<void>; onSession?: (event: CustomToolSessionEvent, ctx: CustomToolContext) => void | Promise<void>;
/** Custom rendering for tool call display - return a Component */ /** Custom rendering for tool call display - return a Component */
renderCall?: (args: Static<TParams>, theme: Theme) => Component; renderCall?: (args: Static<TParams>, theme: Theme) => Component;
/** Custom rendering for tool result display - return a Component */ /** Custom rendering for tool result display - return a Component */
renderResult?: (result: AgentToolResult<TDetails>, options: RenderResultOptions, theme: Theme) => Component; renderResult?: (result: CustomToolResult<TDetails>, options: RenderResultOptions, theme: Theme) => Component;
} }
/** Factory function that creates a custom tool or array of tools */ /** Factory function that creates a custom tool or array of tools */
export type CustomToolFactory = ( export type CustomToolFactory = (
pi: ToolAPI, pi: CustomToolAPI,
) => CustomAgentTool<any> | CustomAgentTool[] | Promise<CustomAgentTool | CustomAgentTool[]>; ) => CustomTool<any, any> | CustomTool<any, any>[] | Promise<CustomTool<any, any> | CustomTool<any, any>[]>;
/** Loaded custom tool with metadata */ /** Loaded custom tool with metadata and wrapped AgentTool */
export interface LoadedCustomTool { export interface LoadedCustomTool {
/** Original path (as specified) */ /** Original path (as specified) */
path: string; path: string;
/** Resolved absolute path */ /** Resolved absolute path */
resolvedPath: string; resolvedPath: string;
/** The tool instance */ /** The original custom tool instance */
tool: CustomAgentTool; tool: CustomTool;
} }
/** Result from loading custom tools */ /** Result from loading custom tools */
@ -113,5 +155,5 @@ export interface CustomToolsLoadResult {
tools: LoadedCustomTool[]; tools: LoadedCustomTool[];
errors: Array<{ path: string; error: string }>; errors: Array<{ path: string; error: string }>;
/** Update the UI context for all loaded tools. Call when mode initializes. */ /** Update the UI context for all loaded tools. Call when mode initializes. */
setUIContext(uiContext: ToolUIContext, hasUI: boolean): void; setUIContext(uiContext: CustomToolUIContext, hasUI: boolean): void;
} }

View file

@ -0,0 +1,28 @@
/**
* Wraps CustomTool instances into AgentTool for use with the agent.
*/
import type { AgentTool } from "@mariozechner/pi-agent-core";
import type { CustomTool, CustomToolContext, LoadedCustomTool } from "./types.js";
/**
* Wrap a CustomTool into an AgentTool.
* The wrapper injects the ToolContext into execute calls.
*/
export function wrapCustomTool(tool: CustomTool, getContext: () => CustomToolContext): AgentTool {
return {
name: tool.name,
label: tool.label,
description: tool.description,
parameters: tool.parameters,
execute: (toolCallId, params, signal, onUpdate) =>
tool.execute(toolCallId, params, signal, onUpdate, getContext()),
};
}
/**
* Wrap all loaded custom tools into AgentTools.
*/
export function wrapCustomTools(loadedTools: LoadedCustomTool[], getContext: () => CustomToolContext): AgentTool[] {
return loadedTools.map((lt) => wrapCustomTool(lt.tool, getContext));
}

View file

@ -108,20 +108,12 @@ export class HookRunner {
hasUI?: boolean; hasUI?: boolean;
}): void { }): void {
this.getModel = options.getModel; this.getModel = options.getModel;
this.setSendMessageHandler(options.sendMessageHandler); for (const hook of this.hooks) {
this.setAppendEntryHandler(options.appendEntryHandler); hook.setSendMessageHandler(options.sendMessageHandler);
if (options.uiContext) { hook.setAppendEntryHandler(options.appendEntryHandler);
this.setUIContext(options.uiContext, options.hasUI ?? false);
} }
} this.uiContext = options.uiContext ?? noOpUIContext;
this.hasUI = options.hasUI ?? false;
/**
* Set the UI context for hooks.
* Call this when the mode initializes and UI is available.
*/
setUIContext(uiContext: HookUIContext, hasUI: boolean): void {
this.uiContext = uiContext;
this.hasUI = hasUI;
} }
/** /**
@ -145,26 +137,6 @@ export class HookRunner {
return this.hooks.map((h) => h.path); return this.hooks.map((h) => h.path);
} }
/**
* Set the send message handler for all hooks' pi.sendMessage().
* Call this when the mode initializes.
*/
setSendMessageHandler(handler: SendMessageHandler): void {
for (const hook of this.hooks) {
hook.setSendMessageHandler(handler);
}
}
/**
* Set the append entry handler for all hooks' pi.appendEntry().
* Call this when the mode initializes.
*/
setAppendEntryHandler(handler: AppendEntryHandler): void {
for (const hook of this.hooks) {
hook.setAppendEntryHandler(handler);
}
}
/** /**
* Subscribe to hook errors. * Subscribe to hook errors.
* @returns Unsubscribe function * @returns Unsubscribe function

View file

@ -13,27 +13,7 @@ import type { CompactionPreparation, CompactionResult } from "../compaction/inde
import type { ExecOptions, ExecResult } from "../exec.js"; import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookMessage } from "../messages.js"; import type { HookMessage } from "../messages.js";
import type { ModelRegistry } from "../model-registry.js"; import type { ModelRegistry } from "../model-registry.js";
import type { BranchSummaryEntry, CompactionEntry, SessionEntry, SessionManager } from "../session-manager.js"; import type { BranchSummaryEntry, CompactionEntry, ReadonlySessionManager, SessionEntry } from "../session-manager.js";
/**
* Read-only view of SessionManager for hooks.
* Hooks should use pi.sendMessage() and pi.appendEntry() for writes.
*/
export type ReadonlySessionManager = Pick<
SessionManager,
| "getCwd"
| "getSessionDir"
| "getSessionId"
| "getSessionFile"
| "getLeafId"
| "getLeafEntry"
| "getEntry"
| "getLabel"
| "getPath"
| "getHeader"
| "getEntries"
| "getTree"
>;
import type { EditToolDetails } from "../tools/edit.js"; import type { EditToolDetails } from "../tools/edit.js";
import type { import type {

View file

@ -14,16 +14,16 @@ export {
export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js"; export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js"; export type { CompactionResult } from "./compaction/index.js";
export { export {
type CustomAgentTool, type CustomTool,
type CustomToolAPI,
type CustomToolFactory, type CustomToolFactory,
type CustomToolsLoadResult, type CustomToolsLoadResult,
type CustomToolUIContext,
discoverAndLoadCustomTools, discoverAndLoadCustomTools,
type ExecResult, type ExecResult,
type LoadedCustomTool, type LoadedCustomTool,
loadCustomTools, loadCustomTools,
type RenderResultOptions, type RenderResultOptions,
type ToolAPI,
type ToolUIContext,
} from "./custom-tools/index.js"; } from "./custom-tools/index.js";
export { export {
type HookAPI, type HookAPI,

View file

@ -35,8 +35,13 @@ import { join } from "path";
import { getAgentDir } from "../config.js"; import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js"; import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js"; import { AuthStorage } from "./auth-storage.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js"; import {
import type { CustomAgentTool } from "./custom-tools/types.js"; type CustomToolsLoadResult,
discoverAndLoadCustomTools,
type LoadedCustomTool,
wrapCustomTools,
} from "./custom-tools/index.js";
import type { CustomTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js"; import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js"; import type { HookFactory } from "./hooks/types.js";
import { convertToLlm } from "./messages.js"; import { convertToLlm } from "./messages.js";
@ -99,7 +104,7 @@ export interface CreateAgentSessionOptions {
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */ /** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
tools?: Tool[]; tools?: Tool[];
/** Custom tools (replaces discovery). */ /** Custom tools (replaces discovery). */
customTools?: Array<{ path?: string; tool: CustomAgentTool }>; customTools?: Array<{ path?: string; tool: CustomTool }>;
/** Additional custom tool paths to load (merged with discovery). */ /** Additional custom tool paths to load (merged with discovery). */
additionalCustomToolPaths?: string[]; additionalCustomToolPaths?: string[];
@ -127,17 +132,14 @@ export interface CreateAgentSessionResult {
/** The created session */ /** The created session */
session: AgentSession; session: AgentSession;
/** Custom tools result (for UI context setup in interactive mode) */ /** Custom tools result (for UI context setup in interactive mode) */
customToolsResult: { customToolsResult: CustomToolsLoadResult;
tools: LoadedCustomTool[];
setUIContext: (uiContext: any, hasUI: boolean) => void;
};
/** Warning if session was restored with a different model than saved */ /** Warning if session was restored with a different model than saved */
modelFallbackMessage?: string; modelFallbackMessage?: string;
} }
// Re-exports // Re-exports
export type { CustomAgentTool } from "./custom-tools/types.js"; export type { CustomTool } from "./custom-tools/types.js";
export type { HookAPI, HookFactory } from "./hooks/types.js"; export type { HookAPI, HookFactory } from "./hooks/types.js";
export type { Settings, SkillsSettings } from "./settings-manager.js"; export type { Settings, SkillsSettings } from "./settings-manager.js";
export type { Skill } from "./skills.js"; export type { Skill } from "./skills.js";
@ -219,7 +221,7 @@ export async function discoverHooks(
export async function discoverCustomTools( export async function discoverCustomTools(
cwd?: string, cwd?: string,
agentDir?: string, agentDir?: string,
): Promise<Array<{ path: string; tool: CustomAgentTool }>> { ): Promise<Array<{ path: string; tool: CustomTool }>> {
const resolvedCwd = cwd ?? process.cwd(); const resolvedCwd = cwd ?? process.cwd();
const resolvedAgentDir = agentDir ?? getDefaultAgentDir(); const resolvedAgentDir = agentDir ?? getDefaultAgentDir();
@ -507,7 +509,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const builtInTools = options.tools ?? createCodingTools(cwd); const builtInTools = options.tools ?? createCodingTools(cwd);
time("createCodingTools"); time("createCodingTools");
let customToolsResult: { tools: LoadedCustomTool[]; setUIContext: (ctx: any, hasUI: boolean) => void }; let customToolsResult: CustomToolsLoadResult;
if (options.customTools !== undefined) { if (options.customTools !== undefined) {
// Use provided custom tools // Use provided custom tools
const loadedTools: LoadedCustomTool[] = options.customTools.map((ct) => ({ const loadedTools: LoadedCustomTool[] = options.customTools.map((ct) => ({
@ -517,17 +519,17 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
})); }));
customToolsResult = { customToolsResult = {
tools: loadedTools, tools: loadedTools,
errors: [],
setUIContext: () => {}, setUIContext: () => {},
}; };
} else { } else {
// Discover custom tools, merging with additional paths // Discover custom tools, merging with additional paths
const configuredPaths = [...settingsManager.getCustomToolPaths(), ...(options.additionalCustomToolPaths ?? [])]; const configuredPaths = [...settingsManager.getCustomToolPaths(), ...(options.additionalCustomToolPaths ?? [])];
const result = await discoverAndLoadCustomTools(configuredPaths, cwd, Object.keys(allTools), agentDir); customToolsResult = await discoverAndLoadCustomTools(configuredPaths, cwd, Object.keys(allTools), agentDir);
time("discoverAndLoadCustomTools"); time("discoverAndLoadCustomTools");
for (const { path, error } of result.errors) { for (const { path, error } of customToolsResult.errors) {
console.error(`Failed to load custom tool "${path}": ${error}`); console.error(`Failed to load custom tool "${path}": ${error}`);
} }
customToolsResult = result;
} }
let hookRunner: HookRunner | undefined; let hookRunner: HookRunner | undefined;
@ -549,7 +551,15 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
} }
} }
let allToolsArray: Tool[] = [...builtInTools, ...customToolsResult.tools.map((lt) => lt.tool as unknown as Tool)]; // Wrap custom tools with context getter (agent is assigned below, accessed at execute time)
let agent: Agent;
const wrappedCustomTools = wrapCustomTools(customToolsResult.tools, () => ({
sessionManager,
modelRegistry,
model: agent.state.model,
}));
let allToolsArray: Tool[] = [...builtInTools, ...wrappedCustomTools];
time("combineTools"); time("combineTools");
if (hookRunner) { if (hookRunner) {
allToolsArray = wrapToolsWithHooks(allToolsArray, hookRunner) as Tool[]; allToolsArray = wrapToolsWithHooks(allToolsArray, hookRunner) as Tool[];
@ -581,7 +591,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const slashCommands = options.slashCommands ?? discoverSlashCommands(cwd, agentDir); const slashCommands = options.slashCommands ?? discoverSlashCommands(cwd, agentDir);
time("discoverSlashCommands"); time("discoverSlashCommands");
const agent = new Agent({ agent = new Agent({
initialState: { initialState: {
systemPrompt, systemPrompt,
model, model,

View file

@ -159,19 +159,21 @@ export interface SessionInfo {
allMessagesText: string; allMessagesText: string;
} }
/** export type ReadonlySessionManager = Pick<
* Read-only interface for SessionManager. SessionManager,
* Used by compaction/summarization utilities that only need to read session data. | "getCwd"
*/ | "getSessionDir"
export interface ReadonlySessionManager { | "getSessionId"
getLeafId(): string | null; | "getSessionFile"
getEntry(id: string): SessionEntry | undefined; | "getLeafId"
getPath(fromId?: string): SessionEntry[]; | "getLeafEntry"
getEntries(): SessionEntry[]; | "getEntry"
getChildren(parentId: string): SessionEntry[]; | "getLabel"
getTree(): SessionTreeNode[]; | "getBranch"
getLabel(id: string): string | undefined; | "getHeader"
} | "getEntries"
| "getTree"
>;
/** Generate a unique short ID (8 hex chars, collision-checked) */ /** Generate a unique short ID (8 hex chars, collision-checked) */
function generateId(byId: { has(id: string): boolean }): string { function generateId(byId: { has(id: string): boolean }): string {
@ -772,7 +774,7 @@ export class SessionManager {
* Includes all entry types (messages, compaction, model changes, etc.). * Includes all entry types (messages, compaction, model changes, etc.).
* Use buildSessionContext() to get the resolved messages for the LLM. * Use buildSessionContext() to get the resolved messages for the LLM.
*/ */
getPath(fromId?: string): SessionEntry[] { getBranch(fromId?: string): SessionEntry[] {
const path: SessionEntry[] = []; const path: SessionEntry[] = [];
const startId = fromId ?? this.leafId; const startId = fromId ?? this.leafId;
let current = startId ? this.byId.get(startId) : undefined; let current = startId ? this.byId.get(startId) : undefined;
@ -908,7 +910,7 @@ export class SessionManager {
* Returns the new session file path, or undefined if not persisting. * Returns the new session file path, or undefined if not persisting.
*/ */
createBranchedSession(leafId: string): string | undefined { createBranchedSession(leafId: string): string | undefined {
const path = this.getPath(leafId); const path = this.getBranch(leafId);
if (path.length === 0) { if (path.length === 0) {
throw new Error(`Entry ${leafId} not found`); throw new Error(`Entry ${leafId} not found`);
} }

View file

@ -35,15 +35,16 @@ export {
// Custom tools // Custom tools
export type { export type {
AgentToolUpdateCallback, AgentToolUpdateCallback,
CustomAgentTool, CustomTool,
CustomToolAPI,
CustomToolContext,
CustomToolFactory, CustomToolFactory,
CustomToolSessionEvent,
CustomToolsLoadResult, CustomToolsLoadResult,
CustomToolUIContext,
ExecResult, ExecResult,
LoadedCustomTool, LoadedCustomTool,
RenderResultOptions, RenderResultOptions,
SessionEvent as ToolSessionEvent,
ToolAPI,
ToolUIContext,
} from "./core/custom-tools/index.js"; } from "./core/custom-tools/index.js";
export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js"; export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js";
export type * from "./core/hooks/index.js"; export type * from "./core/hooks/index.js";

View file

@ -11,7 +11,7 @@ import {
type TUI, type TUI,
} from "@mariozechner/pi-tui"; } from "@mariozechner/pi-tui";
import stripAnsi from "strip-ansi"; import stripAnsi from "strip-ansi";
import type { CustomAgentTool } from "../../../core/custom-tools/types.js"; import type { CustomTool } from "../../../core/custom-tools/types.js";
import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize } from "../../../core/tools/truncate.js"; import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize } from "../../../core/tools/truncate.js";
import { getLanguageFromPath, highlightCode, theme } from "../theme/theme.js"; import { getLanguageFromPath, highlightCode, theme } from "../theme/theme.js";
import { renderDiff } from "./diff.js"; import { renderDiff } from "./diff.js";
@ -55,7 +55,7 @@ export class ToolExecutionComponent extends Container {
private expanded = false; private expanded = false;
private showImages: boolean; private showImages: boolean;
private isPartial = true; private isPartial = true;
private customTool?: CustomAgentTool; private customTool?: CustomTool;
private ui: TUI; private ui: TUI;
private result?: { private result?: {
content: Array<{ type: string; text?: string; data?: string; mimeType?: string }>; content: Array<{ type: string; text?: string; data?: string; mimeType?: string }>;
@ -67,7 +67,7 @@ export class ToolExecutionComponent extends Container {
toolName: string, toolName: string,
args: any, args: any,
options: ToolExecutionOptions = {}, options: ToolExecutionOptions = {},
customTool: CustomAgentTool | undefined, customTool: CustomTool | undefined,
ui: TUI, ui: TUI,
) { ) {
super(); super();

View file

@ -26,7 +26,7 @@ import {
import { exec, spawnSync } from "child_process"; import { exec, spawnSync } from "child_process";
import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js"; import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js";
import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js"; import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js"; import type { CustomToolSessionEvent, LoadedCustomTool } from "../../core/custom-tools/index.js";
import type { HookUIContext } from "../../core/hooks/index.js"; import type { HookUIContext } from "../../core/hooks/index.js";
import { createCompactionSummaryMessage } from "../../core/messages.js"; import { createCompactionSummaryMessage } from "../../core/messages.js";
import { type SessionContext, SessionManager } from "../../core/session-manager.js"; import { type SessionContext, SessionManager } from "../../core/session-manager.js";
@ -350,19 +350,20 @@ export class InteractiveMode {
this.chatContainer.addChild(new Spacer(1)); this.chatContainer.addChild(new Spacer(1));
} }
// Load session entries if any // Create and set hook & tool UI context
const entries = this.session.sessionManager.getEntries(); const uiContext: HookUIContext = {
select: (title, options) => this.showHookSelector(title, options),
// Set TUI-based UI context for custom tools confirm: (title, message) => this.showHookConfirm(title, message),
const uiContext = this.createHookUIContext(); input: (title, placeholder) => this.showHookInput(title, placeholder),
notify: (message, type) => this.showHookNotify(message, type),
custom: (component) => this.showHookCustom(component),
};
this.setToolUIContext(uiContext, true); this.setToolUIContext(uiContext, true);
// Notify custom tools of session start // Notify custom tools of session start
await this.emitToolSessionEvent({ await this.emitCustomToolSessionEvent({
entries,
sessionFile: this.session.sessionFile,
previousSessionFile: undefined,
reason: "start", reason: "start",
previousSessionFile: undefined,
}); });
const hookRunner = this.session.hookRunner; const hookRunner = this.session.hookRunner;
@ -370,16 +371,9 @@ export class InteractiveMode {
return; // No hooks loaded return; // No hooks loaded
} }
// Set UI context on hook runner hookRunner.initialize({
hookRunner.setUIContext(uiContext, true); getModel: () => this.session.model,
sendMessageHandler: (message, triggerTurn) => {
// Subscribe to hook errors
hookRunner.onError((error) => {
this.showHookError(error.hookPath, error.error);
});
// Set up handlers for pi.sendMessage() and pi.appendEntry()
hookRunner.setSendMessageHandler((message, triggerTurn) => {
const wasStreaming = this.session.isStreaming; const wasStreaming = this.session.isStreaming;
this.session this.session
.sendHookMessage(message, triggerTurn) .sendHookMessage(message, triggerTurn)
@ -393,9 +387,17 @@ export class InteractiveMode {
.catch((err) => { .catch((err) => {
this.showError(`Hook sendMessage failed: ${err instanceof Error ? err.message : String(err)}`); this.showError(`Hook sendMessage failed: ${err instanceof Error ? err.message : String(err)}`);
}); });
}); },
hookRunner.setAppendEntryHandler((customType, data) => { appendEntryHandler: (customType, data) => {
this.sessionManager.appendCustomEntry(customType, data); this.sessionManager.appendCustomEntry(customType, data);
},
uiContext,
hasUI: true,
});
// Subscribe to hook errors
hookRunner.onError((error) => {
this.showHookError(error.hookPath, error.error);
}); });
// Show loaded hooks // Show loaded hooks
@ -415,11 +417,15 @@ export class InteractiveMode {
/** /**
* Emit session event to all custom tools. * Emit session event to all custom tools.
*/ */
private async emitToolSessionEvent(event: ToolSessionEvent): Promise<void> { private async emitCustomToolSessionEvent(event: CustomToolSessionEvent): Promise<void> {
for (const { tool } of this.customTools.values()) { for (const { tool } of this.customTools.values()) {
if (tool.onSession) { if (tool.onSession) {
try { try {
await tool.onSession(event); await tool.onSession(event, {
sessionManager: this.session.sessionManager,
modelRegistry: this.session.modelRegistry,
model: this.session.model,
});
} catch (err) { } catch (err) {
this.showToolError(tool.name, err instanceof Error ? err.message : String(err)); this.showToolError(tool.name, err instanceof Error ? err.message : String(err));
} }
@ -436,19 +442,6 @@ export class InteractiveMode {
this.ui.requestRender(); this.ui.requestRender();
} }
/**
* Create the UI context for hooks.
*/
private createHookUIContext(): HookUIContext {
return {
select: (title, options) => this.showHookSelector(title, options),
confirm: (title, message) => this.showHookConfirm(title, message),
input: (title, placeholder) => this.showHookInput(title, placeholder),
notify: (message, type) => this.showHookNotify(message, type),
custom: (component) => this.showHookCustom(component),
};
}
/** /**
* Show a selector for hooks. * Show a selector for hooks.
*/ */
@ -861,6 +854,7 @@ export class InteractiveMode {
this.customTools.get(content.name)?.tool, this.customTools.get(content.name)?.tool,
this.ui, this.ui,
); );
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component); this.chatContainer.addChild(component);
this.pendingTools.set(content.id, component); this.pendingTools.set(content.id, component);
} else { } else {
@ -909,6 +903,7 @@ export class InteractiveMode {
this.customTools.get(event.toolName)?.tool, this.customTools.get(event.toolName)?.tool,
this.ui, this.ui,
); );
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component); this.chatContainer.addChild(component);
this.pendingTools.set(event.toolCallId, component); this.pendingTools.set(event.toolCallId, component);
this.ui.requestRender(); this.ui.requestRender();
@ -1158,6 +1153,7 @@ export class InteractiveMode {
this.customTools.get(content.name)?.tool, this.customTools.get(content.name)?.tool,
this.ui, this.ui,
); );
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component); this.chatContainer.addChild(component);
if (message.stopReason === "aborted" || message.stopReason === "error") { if (message.stopReason === "aborted" || message.stopReason === "error") {
@ -1251,7 +1247,7 @@ export class InteractiveMode {
} }
// Emit shutdown event to custom tools // Emit shutdown event to custom tools
await this.session.emitToolSessionEvent("shutdown"); await this.session.emitCustomToolSessionEvent("shutdown");
this.stop(); this.stop();
process.exit(0); process.exit(0);

View file

@ -26,24 +26,23 @@ export async function runPrintMode(
initialMessage?: string, initialMessage?: string,
initialImages?: ImageContent[], initialImages?: ImageContent[],
): Promise<void> { ): Promise<void> {
// Load entries once for session start events
const entries = session.sessionManager.getEntries();
// Hook runner already has no-op UI context by default (set in main.ts) // Hook runner already has no-op UI context by default (set in main.ts)
// Set up hooks for print mode (no UI) // Set up hooks for print mode (no UI)
const hookRunner = session.hookRunner; const hookRunner = session.hookRunner;
if (hookRunner) { if (hookRunner) {
hookRunner.onError((err) => { hookRunner.initialize({
console.error(`Hook error (${err.hookPath}): ${err.error}`); getModel: () => session.model,
}); sendMessageHandler: (message, triggerTurn) => {
// Set up handlers - sendHookMessage handles queuing/direct append as needed
hookRunner.setSendMessageHandler((message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => { session.sendHookMessage(message, triggerTurn).catch((e) => {
console.error(`Hook sendMessage failed: ${e instanceof Error ? e.message : String(e)}`); console.error(`Hook sendMessage failed: ${e instanceof Error ? e.message : String(e)}`);
}); });
}); },
hookRunner.setAppendEntryHandler((customType, data) => { appendEntryHandler: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data); session.sessionManager.appendCustomEntry(customType, data);
},
});
hookRunner.onError((err) => {
console.error(`Hook error (${err.hookPath}): ${err.error}`);
}); });
// Emit session_start event // Emit session_start event
await hookRunner.emit({ await hookRunner.emit({
@ -55,12 +54,17 @@ export async function runPrintMode(
for (const { tool } of session.customTools) { for (const { tool } of session.customTools) {
if (tool.onSession) { if (tool.onSession) {
try { try {
await tool.onSession({ await tool.onSession(
entries, {
sessionFile: session.sessionFile,
previousSessionFile: undefined,
reason: "start", reason: "start",
}); previousSessionFile: undefined,
},
{
sessionManager: session.sessionManager,
modelRegistry: session.modelRegistry,
model: session.model,
},
);
} catch (_err) { } catch (_err) {
// Silently ignore tool errors // Silently ignore tool errors
} }

View file

@ -125,24 +125,24 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
}, },
}); });
// Load entries once for session start events
const entries = session.sessionManager.getEntries();
// Set up hooks with RPC-based UI context // Set up hooks with RPC-based UI context
const hookRunner = session.hookRunner; const hookRunner = session.hookRunner;
if (hookRunner) { if (hookRunner) {
hookRunner.setUIContext(createHookUIContext(), false); hookRunner.initialize({
hookRunner.onError((err) => { getModel: () => session.agent.state.model,
output({ type: "hook_error", hookPath: err.hookPath, event: err.event, error: err.error }); sendMessageHandler: (message, triggerTurn) => {
});
// Set up handlers for pi.sendMessage() and pi.appendEntry()
hookRunner.setSendMessageHandler((message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => { session.sendHookMessage(message, triggerTurn).catch((e) => {
output(error(undefined, "hook_send", e.message)); output(error(undefined, "hook_send", e.message));
}); });
}); },
hookRunner.setAppendEntryHandler((customType, data) => { appendEntryHandler: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data); session.sessionManager.appendCustomEntry(customType, data);
},
uiContext: createHookUIContext(),
hasUI: false,
});
hookRunner.onError((err) => {
output({ type: "hook_error", hookPath: err.hookPath, event: err.event, error: err.error });
}); });
// Emit session_start event // Emit session_start event
await hookRunner.emit({ await hookRunner.emit({
@ -155,12 +155,17 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
for (const { tool } of session.customTools) { for (const { tool } of session.customTools) {
if (tool.onSession) { if (tool.onSession) {
try { try {
await tool.onSession({ await tool.onSession(
entries, {
sessionFile: session.sessionFile,
previousSessionFile: undefined, previousSessionFile: undefined,
reason: "start", reason: "start",
}); },
{
sessionManager: session.sessionManager,
modelRegistry: session.modelRegistry,
model: session.model,
},
);
} catch (_err) { } catch (_err) {
// Silently ignore tool errors // Silently ignore tool errors
} }

View file

@ -11,17 +11,11 @@ describe("Documentation example", () => {
const exampleHook = (pi: HookAPI) => { const exampleHook = (pi: HookAPI) => {
pi.on("session_before_compact", async (event: SessionBeforeCompactEvent, ctx) => { pi.on("session_before_compact", async (event: SessionBeforeCompactEvent, ctx) => {
// All these should be accessible on the event // All these should be accessible on the event
const { preparation, branchEntries, signal } = event; const { preparation, branchEntries } = event;
// sessionManager, modelRegistry, and model come from ctx // sessionManager, modelRegistry, and model come from ctx
const { sessionManager, modelRegistry, model } = ctx; const { sessionManager, modelRegistry } = ctx;
const { const { messagesToSummarize, turnPrefixMessages, tokensBefore, firstKeptEntryId, isSplitTurn } =
messagesToSummarize, preparation;
turnPrefixMessages,
tokensBefore,
firstKeptEntryId,
isSplitTurn,
previousSummary,
} = preparation;
// Verify types // Verify types
expect(Array.isArray(messagesToSummarize)).toBe(true); expect(Array.isArray(messagesToSummarize)).toBe(true);

View file

@ -99,16 +99,19 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => {
const modelRegistry = new ModelRegistry(authStorage); const modelRegistry = new ModelRegistry(authStorage);
hookRunner = new HookRunner(hooks, tempDir, sessionManager, modelRegistry); hookRunner = new HookRunner(hooks, tempDir, sessionManager, modelRegistry);
hookRunner.setUIContext( hookRunner.initialize({
{ getModel: () => session.model,
sendMessageHandler: async () => {},
appendEntryHandler: async () => {},
uiContext: {
select: async () => undefined, select: async () => undefined,
confirm: async () => false, confirm: async () => false,
input: async () => undefined, input: async () => undefined,
notify: () => {}, notify: () => {},
custom: () => ({ close: () => {}, requestRender: () => {} }), custom: () => ({ close: () => {}, requestRender: () => {} }),
}, },
false, hasUI: false,
); });
session = new AgentSession({ session = new AgentSession({
agent, agent,

View file

@ -42,7 +42,7 @@ describe("SessionManager.saveCustomEntry", () => {
expect(customEntry.parentId).toBe(msgId); expect(customEntry.parentId).toBe(msgId);
// Tree structure should be correct // Tree structure should be correct
const path = session.getPath(); const path = session.getBranch();
expect(path).toHaveLength(3); expect(path).toHaveLength(3);
expect(path[0].id).toBe(msgId); expect(path[0].id).toBe(msgId);
expect(path[1].id).toBe(customId); expect(path[1].id).toBe(customId);

View file

@ -122,14 +122,14 @@ describe("SessionManager append and tree traversal", () => {
describe("getPath", () => { describe("getPath", () => {
it("returns empty array for empty session", () => { it("returns empty array for empty session", () => {
const session = SessionManager.inMemory(); const session = SessionManager.inMemory();
expect(session.getPath()).toEqual([]); expect(session.getBranch()).toEqual([]);
}); });
it("returns single entry path", () => { it("returns single entry path", () => {
const session = SessionManager.inMemory(); const session = SessionManager.inMemory();
const id = session.appendMessage(userMsg("hello")); const id = session.appendMessage(userMsg("hello"));
const path = session.getPath(); const path = session.getBranch();
expect(path).toHaveLength(1); expect(path).toHaveLength(1);
expect(path[0].id).toBe(id); expect(path[0].id).toBe(id);
}); });
@ -142,7 +142,7 @@ describe("SessionManager append and tree traversal", () => {
const id3 = session.appendThinkingLevelChange("high"); const id3 = session.appendThinkingLevelChange("high");
const id4 = session.appendMessage(userMsg("3")); const id4 = session.appendMessage(userMsg("3"));
const path = session.getPath(); const path = session.getBranch();
expect(path).toHaveLength(4); expect(path).toHaveLength(4);
expect(path.map((e) => e.id)).toEqual([id1, id2, id3, id4]); expect(path.map((e) => e.id)).toEqual([id1, id2, id3, id4]);
}); });
@ -155,7 +155,7 @@ describe("SessionManager append and tree traversal", () => {
const _id3 = session.appendMessage(userMsg("3")); const _id3 = session.appendMessage(userMsg("3"));
const _id4 = session.appendMessage(assistantMsg("4")); const _id4 = session.appendMessage(assistantMsg("4"));
const path = session.getPath(id2); const path = session.getBranch(id2);
expect(path).toHaveLength(2); expect(path).toHaveLength(2);
expect(path.map((e) => e.id)).toEqual([id1, id2]); expect(path.map((e) => e.id)).toEqual([id1, id2]);
}); });