Rework custom tools API with CustomToolContext

- CustomAgentTool renamed to CustomTool
- ToolAPI renamed to CustomToolAPI
- ToolContext renamed to CustomToolContext
- ToolSessionEvent renamed to CustomToolSessionEvent
- Added CustomToolContext parameter to execute() and onSession()
- CustomToolFactory now returns CustomTool<any, any> for type compatibility
- dispose() replaced with onSession({ reason: 'shutdown' })
- Added wrapCustomTool() to convert CustomTool to AgentTool
- Session exposes setToolUIContext() instead of leaking internals
- Fix ToolExecutionComponent to sync with toolOutputExpanded state
- Update all custom tool examples for new API
This commit is contained in:
Mario Zechner 2025-12-31 12:05:24 +01:00
parent b123df5fab
commit 568150f18b
27 changed files with 336 additions and 289 deletions

View file

@ -27,7 +27,7 @@ import {
prepareCompaction,
shouldCompact,
} from "./compaction/index.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "./custom-tools/index.js";
import type { CustomToolContext, CustomToolSessionEvent, LoadedCustomTool } from "./custom-tools/index.js";
import { exportSessionToHtml } from "./export-html.js";
import type {
HookContext,
@ -698,7 +698,7 @@ export class AgentSession {
}
// Emit session event to custom tools
await this.emitToolSessionEvent("new", previousSessionFile);
await this.emitCustomToolSessionEvent("new", previousSessionFile);
return true;
}
@ -895,7 +895,7 @@ export class AgentSession {
throw new Error(`No API key for ${this.model.provider}`);
}
const pathEntries = this.sessionManager.getPath();
const pathEntries = this.sessionManager.getBranch();
const settings = this.settingsManager.getCompactionSettings();
const preparation = prepareCompaction(pathEntries, settings);
@ -1068,7 +1068,7 @@ export class AgentSession {
return;
}
const pathEntries = this.sessionManager.getPath();
const pathEntries = this.sessionManager.getBranch();
const preparation = prepareCompaction(pathEntries, settings);
if (!preparation) {
@ -1473,7 +1473,7 @@ export class AgentSession {
}
// Emit session event to custom tools
await this.emitToolSessionEvent("switch", previousSessionFile);
await this.emitCustomToolSessionEvent("switch", previousSessionFile);
this.agent.replaceMessages(sessionContext.messages);
@ -1550,7 +1550,7 @@ export class AgentSession {
}
// Emit session event to custom tools (with reason "branch")
await this.emitToolSessionEvent("branch", previousSessionFile);
await this.emitCustomToolSessionEvent("branch", previousSessionFile);
if (!skipConversationRestore) {
this.agent.replaceMessages(sessionContext.messages);
@ -1720,7 +1720,7 @@ export class AgentSession {
}
// Emit to custom tools
await this.emitToolSessionEvent("tree", this.sessionFile);
await this.emitCustomToolSessionEvent("tree", this.sessionFile);
this._branchSummaryAbortController = undefined;
return { editorText, cancelled: false, summaryEntry };
@ -1877,20 +1877,23 @@ export class AgentSession {
* Emit session event to all custom tools.
* Called on session switch, branch, tree navigation, and shutdown.
*/
async emitToolSessionEvent(
reason: ToolSessionEvent["reason"],
async emitCustomToolSessionEvent(
reason: CustomToolSessionEvent["reason"],
previousSessionFile?: string | undefined,
): Promise<void> {
const event: ToolSessionEvent = {
entries: this.sessionManager.getEntries(),
sessionFile: this.sessionFile,
previousSessionFile,
reason,
if (!this._customTools) return;
const event: CustomToolSessionEvent = { reason, previousSessionFile };
const ctx: CustomToolContext = {
sessionManager: this.sessionManager,
modelRegistry: this._modelRegistry,
model: this.agent.state.model,
};
for (const { tool } of this._customTools) {
if (tool.onSession) {
try {
await tool.onSession(event);
await tool.onSession(event, ctx);
} catch (_err) {
// Silently ignore tool errors during session events
}

View file

@ -102,8 +102,8 @@ export function collectEntriesForBranchSummary(
}
// Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getPath(oldLeafId).map((e) => e.id));
const targetPath = session.getPath(targetId);
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
const targetPath = session.getBranch(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null;

View file

@ -4,14 +4,18 @@
export { discoverAndLoadCustomTools, loadCustomTools } from "./loader.js";
export type {
AgentToolResult,
AgentToolUpdateCallback,
CustomAgentTool,
CustomTool,
CustomToolAPI,
CustomToolContext,
CustomToolFactory,
CustomToolResult,
CustomToolSessionEvent,
CustomToolsLoadResult,
CustomToolUIContext,
ExecResult,
LoadedCustomTool,
RenderResultOptions,
SessionEvent,
ToolAPI,
ToolUIContext,
} from "./types.js";
export { wrapCustomTool, wrapCustomTools } from "./wrapper.js";

View file

@ -17,7 +17,7 @@ import { getAgentDir, isBunBinary } from "../../config.js";
import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js";
import type { CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool, ToolAPI } from "./types.js";
import type { CustomToolAPI, CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool } from "./types.js";
// Create require function to resolve module paths at runtime
const require = createRequire(import.meta.url);
@ -104,7 +104,7 @@ function createNoOpUIContext(): HookUIContext {
*/
async function loadToolWithBun(
resolvedPath: string,
sharedApi: ToolAPI,
sharedApi: CustomToolAPI,
): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> {
try {
// Try to import directly - will work for tools without @mariozechner/* imports
@ -149,7 +149,7 @@ async function loadToolWithBun(
async function loadTool(
toolPath: string,
cwd: string,
sharedApi: ToolAPI,
sharedApi: CustomToolAPI,
): Promise<{ tools: LoadedCustomTool[] | null; error: string | null }> {
const resolvedPath = resolveToolPath(toolPath, cwd);
@ -209,7 +209,7 @@ export async function loadCustomTools(
const seenNames = new Set<string>(builtInToolNames);
// Shared API object - all tools get the same instance
const sharedApi: ToolAPI = {
const sharedApi: CustomToolAPI = {
cwd,
exec: (command: string, args: string[], options?: ExecOptions) =>
execCommand(command, args, options?.cwd ?? cwd, options),

View file

@ -5,45 +5,56 @@
* They can provide custom rendering for tool calls and results in the TUI.
*/
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
import type { AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui";
import type { Static, TSchema } from "@sinclair/typebox";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js";
import type { SessionEntry } from "../session-manager.js";
import type { ModelRegistry } from "../model-registry.js";
import type { ReadonlySessionManager } from "../session-manager.js";
/** Alias for clarity */
export type ToolUIContext = HookUIContext;
export type CustomToolUIContext = HookUIContext;
/** Re-export for custom tools to use in execute signature */
export type { AgentToolUpdateCallback };
export type { AgentToolResult, AgentToolUpdateCallback };
// Re-export for backward compatibility
export type { ExecOptions, ExecResult } from "../exec.js";
/** API passed to custom tool factory (stable across session changes) */
export interface ToolAPI {
export interface CustomToolAPI {
/** Current working directory */
cwd: string;
/** Execute a command */
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
/** UI methods for user interaction (select, confirm, input, notify) */
ui: ToolUIContext;
/** UI methods for user interaction (select, confirm, input, notify, custom) */
ui: CustomToolUIContext;
/** Whether UI is available (false in print/RPC mode) */
hasUI: boolean;
}
/**
* Context passed to tool execute and onSession callbacks.
* Provides access to session state and model information.
*/
export interface CustomToolContext {
/** Session manager (read-only) */
sessionManager: ReadonlySessionManager;
/** Model registry - use for API key resolution and model retrieval */
modelRegistry: ModelRegistry;
/** Current model (may be undefined if no model is selected yet) */
model: Model<any> | undefined;
}
/** Session event passed to onSession callback */
export interface SessionEvent {
/** All session entries (including pre-compaction history) */
entries: SessionEntry[];
/** Current session file path, or undefined in --no-session mode */
sessionFile: string | undefined;
/** Previous session file path, or undefined for "start", "new", and "shutdown" */
previousSessionFile: string | undefined;
export interface CustomToolSessionEvent {
/** Reason for the session event */
reason: "start" | "switch" | "branch" | "new" | "tree" | "shutdown";
/** Previous session file path, or undefined for "start", "new", and "shutdown" */
previousSessionFile: string | undefined;
}
/** Rendering options passed to renderResult */
@ -54,58 +65,89 @@ export interface RenderResultOptions {
isPartial: boolean;
}
export type CustomToolResult<TDetails = any> = AgentToolResult<TDetails>;
/**
* Custom tool with optional lifecycle and rendering methods.
* Custom tool definition.
*
* The execute signature inherited from AgentTool includes an optional onUpdate callback
* for streaming progress updates during long-running operations:
* - The callback emits partial results to subscribers (e.g. TUI/RPC), not to the LLM.
* - Partial updates should use the same TDetails type as the final result (use a union if needed).
* Custom tools are standalone - they don't extend AgentTool directly.
* When loaded, they are wrapped in an AgentTool for the agent to use.
*
* The execute callback receives a ToolContext with access to session state,
* model registry, and current model.
*
* @example
* ```typescript
* type Details =
* | { status: "running"; step: number; total: number }
* | { status: "done"; count: number };
* const factory: CustomToolFactory = (pi) => ({
* name: "my_tool",
* label: "My Tool",
* description: "Does something useful",
* parameters: Type.Object({ input: Type.String() }),
*
* async execute(toolCallId, params, signal, onUpdate) {
* const items = params.items || [];
* for (let i = 0; i < items.length; i++) {
* onUpdate?.({
* content: [{ type: "text", text: `Step ${i + 1}/${items.length}...` }],
* details: { status: "running", step: i + 1, total: items.length },
* });
* await processItem(items[i], signal);
* async execute(toolCallId, params, signal, onUpdate, ctx) {
* // Access session state via ctx.sessionManager
* // Access model registry via ctx.modelRegistry
* // Current model via ctx.model
* return { content: [{ type: "text", text: "Done" }] };
* },
*
* onSession(event, ctx) {
* if (event.reason === "shutdown") {
* // Cleanup
* }
* // Reconstruct state from ctx.sessionManager.getEntries()
* }
* return { content: [{ type: "text", text: "Done" }], details: { status: "done", count: items.length } };
* }
* });
* ```
*
* Progress updates are rendered via renderResult with isPartial: true.
*/
export interface CustomAgentTool<TParams extends TSchema = TSchema, TDetails = any>
extends AgentTool<TParams, TDetails> {
export interface CustomTool<TParams extends TSchema = TSchema, TDetails = any> {
/** Tool name (used in LLM tool calls) */
name: string;
/** Human-readable label for UI */
label: string;
/** Description for LLM */
description: string;
/** Parameter schema (TypeBox) */
parameters: TParams;
/**
* Execute the tool.
* @param toolCallId - Unique ID for this tool call
* @param params - Parsed parameters matching the schema
* @param signal - AbortSignal for cancellation
* @param onUpdate - Callback for streaming partial results (for UI, not LLM)
* @param ctx - Context with session manager, model registry, and current model
*/
execute(
toolCallId: string,
params: Static<TParams>,
signal: AbortSignal | undefined,
onUpdate: AgentToolUpdateCallback<TDetails> | undefined,
ctx: CustomToolContext,
): Promise<AgentToolResult<TDetails>>;
/** Called on session lifecycle events - use to reconstruct state or cleanup resources */
onSession?: (event: SessionEvent) => void | Promise<void>;
onSession?: (event: CustomToolSessionEvent, ctx: CustomToolContext) => void | Promise<void>;
/** Custom rendering for tool call display - return a Component */
renderCall?: (args: Static<TParams>, theme: Theme) => Component;
/** Custom rendering for tool result display - return a Component */
renderResult?: (result: AgentToolResult<TDetails>, options: RenderResultOptions, theme: Theme) => Component;
renderResult?: (result: CustomToolResult<TDetails>, options: RenderResultOptions, theme: Theme) => Component;
}
/** Factory function that creates a custom tool or array of tools */
export type CustomToolFactory = (
pi: ToolAPI,
) => CustomAgentTool<any> | CustomAgentTool[] | Promise<CustomAgentTool | CustomAgentTool[]>;
pi: CustomToolAPI,
) => CustomTool<any, any> | CustomTool<any, any>[] | Promise<CustomTool<any, any> | CustomTool<any, any>[]>;
/** Loaded custom tool with metadata */
/** Loaded custom tool with metadata and wrapped AgentTool */
export interface LoadedCustomTool {
/** Original path (as specified) */
path: string;
/** Resolved absolute path */
resolvedPath: string;
/** The tool instance */
tool: CustomAgentTool;
/** The original custom tool instance */
tool: CustomTool;
}
/** Result from loading custom tools */
@ -113,5 +155,5 @@ export interface CustomToolsLoadResult {
tools: LoadedCustomTool[];
errors: Array<{ path: string; error: string }>;
/** Update the UI context for all loaded tools. Call when mode initializes. */
setUIContext(uiContext: ToolUIContext, hasUI: boolean): void;
setUIContext(uiContext: CustomToolUIContext, hasUI: boolean): void;
}

View file

@ -0,0 +1,28 @@
/**
* Wraps CustomTool instances into AgentTool for use with the agent.
*/
import type { AgentTool } from "@mariozechner/pi-agent-core";
import type { CustomTool, CustomToolContext, LoadedCustomTool } from "./types.js";
/**
* Wrap a CustomTool into an AgentTool.
* The wrapper injects the ToolContext into execute calls.
*/
export function wrapCustomTool(tool: CustomTool, getContext: () => CustomToolContext): AgentTool {
return {
name: tool.name,
label: tool.label,
description: tool.description,
parameters: tool.parameters,
execute: (toolCallId, params, signal, onUpdate) =>
tool.execute(toolCallId, params, signal, onUpdate, getContext()),
};
}
/**
* Wrap all loaded custom tools into AgentTools.
*/
export function wrapCustomTools(loadedTools: LoadedCustomTool[], getContext: () => CustomToolContext): AgentTool[] {
return loadedTools.map((lt) => wrapCustomTool(lt.tool, getContext));
}

View file

@ -108,20 +108,12 @@ export class HookRunner {
hasUI?: boolean;
}): void {
this.getModel = options.getModel;
this.setSendMessageHandler(options.sendMessageHandler);
this.setAppendEntryHandler(options.appendEntryHandler);
if (options.uiContext) {
this.setUIContext(options.uiContext, options.hasUI ?? false);
for (const hook of this.hooks) {
hook.setSendMessageHandler(options.sendMessageHandler);
hook.setAppendEntryHandler(options.appendEntryHandler);
}
}
/**
* Set the UI context for hooks.
* Call this when the mode initializes and UI is available.
*/
setUIContext(uiContext: HookUIContext, hasUI: boolean): void {
this.uiContext = uiContext;
this.hasUI = hasUI;
this.uiContext = options.uiContext ?? noOpUIContext;
this.hasUI = options.hasUI ?? false;
}
/**
@ -145,26 +137,6 @@ export class HookRunner {
return this.hooks.map((h) => h.path);
}
/**
* Set the send message handler for all hooks' pi.sendMessage().
* Call this when the mode initializes.
*/
setSendMessageHandler(handler: SendMessageHandler): void {
for (const hook of this.hooks) {
hook.setSendMessageHandler(handler);
}
}
/**
* Set the append entry handler for all hooks' pi.appendEntry().
* Call this when the mode initializes.
*/
setAppendEntryHandler(handler: AppendEntryHandler): void {
for (const hook of this.hooks) {
hook.setAppendEntryHandler(handler);
}
}
/**
* Subscribe to hook errors.
* @returns Unsubscribe function

View file

@ -13,27 +13,7 @@ import type { CompactionPreparation, CompactionResult } from "../compaction/inde
import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookMessage } from "../messages.js";
import type { ModelRegistry } from "../model-registry.js";
import type { BranchSummaryEntry, CompactionEntry, SessionEntry, SessionManager } from "../session-manager.js";
/**
* Read-only view of SessionManager for hooks.
* Hooks should use pi.sendMessage() and pi.appendEntry() for writes.
*/
export type ReadonlySessionManager = Pick<
SessionManager,
| "getCwd"
| "getSessionDir"
| "getSessionId"
| "getSessionFile"
| "getLeafId"
| "getLeafEntry"
| "getEntry"
| "getLabel"
| "getPath"
| "getHeader"
| "getEntries"
| "getTree"
>;
import type { BranchSummaryEntry, CompactionEntry, ReadonlySessionManager, SessionEntry } from "../session-manager.js";
import type { EditToolDetails } from "../tools/edit.js";
import type {

View file

@ -14,16 +14,16 @@ export {
export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js";
export {
type CustomAgentTool,
type CustomTool,
type CustomToolAPI,
type CustomToolFactory,
type CustomToolsLoadResult,
type CustomToolUIContext,
discoverAndLoadCustomTools,
type ExecResult,
type LoadedCustomTool,
loadCustomTools,
type RenderResultOptions,
type ToolAPI,
type ToolUIContext,
} from "./custom-tools/index.js";
export {
type HookAPI,

View file

@ -35,8 +35,13 @@ import { join } from "path";
import { getAgentDir } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tools/index.js";
import type { CustomAgentTool } from "./custom-tools/types.js";
import {
type CustomToolsLoadResult,
discoverAndLoadCustomTools,
type LoadedCustomTool,
wrapCustomTools,
} from "./custom-tools/index.js";
import type { CustomTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js";
import { convertToLlm } from "./messages.js";
@ -99,7 +104,7 @@ export interface CreateAgentSessionOptions {
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
tools?: Tool[];
/** Custom tools (replaces discovery). */
customTools?: Array<{ path?: string; tool: CustomAgentTool }>;
customTools?: Array<{ path?: string; tool: CustomTool }>;
/** Additional custom tool paths to load (merged with discovery). */
additionalCustomToolPaths?: string[];
@ -127,17 +132,14 @@ export interface CreateAgentSessionResult {
/** The created session */
session: AgentSession;
/** Custom tools result (for UI context setup in interactive mode) */
customToolsResult: {
tools: LoadedCustomTool[];
setUIContext: (uiContext: any, hasUI: boolean) => void;
};
customToolsResult: CustomToolsLoadResult;
/** Warning if session was restored with a different model than saved */
modelFallbackMessage?: string;
}
// Re-exports
export type { CustomAgentTool } from "./custom-tools/types.js";
export type { CustomTool } from "./custom-tools/types.js";
export type { HookAPI, HookFactory } from "./hooks/types.js";
export type { Settings, SkillsSettings } from "./settings-manager.js";
export type { Skill } from "./skills.js";
@ -219,7 +221,7 @@ export async function discoverHooks(
export async function discoverCustomTools(
cwd?: string,
agentDir?: string,
): Promise<Array<{ path: string; tool: CustomAgentTool }>> {
): Promise<Array<{ path: string; tool: CustomTool }>> {
const resolvedCwd = cwd ?? process.cwd();
const resolvedAgentDir = agentDir ?? getDefaultAgentDir();
@ -507,7 +509,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const builtInTools = options.tools ?? createCodingTools(cwd);
time("createCodingTools");
let customToolsResult: { tools: LoadedCustomTool[]; setUIContext: (ctx: any, hasUI: boolean) => void };
let customToolsResult: CustomToolsLoadResult;
if (options.customTools !== undefined) {
// Use provided custom tools
const loadedTools: LoadedCustomTool[] = options.customTools.map((ct) => ({
@ -517,17 +519,17 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
}));
customToolsResult = {
tools: loadedTools,
errors: [],
setUIContext: () => {},
};
} else {
// Discover custom tools, merging with additional paths
const configuredPaths = [...settingsManager.getCustomToolPaths(), ...(options.additionalCustomToolPaths ?? [])];
const result = await discoverAndLoadCustomTools(configuredPaths, cwd, Object.keys(allTools), agentDir);
customToolsResult = await discoverAndLoadCustomTools(configuredPaths, cwd, Object.keys(allTools), agentDir);
time("discoverAndLoadCustomTools");
for (const { path, error } of result.errors) {
for (const { path, error } of customToolsResult.errors) {
console.error(`Failed to load custom tool "${path}": ${error}`);
}
customToolsResult = result;
}
let hookRunner: HookRunner | undefined;
@ -549,7 +551,15 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
}
}
let allToolsArray: Tool[] = [...builtInTools, ...customToolsResult.tools.map((lt) => lt.tool as unknown as Tool)];
// Wrap custom tools with context getter (agent is assigned below, accessed at execute time)
let agent: Agent;
const wrappedCustomTools = wrapCustomTools(customToolsResult.tools, () => ({
sessionManager,
modelRegistry,
model: agent.state.model,
}));
let allToolsArray: Tool[] = [...builtInTools, ...wrappedCustomTools];
time("combineTools");
if (hookRunner) {
allToolsArray = wrapToolsWithHooks(allToolsArray, hookRunner) as Tool[];
@ -581,7 +591,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const slashCommands = options.slashCommands ?? discoverSlashCommands(cwd, agentDir);
time("discoverSlashCommands");
const agent = new Agent({
agent = new Agent({
initialState: {
systemPrompt,
model,

View file

@ -159,19 +159,21 @@ export interface SessionInfo {
allMessagesText: string;
}
/**
* Read-only interface for SessionManager.
* Used by compaction/summarization utilities that only need to read session data.
*/
export interface ReadonlySessionManager {
getLeafId(): string | null;
getEntry(id: string): SessionEntry | undefined;
getPath(fromId?: string): SessionEntry[];
getEntries(): SessionEntry[];
getChildren(parentId: string): SessionEntry[];
getTree(): SessionTreeNode[];
getLabel(id: string): string | undefined;
}
export type ReadonlySessionManager = Pick<
SessionManager,
| "getCwd"
| "getSessionDir"
| "getSessionId"
| "getSessionFile"
| "getLeafId"
| "getLeafEntry"
| "getEntry"
| "getLabel"
| "getBranch"
| "getHeader"
| "getEntries"
| "getTree"
>;
/** Generate a unique short ID (8 hex chars, collision-checked) */
function generateId(byId: { has(id: string): boolean }): string {
@ -772,7 +774,7 @@ export class SessionManager {
* Includes all entry types (messages, compaction, model changes, etc.).
* Use buildSessionContext() to get the resolved messages for the LLM.
*/
getPath(fromId?: string): SessionEntry[] {
getBranch(fromId?: string): SessionEntry[] {
const path: SessionEntry[] = [];
const startId = fromId ?? this.leafId;
let current = startId ? this.byId.get(startId) : undefined;
@ -908,7 +910,7 @@ export class SessionManager {
* Returns the new session file path, or undefined if not persisting.
*/
createBranchedSession(leafId: string): string | undefined {
const path = this.getPath(leafId);
const path = this.getBranch(leafId);
if (path.length === 0) {
throw new Error(`Entry ${leafId} not found`);
}

View file

@ -35,15 +35,16 @@ export {
// Custom tools
export type {
AgentToolUpdateCallback,
CustomAgentTool,
CustomTool,
CustomToolAPI,
CustomToolContext,
CustomToolFactory,
CustomToolSessionEvent,
CustomToolsLoadResult,
CustomToolUIContext,
ExecResult,
LoadedCustomTool,
RenderResultOptions,
SessionEvent as ToolSessionEvent,
ToolAPI,
ToolUIContext,
} from "./core/custom-tools/index.js";
export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js";
export type * from "./core/hooks/index.js";

View file

@ -11,7 +11,7 @@ import {
type TUI,
} from "@mariozechner/pi-tui";
import stripAnsi from "strip-ansi";
import type { CustomAgentTool } from "../../../core/custom-tools/types.js";
import type { CustomTool } from "../../../core/custom-tools/types.js";
import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize } from "../../../core/tools/truncate.js";
import { getLanguageFromPath, highlightCode, theme } from "../theme/theme.js";
import { renderDiff } from "./diff.js";
@ -55,7 +55,7 @@ export class ToolExecutionComponent extends Container {
private expanded = false;
private showImages: boolean;
private isPartial = true;
private customTool?: CustomAgentTool;
private customTool?: CustomTool;
private ui: TUI;
private result?: {
content: Array<{ type: string; text?: string; data?: string; mimeType?: string }>;
@ -67,7 +67,7 @@ export class ToolExecutionComponent extends Container {
toolName: string,
args: any,
options: ToolExecutionOptions = {},
customTool: CustomAgentTool | undefined,
customTool: CustomTool | undefined,
ui: TUI,
) {
super();

View file

@ -26,7 +26,7 @@ import {
import { exec, spawnSync } from "child_process";
import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js";
import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js";
import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js";
import type { CustomToolSessionEvent, LoadedCustomTool } from "../../core/custom-tools/index.js";
import type { HookUIContext } from "../../core/hooks/index.js";
import { createCompactionSummaryMessage } from "../../core/messages.js";
import { type SessionContext, SessionManager } from "../../core/session-manager.js";
@ -350,19 +350,20 @@ export class InteractiveMode {
this.chatContainer.addChild(new Spacer(1));
}
// Load session entries if any
const entries = this.session.sessionManager.getEntries();
// Set TUI-based UI context for custom tools
const uiContext = this.createHookUIContext();
// Create and set hook & tool UI context
const uiContext: HookUIContext = {
select: (title, options) => this.showHookSelector(title, options),
confirm: (title, message) => this.showHookConfirm(title, message),
input: (title, placeholder) => this.showHookInput(title, placeholder),
notify: (message, type) => this.showHookNotify(message, type),
custom: (component) => this.showHookCustom(component),
};
this.setToolUIContext(uiContext, true);
// Notify custom tools of session start
await this.emitToolSessionEvent({
entries,
sessionFile: this.session.sessionFile,
previousSessionFile: undefined,
await this.emitCustomToolSessionEvent({
reason: "start",
previousSessionFile: undefined,
});
const hookRunner = this.session.hookRunner;
@ -370,34 +371,35 @@ export class InteractiveMode {
return; // No hooks loaded
}
// Set UI context on hook runner
hookRunner.setUIContext(uiContext, true);
hookRunner.initialize({
getModel: () => this.session.model,
sendMessageHandler: (message, triggerTurn) => {
const wasStreaming = this.session.isStreaming;
this.session
.sendHookMessage(message, triggerTurn)
.then(() => {
// For non-streaming cases with display=true, update UI
// (streaming cases update via message_end event)
if (!wasStreaming && message.display) {
this.rebuildChatFromMessages();
}
})
.catch((err) => {
this.showError(`Hook sendMessage failed: ${err instanceof Error ? err.message : String(err)}`);
});
},
appendEntryHandler: (customType, data) => {
this.sessionManager.appendCustomEntry(customType, data);
},
uiContext,
hasUI: true,
});
// Subscribe to hook errors
hookRunner.onError((error) => {
this.showHookError(error.hookPath, error.error);
});
// Set up handlers for pi.sendMessage() and pi.appendEntry()
hookRunner.setSendMessageHandler((message, triggerTurn) => {
const wasStreaming = this.session.isStreaming;
this.session
.sendHookMessage(message, triggerTurn)
.then(() => {
// For non-streaming cases with display=true, update UI
// (streaming cases update via message_end event)
if (!wasStreaming && message.display) {
this.rebuildChatFromMessages();
}
})
.catch((err) => {
this.showError(`Hook sendMessage failed: ${err instanceof Error ? err.message : String(err)}`);
});
});
hookRunner.setAppendEntryHandler((customType, data) => {
this.sessionManager.appendCustomEntry(customType, data);
});
// Show loaded hooks
const hookPaths = hookRunner.getHookPaths();
if (hookPaths.length > 0) {
@ -415,11 +417,15 @@ export class InteractiveMode {
/**
* Emit session event to all custom tools.
*/
private async emitToolSessionEvent(event: ToolSessionEvent): Promise<void> {
private async emitCustomToolSessionEvent(event: CustomToolSessionEvent): Promise<void> {
for (const { tool } of this.customTools.values()) {
if (tool.onSession) {
try {
await tool.onSession(event);
await tool.onSession(event, {
sessionManager: this.session.sessionManager,
modelRegistry: this.session.modelRegistry,
model: this.session.model,
});
} catch (err) {
this.showToolError(tool.name, err instanceof Error ? err.message : String(err));
}
@ -436,19 +442,6 @@ export class InteractiveMode {
this.ui.requestRender();
}
/**
* Create the UI context for hooks.
*/
private createHookUIContext(): HookUIContext {
return {
select: (title, options) => this.showHookSelector(title, options),
confirm: (title, message) => this.showHookConfirm(title, message),
input: (title, placeholder) => this.showHookInput(title, placeholder),
notify: (message, type) => this.showHookNotify(message, type),
custom: (component) => this.showHookCustom(component),
};
}
/**
* Show a selector for hooks.
*/
@ -861,6 +854,7 @@ export class InteractiveMode {
this.customTools.get(content.name)?.tool,
this.ui,
);
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component);
this.pendingTools.set(content.id, component);
} else {
@ -909,6 +903,7 @@ export class InteractiveMode {
this.customTools.get(event.toolName)?.tool,
this.ui,
);
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component);
this.pendingTools.set(event.toolCallId, component);
this.ui.requestRender();
@ -1158,6 +1153,7 @@ export class InteractiveMode {
this.customTools.get(content.name)?.tool,
this.ui,
);
component.setExpanded(this.toolOutputExpanded);
this.chatContainer.addChild(component);
if (message.stopReason === "aborted" || message.stopReason === "error") {
@ -1251,7 +1247,7 @@ export class InteractiveMode {
}
// Emit shutdown event to custom tools
await this.session.emitToolSessionEvent("shutdown");
await this.session.emitCustomToolSessionEvent("shutdown");
this.stop();
process.exit(0);

View file

@ -26,25 +26,24 @@ export async function runPrintMode(
initialMessage?: string,
initialImages?: ImageContent[],
): Promise<void> {
// Load entries once for session start events
const entries = session.sessionManager.getEntries();
// Hook runner already has no-op UI context by default (set in main.ts)
// Set up hooks for print mode (no UI)
const hookRunner = session.hookRunner;
if (hookRunner) {
hookRunner.initialize({
getModel: () => session.model,
sendMessageHandler: (message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => {
console.error(`Hook sendMessage failed: ${e instanceof Error ? e.message : String(e)}`);
});
},
appendEntryHandler: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
},
});
hookRunner.onError((err) => {
console.error(`Hook error (${err.hookPath}): ${err.error}`);
});
// Set up handlers - sendHookMessage handles queuing/direct append as needed
hookRunner.setSendMessageHandler((message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => {
console.error(`Hook sendMessage failed: ${e instanceof Error ? e.message : String(e)}`);
});
});
hookRunner.setAppendEntryHandler((customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
});
// Emit session_start event
await hookRunner.emit({
type: "session_start",
@ -55,12 +54,17 @@ export async function runPrintMode(
for (const { tool } of session.customTools) {
if (tool.onSession) {
try {
await tool.onSession({
entries,
sessionFile: session.sessionFile,
previousSessionFile: undefined,
reason: "start",
});
await tool.onSession(
{
reason: "start",
previousSessionFile: undefined,
},
{
sessionManager: session.sessionManager,
modelRegistry: session.modelRegistry,
model: session.model,
},
);
} catch (_err) {
// Silently ignore tool errors
}

View file

@ -125,25 +125,25 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
},
});
// Load entries once for session start events
const entries = session.sessionManager.getEntries();
// Set up hooks with RPC-based UI context
const hookRunner = session.hookRunner;
if (hookRunner) {
hookRunner.setUIContext(createHookUIContext(), false);
hookRunner.initialize({
getModel: () => session.agent.state.model,
sendMessageHandler: (message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => {
output(error(undefined, "hook_send", e.message));
});
},
appendEntryHandler: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
},
uiContext: createHookUIContext(),
hasUI: false,
});
hookRunner.onError((err) => {
output({ type: "hook_error", hookPath: err.hookPath, event: err.event, error: err.error });
});
// Set up handlers for pi.sendMessage() and pi.appendEntry()
hookRunner.setSendMessageHandler((message, triggerTurn) => {
session.sendHookMessage(message, triggerTurn).catch((e) => {
output(error(undefined, "hook_send", e.message));
});
});
hookRunner.setAppendEntryHandler((customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
});
// Emit session_start event
await hookRunner.emit({
type: "session_start",
@ -155,12 +155,17 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
for (const { tool } of session.customTools) {
if (tool.onSession) {
try {
await tool.onSession({
entries,
sessionFile: session.sessionFile,
previousSessionFile: undefined,
reason: "start",
});
await tool.onSession(
{
previousSessionFile: undefined,
reason: "start",
},
{
sessionManager: session.sessionManager,
modelRegistry: session.modelRegistry,
model: session.model,
},
);
} catch (_err) {
// Silently ignore tool errors
}