diff --git a/packages/coding-agent/src/core/extensions/index.ts b/packages/coding-agent/src/core/extensions/index.ts new file mode 100644 index 00000000..0a563b70 --- /dev/null +++ b/packages/coding-agent/src/core/extensions/index.ts @@ -0,0 +1,100 @@ +/** + * Extension system - unified hooks and custom tools. + */ + +export { discoverAndLoadExtensions, loadExtensions } from "./loader.js"; +export type { BranchHandler, ExtensionErrorListener, NavigateTreeHandler, NewSessionHandler } from "./runner.js"; +export { ExtensionRunner } from "./runner.js"; +export type { + AgentEndEvent, + AgentStartEvent, + // Re-exports + AgentToolResult, + AgentToolUpdateCallback, + AppendEntryHandler, + BashToolResultEvent, + BeforeAgentStartEvent, + BeforeAgentStartEventResult, + // Events - Agent + ContextEvent, + // Event Results + ContextEventResult, + CustomToolResultEvent, + EditToolResultEvent, + ExecOptions, + ExecResult, + // API + ExtensionAPI, + ExtensionCommandContext, + // Context + ExtensionContext, + // Errors + ExtensionError, + ExtensionEvent, + ExtensionFactory, + ExtensionFlag, + ExtensionHandler, + ExtensionShortcut, + ExtensionUIContext, + FindToolResultEvent, + GetActiveToolsHandler, + GetAllToolsHandler, + GrepToolResultEvent, + LoadExtensionsResult, + // Loaded Extension + LoadedExtension, + LsToolResultEvent, + // Message Rendering + MessageRenderer, + MessageRenderOptions, + ReadToolResultEvent, + // Commands + RegisteredCommand, + RegisteredTool, + SendMessageHandler, + SessionBeforeBranchEvent, + SessionBeforeBranchResult, + SessionBeforeCompactEvent, + SessionBeforeCompactResult, + SessionBeforeSwitchEvent, + SessionBeforeSwitchResult, + SessionBeforeTreeEvent, + SessionBeforeTreeResult, + SessionBranchEvent, + SessionCompactEvent, + SessionEvent, + SessionShutdownEvent, + // Events - Session + SessionStartEvent, + SessionSwitchEvent, + SessionTreeEvent, + SetActiveToolsHandler, + // Events - Tool + ToolCallEvent, + ToolCallEventResult, + // Tools + ToolDefinition, + ToolRenderResultOptions, + ToolResultEvent, + ToolResultEventResult, + TreePreparation, + TurnEndEvent, + TurnStartEvent, + WriteToolResultEvent, +} from "./types.js"; +// Type guards +export { + isBashToolResult, + isEditToolResult, + isFindToolResult, + isGrepToolResult, + isLsToolResult, + isReadToolResult, + isWriteToolResult, +} from "./types.js"; +export { + wrapRegisteredTool, + wrapRegisteredTools, + wrapToolsWithExtensions, + wrapToolWithExtensions, +} from "./wrapper.js"; diff --git a/packages/coding-agent/src/core/extensions/loader.ts b/packages/coding-agent/src/core/extensions/loader.ts new file mode 100644 index 00000000..85b63dda --- /dev/null +++ b/packages/coding-agent/src/core/extensions/loader.ts @@ -0,0 +1,459 @@ +/** + * Extension loader - loads TypeScript extension modules using jiti. + */ + +import * as fs from "node:fs"; +import { createRequire } from "node:module"; +import * as os from "node:os"; +import * as path from "node:path"; +import { fileURLToPath } from "node:url"; +import type { KeyId } from "@mariozechner/pi-tui"; +import { createJiti } from "jiti"; +import { getAgentDir, isBunBinary } from "../../config.js"; +import { theme } from "../../modes/interactive/theme/theme.js"; +import { createEventBus, type EventBus } from "../event-bus.js"; +import type { ExecOptions } from "../exec.js"; +import { execCommand } from "../exec.js"; +import type { + AppendEntryHandler, + ExtensionAPI, + ExtensionFactory, + ExtensionFlag, + ExtensionShortcut, + ExtensionUIContext, + GetActiveToolsHandler, + GetAllToolsHandler, + LoadExtensionsResult, + LoadedExtension, + MessageRenderer, + RegisteredCommand, + RegisteredTool, + SendMessageHandler, + SetActiveToolsHandler, + ToolDefinition, +} from "./types.js"; + +const require = createRequire(import.meta.url); + +let _aliases: Record | null = null; +function getAliases(): Record { + if (_aliases) return _aliases; + + const __dirname = path.dirname(fileURLToPath(import.meta.url)); + const packageIndex = path.resolve(__dirname, "../..", "index.js"); + + const typeboxEntry = require.resolve("@sinclair/typebox"); + const typeboxRoot = typeboxEntry.replace(/\/build\/cjs\/index\.js$/, ""); + + _aliases = { + "@mariozechner/pi-coding-agent": packageIndex, + "@mariozechner/pi-coding-agent/extensions": path.resolve(__dirname, "index.js"), + "@mariozechner/pi-tui": require.resolve("@mariozechner/pi-tui"), + "@mariozechner/pi-ai": require.resolve("@mariozechner/pi-ai"), + "@sinclair/typebox": typeboxRoot, + }; + return _aliases; +} + +const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g; + +function normalizeUnicodeSpaces(str: string): string { + return str.replace(UNICODE_SPACES, " "); +} + +function expandPath(p: string): string { + const normalized = normalizeUnicodeSpaces(p); + if (normalized.startsWith("~/")) { + return path.join(os.homedir(), normalized.slice(2)); + } + if (normalized.startsWith("~")) { + return path.join(os.homedir(), normalized.slice(1)); + } + return normalized; +} + +function resolvePath(extPath: string, cwd: string): string { + const expanded = expandPath(extPath); + if (path.isAbsolute(expanded)) { + return expanded; + } + return path.resolve(cwd, expanded); +} + +function createNoOpUIContext(): ExtensionUIContext { + return { + select: async () => undefined, + confirm: async () => false, + input: async () => undefined, + notify: () => {}, + setStatus: () => {}, + setWidget: () => {}, + setTitle: () => {}, + custom: async () => undefined as never, + setEditorText: () => {}, + getEditorText: () => "", + editor: async () => undefined, + get theme() { + return theme; + }, + }; +} + +type HandlerFn = (...args: unknown[]) => Promise; + +function createExtensionAPI( + handlers: Map, + tools: Map, + cwd: string, + extensionPath: string, + eventBus: EventBus, + _sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, +): { + api: ExtensionAPI; + messageRenderers: Map; + commands: Map; + flags: Map; + flagValues: Map; + shortcuts: Map; + setSendMessageHandler: (handler: SendMessageHandler) => void; + setAppendEntryHandler: (handler: AppendEntryHandler) => void; + setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void; + setGetAllToolsHandler: (handler: GetAllToolsHandler) => void; + setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void; + setFlagValue: (name: string, value: boolean | string) => void; +} { + let sendMessageHandler: SendMessageHandler = () => {}; + let appendEntryHandler: AppendEntryHandler = () => {}; + let getActiveToolsHandler: GetActiveToolsHandler = () => []; + let getAllToolsHandler: GetAllToolsHandler = () => []; + let setActiveToolsHandler: SetActiveToolsHandler = () => {}; + + const messageRenderers = new Map(); + const commands = new Map(); + const flags = new Map(); + const flagValues = new Map(); + const shortcuts = new Map(); + + const api = { + on(event: string, handler: HandlerFn): void { + const list = handlers.get(event) ?? []; + list.push(handler); + handlers.set(event, list); + }, + + registerTool(tool: ToolDefinition): void { + tools.set(tool.name, { + definition: tool, + extensionPath, + }); + }, + + registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void { + commands.set(name, { name, ...options }); + }, + + registerShortcut( + shortcut: KeyId, + options: { + description?: string; + handler: (ctx: import("./types.js").ExtensionContext) => Promise | void; + }, + ): void { + shortcuts.set(shortcut, { shortcut, extensionPath, ...options }); + }, + + registerFlag( + name: string, + options: { description?: string; type: "boolean" | "string"; default?: boolean | string }, + ): void { + flags.set(name, { name, extensionPath, ...options }); + if (options.default !== undefined) { + flagValues.set(name, options.default); + } + }, + + getFlag(name: string): boolean | string | undefined { + return flagValues.get(name); + }, + + registerMessageRenderer(customType: string, renderer: MessageRenderer): void { + messageRenderers.set(customType, renderer as MessageRenderer); + }, + + sendMessage(message, options): void { + sendMessageHandler(message, options); + }, + + appendEntry(customType: string, data?: unknown): void { + appendEntryHandler(customType, data); + }, + + exec(command: string, args: string[], options?: ExecOptions) { + return execCommand(command, args, options?.cwd ?? cwd, options); + }, + + getActiveTools(): string[] { + return getActiveToolsHandler(); + }, + + getAllTools(): string[] { + return getAllToolsHandler(); + }, + + setActiveTools(toolNames: string[]): void { + setActiveToolsHandler(toolNames); + }, + + events: eventBus, + } as ExtensionAPI; + + return { + api, + messageRenderers, + commands, + flags, + flagValues, + shortcuts, + setSendMessageHandler: (handler: SendMessageHandler) => { + sendMessageHandler = handler; + }, + setAppendEntryHandler: (handler: AppendEntryHandler) => { + appendEntryHandler = handler; + }, + setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => { + getActiveToolsHandler = handler; + }, + setGetAllToolsHandler: (handler: GetAllToolsHandler) => { + getAllToolsHandler = handler; + }, + setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => { + setActiveToolsHandler = handler; + }, + setFlagValue: (name: string, value: boolean | string) => { + flagValues.set(name, value); + }, + }; +} + +async function loadExtensionWithBun( + resolvedPath: string, + cwd: string, + extensionPath: string, + eventBus: EventBus, + sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, +): Promise<{ extension: LoadedExtension | null; error: string | null }> { + try { + const module = await import(resolvedPath); + const factory = (module.default ?? module) as ExtensionFactory; + + if (typeof factory !== "function") { + return { extension: null, error: "Extension must export a default function" }; + } + + const handlers = new Map(); + const tools = new Map(); + const { + api, + messageRenderers, + commands, + flags, + flagValues, + shortcuts, + setSendMessageHandler, + setAppendEntryHandler, + setGetActiveToolsHandler, + setGetAllToolsHandler, + setSetActiveToolsHandler, + setFlagValue, + } = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI); + + factory(api); + + return { + extension: { + path: extensionPath, + resolvedPath, + handlers, + tools, + messageRenderers, + commands, + flags, + flagValues, + shortcuts, + setSendMessageHandler, + setAppendEntryHandler, + setGetActiveToolsHandler, + setGetAllToolsHandler, + setSetActiveToolsHandler, + setFlagValue, + }, + error: null, + }; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + + if (message.includes("Cannot find module") && message.includes("@mariozechner/")) { + return { + extension: null, + error: + `${message}\n` + + "Note: Extensions importing from @mariozechner/* packages are not supported in the standalone binary.\n" + + "Please install pi via npm: npm install -g @mariozechner/pi-coding-agent", + }; + } + + return { extension: null, error: `Failed to load extension: ${message}` }; + } +} + +async function loadExtension( + extensionPath: string, + cwd: string, + eventBus: EventBus, + sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, +): Promise<{ extension: LoadedExtension | null; error: string | null }> { + const resolvedPath = resolvePath(extensionPath, cwd); + + if (isBunBinary) { + return loadExtensionWithBun(resolvedPath, cwd, extensionPath, eventBus, sharedUI); + } + + try { + const jiti = createJiti(import.meta.url, { + alias: getAliases(), + }); + + const module = await jiti.import(resolvedPath, { default: true }); + const factory = module as ExtensionFactory; + + if (typeof factory !== "function") { + return { extension: null, error: "Extension must export a default function" }; + } + + const handlers = new Map(); + const tools = new Map(); + const { + api, + messageRenderers, + commands, + flags, + flagValues, + shortcuts, + setSendMessageHandler, + setAppendEntryHandler, + setGetActiveToolsHandler, + setGetAllToolsHandler, + setSetActiveToolsHandler, + setFlagValue, + } = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI); + + factory(api); + + return { + extension: { + path: extensionPath, + resolvedPath, + handlers, + tools, + messageRenderers, + commands, + flags, + flagValues, + shortcuts, + setSendMessageHandler, + setAppendEntryHandler, + setGetActiveToolsHandler, + setGetAllToolsHandler, + setSetActiveToolsHandler, + setFlagValue, + }, + error: null, + }; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return { extension: null, error: `Failed to load extension: ${message}` }; + } +} + +/** + * Load extensions from paths. + */ +export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise { + const extensions: LoadedExtension[] = []; + const errors: Array<{ path: string; error: string }> = []; + const resolvedEventBus = eventBus ?? createEventBus(); + const sharedUI = { ui: createNoOpUIContext(), hasUI: false }; + + for (const extPath of paths) { + const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, sharedUI); + + if (error) { + errors.push({ path: extPath, error }); + continue; + } + + if (extension) { + extensions.push(extension); + } + } + + return { + extensions, + errors, + setUIContext(uiContext, hasUI) { + sharedUI.ui = uiContext; + sharedUI.hasUI = hasUI; + }, + }; +} + +function discoverExtensionsInDir(dir: string): string[] { + if (!fs.existsSync(dir)) { + return []; + } + + try { + const entries = fs.readdirSync(dir, { withFileTypes: true }); + return entries + .filter((e) => (e.isFile() || e.isSymbolicLink()) && e.name.endsWith(".ts")) + .map((e) => path.join(dir, e.name)); + } catch { + return []; + } +} + +/** + * Discover and load extensions from standard locations. + */ +export async function discoverAndLoadExtensions( + configuredPaths: string[], + cwd: string, + agentDir: string = getAgentDir(), + eventBus?: EventBus, +): Promise { + const allPaths: string[] = []; + const seen = new Set(); + + const addPaths = (paths: string[]) => { + for (const p of paths) { + const resolved = path.resolve(p); + if (!seen.has(resolved)) { + seen.add(resolved); + allPaths.push(p); + } + } + }; + + // 1. Global extensions: agentDir/extensions/ + const globalExtDir = path.join(agentDir, "extensions"); + addPaths(discoverExtensionsInDir(globalExtDir)); + + // 2. Project-local extensions: cwd/.pi/extensions/ + const localExtDir = path.join(cwd, ".pi", "extensions"); + addPaths(discoverExtensionsInDir(localExtDir)); + + // 3. Explicitly configured paths + addPaths(configuredPaths.map((p) => resolvePath(p, cwd))); + + return loadExtensions(allPaths, cwd, eventBus); +} diff --git a/packages/coding-agent/src/core/extensions/runner.ts b/packages/coding-agent/src/core/extensions/runner.ts new file mode 100644 index 00000000..c6a3528f --- /dev/null +++ b/packages/coding-agent/src/core/extensions/runner.ts @@ -0,0 +1,464 @@ +/** + * Extension runner - executes extensions and manages their lifecycle. + */ + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Model } from "@mariozechner/pi-ai"; +import type { KeyId } from "@mariozechner/pi-tui"; +import { theme } from "../../modes/interactive/theme/theme.js"; +import type { ModelRegistry } from "../model-registry.js"; +import type { SessionManager } from "../session-manager.js"; +import type { + AppendEntryHandler, + BeforeAgentStartEvent, + BeforeAgentStartEventResult, + ContextEvent, + ContextEventResult, + ExtensionCommandContext, + ExtensionContext, + ExtensionError, + ExtensionEvent, + ExtensionFlag, + ExtensionShortcut, + ExtensionUIContext, + GetActiveToolsHandler, + GetAllToolsHandler, + LoadedExtension, + MessageRenderer, + RegisteredCommand, + RegisteredTool, + SendMessageHandler, + SessionBeforeCompactResult, + SessionBeforeTreeResult, + SetActiveToolsHandler, + ToolCallEvent, + ToolCallEventResult, + ToolResultEventResult, +} from "./types.js"; + +/** Combined result from all before_agent_start handlers */ +interface BeforeAgentStartCombinedResult { + messages?: NonNullable[]; + systemPromptAppend?: string; +} + +export type ExtensionErrorListener = (error: ExtensionError) => void; + +export type NewSessionHandler = (options?: { + parentSession?: string; + setup?: (sessionManager: SessionManager) => Promise; +}) => Promise<{ cancelled: boolean }>; + +export type BranchHandler = (entryId: string) => Promise<{ cancelled: boolean }>; + +export type NavigateTreeHandler = ( + targetId: string, + options?: { summarize?: boolean }, +) => Promise<{ cancelled: boolean }>; + +const noOpUIContext: ExtensionUIContext = { + select: async () => undefined, + confirm: async () => false, + input: async () => undefined, + notify: () => {}, + setStatus: () => {}, + setWidget: () => {}, + setTitle: () => {}, + custom: async () => undefined as never, + setEditorText: () => {}, + getEditorText: () => "", + editor: async () => undefined, + get theme() { + return theme; + }, +}; + +export class ExtensionRunner { + private extensions: LoadedExtension[]; + private uiContext: ExtensionUIContext; + private hasUI: boolean; + private cwd: string; + private sessionManager: SessionManager; + private modelRegistry: ModelRegistry; + private errorListeners: Set = new Set(); + private getModel: () => Model | undefined = () => undefined; + private isIdleFn: () => boolean = () => true; + private waitForIdleFn: () => Promise = async () => {}; + private abortFn: () => void = () => {}; + private hasPendingMessagesFn: () => boolean = () => false; + private newSessionHandler: NewSessionHandler = async () => ({ cancelled: false }); + private branchHandler: BranchHandler = async () => ({ cancelled: false }); + private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false }); + + constructor( + extensions: LoadedExtension[], + cwd: string, + sessionManager: SessionManager, + modelRegistry: ModelRegistry, + ) { + this.extensions = extensions; + this.uiContext = noOpUIContext; + this.hasUI = false; + this.cwd = cwd; + this.sessionManager = sessionManager; + this.modelRegistry = modelRegistry; + } + + initialize(options: { + getModel: () => Model | undefined; + sendMessageHandler: SendMessageHandler; + appendEntryHandler: AppendEntryHandler; + getActiveToolsHandler: GetActiveToolsHandler; + getAllToolsHandler: GetAllToolsHandler; + setActiveToolsHandler: SetActiveToolsHandler; + newSessionHandler?: NewSessionHandler; + branchHandler?: BranchHandler; + navigateTreeHandler?: NavigateTreeHandler; + isIdle?: () => boolean; + waitForIdle?: () => Promise; + abort?: () => void; + hasPendingMessages?: () => boolean; + uiContext?: ExtensionUIContext; + hasUI?: boolean; + }): void { + this.getModel = options.getModel; + this.isIdleFn = options.isIdle ?? (() => true); + this.waitForIdleFn = options.waitForIdle ?? (async () => {}); + this.abortFn = options.abort ?? (() => {}); + this.hasPendingMessagesFn = options.hasPendingMessages ?? (() => false); + + if (options.newSessionHandler) { + this.newSessionHandler = options.newSessionHandler; + } + if (options.branchHandler) { + this.branchHandler = options.branchHandler; + } + if (options.navigateTreeHandler) { + this.navigateTreeHandler = options.navigateTreeHandler; + } + + for (const ext of this.extensions) { + ext.setSendMessageHandler(options.sendMessageHandler); + ext.setAppendEntryHandler(options.appendEntryHandler); + ext.setGetActiveToolsHandler(options.getActiveToolsHandler); + ext.setGetAllToolsHandler(options.getAllToolsHandler); + ext.setSetActiveToolsHandler(options.setActiveToolsHandler); + } + + this.uiContext = options.uiContext ?? noOpUIContext; + this.hasUI = options.hasUI ?? false; + } + + getUIContext(): ExtensionUIContext | null { + return this.uiContext; + } + + getHasUI(): boolean { + return this.hasUI; + } + + getExtensionPaths(): string[] { + return this.extensions.map((e) => e.path); + } + + /** Get all registered tools from all extensions. */ + getAllRegisteredTools(): RegisteredTool[] { + const tools: RegisteredTool[] = []; + for (const ext of this.extensions) { + for (const tool of ext.tools.values()) { + tools.push(tool); + } + } + return tools; + } + + getFlags(): Map { + const allFlags = new Map(); + for (const ext of this.extensions) { + for (const [name, flag] of ext.flags) { + allFlags.set(name, flag); + } + } + return allFlags; + } + + setFlagValue(name: string, value: boolean | string): void { + for (const ext of this.extensions) { + if (ext.flags.has(name)) { + ext.setFlagValue(name, value); + } + } + } + + private static readonly RESERVED_SHORTCUTS = new Set([ + "ctrl+c", + "ctrl+d", + "ctrl+z", + "ctrl+k", + "ctrl+p", + "ctrl+l", + "ctrl+o", + "ctrl+t", + "ctrl+g", + "shift+tab", + "shift+ctrl+p", + "alt+enter", + "escape", + "enter", + ]); + + getShortcuts(): Map { + const allShortcuts = new Map(); + for (const ext of this.extensions) { + for (const [key, shortcut] of ext.shortcuts) { + const normalizedKey = key.toLowerCase() as KeyId; + + if (ExtensionRunner.RESERVED_SHORTCUTS.has(normalizedKey)) { + console.warn( + `Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`, + ); + continue; + } + + const existing = allShortcuts.get(normalizedKey); + if (existing) { + console.warn( + `Extension shortcut conflict: '${key}' registered by both ${existing.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`, + ); + } + allShortcuts.set(normalizedKey, shortcut); + } + } + return allShortcuts; + } + + onError(listener: ExtensionErrorListener): () => void { + this.errorListeners.add(listener); + return () => this.errorListeners.delete(listener); + } + + emitError(error: ExtensionError): void { + for (const listener of this.errorListeners) { + listener(error); + } + } + + hasHandlers(eventType: string): boolean { + for (const ext of this.extensions) { + const handlers = ext.handlers.get(eventType); + if (handlers && handlers.length > 0) { + return true; + } + } + return false; + } + + getMessageRenderer(customType: string): MessageRenderer | undefined { + for (const ext of this.extensions) { + const renderer = ext.messageRenderers.get(customType); + if (renderer) { + return renderer; + } + } + return undefined; + } + + getRegisteredCommands(): RegisteredCommand[] { + const commands: RegisteredCommand[] = []; + for (const ext of this.extensions) { + for (const command of ext.commands.values()) { + commands.push(command); + } + } + return commands; + } + + getCommand(name: string): RegisteredCommand | undefined { + for (const ext of this.extensions) { + const command = ext.commands.get(name); + if (command) { + return command; + } + } + return undefined; + } + + private createContext(): ExtensionContext { + return { + ui: this.uiContext, + hasUI: this.hasUI, + cwd: this.cwd, + sessionManager: this.sessionManager, + modelRegistry: this.modelRegistry, + model: this.getModel(), + isIdle: () => this.isIdleFn(), + abort: () => this.abortFn(), + hasPendingMessages: () => this.hasPendingMessagesFn(), + }; + } + + createCommandContext(): ExtensionCommandContext { + return { + ...this.createContext(), + waitForIdle: () => this.waitForIdleFn(), + newSession: (options) => this.newSessionHandler(options), + branch: (entryId) => this.branchHandler(entryId), + navigateTree: (targetId, options) => this.navigateTreeHandler(targetId, options), + }; + } + + private isSessionBeforeEvent( + type: string, + ): type is "session_before_switch" | "session_before_branch" | "session_before_compact" | "session_before_tree" { + return ( + type === "session_before_switch" || + type === "session_before_branch" || + type === "session_before_compact" || + type === "session_before_tree" + ); + } + + async emit( + event: ExtensionEvent, + ): Promise { + const ctx = this.createContext(); + let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined; + + for (const ext of this.extensions) { + const handlers = ext.handlers.get(event.type); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + try { + const handlerResult = await handler(event, ctx); + + if (this.isSessionBeforeEvent(event.type) && handlerResult) { + result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult; + if (result.cancel) { + return result; + } + } + + if (event.type === "tool_result" && handlerResult) { + result = handlerResult as ToolResultEventResult; + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + const stack = err instanceof Error ? err.stack : undefined; + this.emitError({ + extensionPath: ext.path, + event: event.type, + error: message, + stack, + }); + } + } + } + + return result; + } + + async emitToolCall(event: ToolCallEvent): Promise { + const ctx = this.createContext(); + let result: ToolCallEventResult | undefined; + + for (const ext of this.extensions) { + const handlers = ext.handlers.get("tool_call"); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + const handlerResult = await handler(event, ctx); + + if (handlerResult) { + result = handlerResult as ToolCallEventResult; + if (result.block) { + return result; + } + } + } + } + + return result; + } + + async emitContext(messages: AgentMessage[]): Promise { + const ctx = this.createContext(); + let currentMessages = structuredClone(messages); + + for (const ext of this.extensions) { + const handlers = ext.handlers.get("context"); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + try { + const event: ContextEvent = { type: "context", messages: currentMessages }; + const handlerResult = await handler(event, ctx); + + if (handlerResult && (handlerResult as ContextEventResult).messages) { + currentMessages = (handlerResult as ContextEventResult).messages!; + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + const stack = err instanceof Error ? err.stack : undefined; + this.emitError({ + extensionPath: ext.path, + event: "context", + error: message, + stack, + }); + } + } + } + + return currentMessages; + } + + async emitBeforeAgentStart( + prompt: string, + images?: ImageContent[], + ): Promise { + const ctx = this.createContext(); + const messages: NonNullable[] = []; + const systemPromptAppends: string[] = []; + + for (const ext of this.extensions) { + const handlers = ext.handlers.get("before_agent_start"); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + try { + const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images }; + const handlerResult = await handler(event, ctx); + + if (handlerResult) { + const result = handlerResult as BeforeAgentStartEventResult; + if (result.message) { + messages.push(result.message); + } + if (result.systemPromptAppend) { + systemPromptAppends.push(result.systemPromptAppend); + } + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + const stack = err instanceof Error ? err.stack : undefined; + this.emitError({ + extensionPath: ext.path, + event: "before_agent_start", + error: message, + stack, + }); + } + } + } + + if (messages.length > 0 || systemPromptAppends.length > 0) { + return { + messages: messages.length > 0 ? messages : undefined, + systemPromptAppend: systemPromptAppends.length > 0 ? systemPromptAppends.join("\n\n") : undefined, + }; + } + + return undefined; + } +} diff --git a/packages/coding-agent/src/core/extensions/types.ts b/packages/coding-agent/src/core/extensions/types.ts new file mode 100644 index 00000000..39791161 --- /dev/null +++ b/packages/coding-agent/src/core/extensions/types.ts @@ -0,0 +1,688 @@ +/** + * Extension system types. + * + * Extensions are TypeScript modules that can: + * - Subscribe to agent lifecycle events + * - Register LLM-callable tools + * - Register slash commands, keyboard shortcuts, and CLI flags + * - Interact with the user via UI primitives + */ + +import type { AgentMessage, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { Component, KeyId, TUI } from "@mariozechner/pi-tui"; +import type { Static, TSchema } from "@sinclair/typebox"; +import type { Theme } from "../../modes/interactive/theme/theme.js"; +import type { CompactionPreparation, CompactionResult } from "../compaction/index.js"; +import type { EventBus } from "../event-bus.js"; +import type { ExecOptions, ExecResult } from "../exec.js"; +import type { HookMessage } from "../messages.js"; +import type { ModelRegistry } from "../model-registry.js"; +import type { + BranchSummaryEntry, + CompactionEntry, + ReadonlySessionManager, + SessionEntry, + SessionManager, +} from "../session-manager.js"; +import type { EditToolDetails } from "../tools/edit.js"; +import type { + BashToolDetails, + FindToolDetails, + GrepToolDetails, + LsToolDetails, + ReadToolDetails, +} from "../tools/index.js"; + +export type { ExecOptions, ExecResult } from "../exec.js"; +export type { AgentToolResult, AgentToolUpdateCallback }; + +// ============================================================================ +// UI Context +// ============================================================================ + +/** + * UI context for extensions to request interactive UI. + * Each mode (interactive, RPC, print) provides its own implementation. + */ +export interface ExtensionUIContext { + /** Show a selector and return the user's choice. */ + select(title: string, options: string[]): Promise; + + /** Show a confirmation dialog. */ + confirm(title: string, message: string): Promise; + + /** Show a text input dialog. */ + input(title: string, placeholder?: string): Promise; + + /** Show a notification to the user. */ + notify(message: string, type?: "info" | "warning" | "error"): void; + + /** Set status text in the footer/status bar. Pass undefined to clear. */ + setStatus(key: string, text: string | undefined): void; + + /** Set a widget to display above the editor. Accepts string array or component factory. */ + setWidget(key: string, content: string[] | undefined): void; + setWidget(key: string, content: ((tui: TUI, theme: Theme) => Component & { dispose?(): void }) | undefined): void; + + /** Set the terminal window/tab title. */ + setTitle(title: string): void; + + /** Show a custom component with keyboard focus. */ + custom( + factory: ( + tui: TUI, + theme: Theme, + done: (result: T) => void, + ) => (Component & { dispose?(): void }) | Promise, + ): Promise; + + /** Set the text in the core input editor. */ + setEditorText(text: string): void; + + /** Get the current text from the core input editor. */ + getEditorText(): string; + + /** Show a multi-line editor for text editing. */ + editor(title: string, prefill?: string): Promise; + + /** Get the current theme for styling. */ + readonly theme: Theme; +} + +// ============================================================================ +// Extension Context +// ============================================================================ + +/** + * Context passed to extension event handlers. + */ +export interface ExtensionContext { + /** UI methods for user interaction */ + ui: ExtensionUIContext; + /** Whether UI is available (false in print/RPC mode) */ + hasUI: boolean; + /** Current working directory */ + cwd: string; + /** Session manager (read-only) */ + sessionManager: ReadonlySessionManager; + /** Model registry for API key resolution */ + modelRegistry: ModelRegistry; + /** Current model (may be undefined) */ + model: Model | undefined; + /** Whether the agent is idle (not streaming) */ + isIdle(): boolean; + /** Abort the current agent operation */ + abort(): void; + /** Whether there are queued messages waiting */ + hasPendingMessages(): boolean; +} + +/** + * Extended context for slash command handlers. + * Includes session control methods only safe in user-initiated commands. + */ +export interface ExtensionCommandContext extends ExtensionContext { + /** Wait for the agent to finish streaming */ + waitForIdle(): Promise; + + /** Start a new session, optionally with initialization. */ + newSession(options?: { + parentSession?: string; + setup?: (sessionManager: SessionManager) => Promise; + }): Promise<{ cancelled: boolean }>; + + /** Branch from a specific entry, creating a new session file. */ + branch(entryId: string): Promise<{ cancelled: boolean }>; + + /** Navigate to a different point in the session tree. */ + navigateTree(targetId: string, options?: { summarize?: boolean }): Promise<{ cancelled: boolean }>; +} + +// ============================================================================ +// Tool Types +// ============================================================================ + +/** Rendering options for tool results */ +export interface ToolRenderResultOptions { + /** Whether the result view is expanded */ + expanded: boolean; + /** Whether this is a partial/streaming result */ + isPartial: boolean; +} + +/** + * Tool definition for registerTool(). + */ +export interface ToolDefinition { + /** Tool name (used in LLM tool calls) */ + name: string; + /** Human-readable label for UI */ + label: string; + /** Description for LLM */ + description: string; + /** Parameter schema (TypeBox) */ + parameters: TParams; + + /** Execute the tool. */ + execute( + toolCallId: string, + params: Static, + onUpdate: AgentToolUpdateCallback | undefined, + ctx: ExtensionContext, + signal?: AbortSignal, + ): Promise>; + + /** Custom rendering for tool call display */ + renderCall?: (args: Static, theme: Theme) => Component; + + /** Custom rendering for tool result display */ + renderResult?: (result: AgentToolResult, options: ToolRenderResultOptions, theme: Theme) => Component; +} + +// ============================================================================ +// Session Events +// ============================================================================ + +/** Fired on initial session load */ +export interface SessionStartEvent { + type: "session_start"; +} + +/** Fired before switching to another session (can be cancelled) */ +export interface SessionBeforeSwitchEvent { + type: "session_before_switch"; + reason: "new" | "resume"; + targetSessionFile?: string; +} + +/** Fired after switching to another session */ +export interface SessionSwitchEvent { + type: "session_switch"; + reason: "new" | "resume"; + previousSessionFile: string | undefined; +} + +/** Fired before branching a session (can be cancelled) */ +export interface SessionBeforeBranchEvent { + type: "session_before_branch"; + entryId: string; +} + +/** Fired after branching a session */ +export interface SessionBranchEvent { + type: "session_branch"; + previousSessionFile: string | undefined; +} + +/** Fired before context compaction (can be cancelled or customized) */ +export interface SessionBeforeCompactEvent { + type: "session_before_compact"; + preparation: CompactionPreparation; + branchEntries: SessionEntry[]; + customInstructions?: string; + signal: AbortSignal; +} + +/** Fired after context compaction */ +export interface SessionCompactEvent { + type: "session_compact"; + compactionEntry: CompactionEntry; + fromHook: boolean; +} + +/** Fired on process exit */ +export interface SessionShutdownEvent { + type: "session_shutdown"; +} + +/** Preparation data for tree navigation */ +export interface TreePreparation { + targetId: string; + oldLeafId: string | null; + commonAncestorId: string | null; + entriesToSummarize: SessionEntry[]; + userWantsSummary: boolean; +} + +/** Fired before navigating in the session tree (can be cancelled) */ +export interface SessionBeforeTreeEvent { + type: "session_before_tree"; + preparation: TreePreparation; + signal: AbortSignal; +} + +/** Fired after navigating in the session tree */ +export interface SessionTreeEvent { + type: "session_tree"; + newLeafId: string | null; + oldLeafId: string | null; + summaryEntry?: BranchSummaryEntry; + fromHook?: boolean; +} + +export type SessionEvent = + | SessionStartEvent + | SessionBeforeSwitchEvent + | SessionSwitchEvent + | SessionBeforeBranchEvent + | SessionBranchEvent + | SessionBeforeCompactEvent + | SessionCompactEvent + | SessionShutdownEvent + | SessionBeforeTreeEvent + | SessionTreeEvent; + +// ============================================================================ +// Agent Events +// ============================================================================ + +/** Fired before each LLM call. Can modify messages. */ +export interface ContextEvent { + type: "context"; + messages: AgentMessage[]; +} + +/** Fired after user submits prompt but before agent loop. */ +export interface BeforeAgentStartEvent { + type: "before_agent_start"; + prompt: string; + images?: ImageContent[]; +} + +/** Fired when an agent loop starts */ +export interface AgentStartEvent { + type: "agent_start"; +} + +/** Fired when an agent loop ends */ +export interface AgentEndEvent { + type: "agent_end"; + messages: AgentMessage[]; +} + +/** Fired at the start of each turn */ +export interface TurnStartEvent { + type: "turn_start"; + turnIndex: number; + timestamp: number; +} + +/** Fired at the end of each turn */ +export interface TurnEndEvent { + type: "turn_end"; + turnIndex: number; + message: AgentMessage; + toolResults: ToolResultMessage[]; +} + +// ============================================================================ +// Tool Events +// ============================================================================ + +/** Fired before a tool executes. Can block. */ +export interface ToolCallEvent { + type: "tool_call"; + toolName: string; + toolCallId: string; + input: Record; +} + +interface ToolResultEventBase { + type: "tool_result"; + toolCallId: string; + input: Record; + content: (TextContent | ImageContent)[]; + isError: boolean; +} + +export interface BashToolResultEvent extends ToolResultEventBase { + toolName: "bash"; + details: BashToolDetails | undefined; +} + +export interface ReadToolResultEvent extends ToolResultEventBase { + toolName: "read"; + details: ReadToolDetails | undefined; +} + +export interface EditToolResultEvent extends ToolResultEventBase { + toolName: "edit"; + details: EditToolDetails | undefined; +} + +export interface WriteToolResultEvent extends ToolResultEventBase { + toolName: "write"; + details: undefined; +} + +export interface GrepToolResultEvent extends ToolResultEventBase { + toolName: "grep"; + details: GrepToolDetails | undefined; +} + +export interface FindToolResultEvent extends ToolResultEventBase { + toolName: "find"; + details: FindToolDetails | undefined; +} + +export interface LsToolResultEvent extends ToolResultEventBase { + toolName: "ls"; + details: LsToolDetails | undefined; +} + +export interface CustomToolResultEvent extends ToolResultEventBase { + toolName: string; + details: unknown; +} + +/** Fired after a tool executes. Can modify result. */ +export type ToolResultEvent = + | BashToolResultEvent + | ReadToolResultEvent + | EditToolResultEvent + | WriteToolResultEvent + | GrepToolResultEvent + | FindToolResultEvent + | LsToolResultEvent + | CustomToolResultEvent; + +// Type guards +export function isBashToolResult(e: ToolResultEvent): e is BashToolResultEvent { + return e.toolName === "bash"; +} +export function isReadToolResult(e: ToolResultEvent): e is ReadToolResultEvent { + return e.toolName === "read"; +} +export function isEditToolResult(e: ToolResultEvent): e is EditToolResultEvent { + return e.toolName === "edit"; +} +export function isWriteToolResult(e: ToolResultEvent): e is WriteToolResultEvent { + return e.toolName === "write"; +} +export function isGrepToolResult(e: ToolResultEvent): e is GrepToolResultEvent { + return e.toolName === "grep"; +} +export function isFindToolResult(e: ToolResultEvent): e is FindToolResultEvent { + return e.toolName === "find"; +} +export function isLsToolResult(e: ToolResultEvent): e is LsToolResultEvent { + return e.toolName === "ls"; +} + +/** Union of all event types */ +export type ExtensionEvent = + | SessionEvent + | ContextEvent + | BeforeAgentStartEvent + | AgentStartEvent + | AgentEndEvent + | TurnStartEvent + | TurnEndEvent + | ToolCallEvent + | ToolResultEvent; + +// ============================================================================ +// Event Results +// ============================================================================ + +export interface ContextEventResult { + messages?: AgentMessage[]; +} + +export interface ToolCallEventResult { + block?: boolean; + reason?: string; +} + +export interface ToolResultEventResult { + content?: (TextContent | ImageContent)[]; + details?: unknown; + isError?: boolean; +} + +export interface BeforeAgentStartEventResult { + message?: Pick; + systemPromptAppend?: string; +} + +export interface SessionBeforeSwitchResult { + cancel?: boolean; +} + +export interface SessionBeforeBranchResult { + cancel?: boolean; + skipConversationRestore?: boolean; +} + +export interface SessionBeforeCompactResult { + cancel?: boolean; + compaction?: CompactionResult; +} + +export interface SessionBeforeTreeResult { + cancel?: boolean; + summary?: { + summary: string; + details?: unknown; + }; +} + +// ============================================================================ +// Message Rendering +// ============================================================================ + +export interface MessageRenderOptions { + expanded: boolean; +} + +export type MessageRenderer = ( + message: HookMessage, + options: MessageRenderOptions, + theme: Theme, +) => Component | undefined; + +// ============================================================================ +// Command Registration +// ============================================================================ + +export interface RegisteredCommand { + name: string; + description?: string; + handler: (args: string, ctx: ExtensionCommandContext) => Promise; +} + +// ============================================================================ +// Extension API +// ============================================================================ + +/** Handler function type for events */ +// biome-ignore lint/suspicious/noConfusingVoidType: void allows bare return statements +export type ExtensionHandler = (event: E, ctx: ExtensionContext) => Promise | R | void; + +/** + * ExtensionAPI passed to extension factory functions. + */ +export interface ExtensionAPI { + // ========================================================================= + // Event Subscription + // ========================================================================= + + on(event: "session_start", handler: ExtensionHandler): void; + on( + event: "session_before_switch", + handler: ExtensionHandler, + ): void; + on(event: "session_switch", handler: ExtensionHandler): void; + on( + event: "session_before_branch", + handler: ExtensionHandler, + ): void; + on(event: "session_branch", handler: ExtensionHandler): void; + on( + event: "session_before_compact", + handler: ExtensionHandler, + ): void; + on(event: "session_compact", handler: ExtensionHandler): void; + on(event: "session_shutdown", handler: ExtensionHandler): void; + on(event: "session_before_tree", handler: ExtensionHandler): void; + on(event: "session_tree", handler: ExtensionHandler): void; + on(event: "context", handler: ExtensionHandler): void; + on(event: "before_agent_start", handler: ExtensionHandler): void; + on(event: "agent_start", handler: ExtensionHandler): void; + on(event: "agent_end", handler: ExtensionHandler): void; + on(event: "turn_start", handler: ExtensionHandler): void; + on(event: "turn_end", handler: ExtensionHandler): void; + on(event: "tool_call", handler: ExtensionHandler): void; + on(event: "tool_result", handler: ExtensionHandler): void; + + // ========================================================================= + // Tool Registration + // ========================================================================= + + /** Register a tool that the LLM can call. */ + registerTool(tool: ToolDefinition): void; + + // ========================================================================= + // Command, Shortcut, Flag Registration + // ========================================================================= + + /** Register a custom slash command. */ + registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void; + + /** Register a keyboard shortcut. */ + registerShortcut( + shortcut: KeyId, + options: { + description?: string; + handler: (ctx: ExtensionContext) => Promise | void; + }, + ): void; + + /** Register a CLI flag. */ + registerFlag( + name: string, + options: { + description?: string; + type: "boolean" | "string"; + default?: boolean | string; + }, + ): void; + + /** Get the value of a registered CLI flag. */ + getFlag(name: string): boolean | string | undefined; + + // ========================================================================= + // Message Rendering + // ========================================================================= + + /** Register a custom renderer for HookMessageEntry. */ + registerMessageRenderer(customType: string, renderer: MessageRenderer): void; + + // ========================================================================= + // Actions + // ========================================================================= + + /** Send a custom message to the session. */ + sendMessage( + message: Pick, "customType" | "content" | "display" | "details">, + options?: { triggerTurn?: boolean; deliverAs?: "steer" | "followUp" | "nextTurn" }, + ): void; + + /** Append a custom entry to the session for state persistence (not sent to LLM). */ + appendEntry(customType: string, data?: T): void; + + /** Execute a shell command. */ + exec(command: string, args: string[], options?: ExecOptions): Promise; + + /** Get the list of currently active tool names. */ + getActiveTools(): string[]; + + /** Get all configured tools (built-in + extension tools). */ + getAllTools(): string[]; + + /** Set the active tools by name. */ + setActiveTools(toolNames: string[]): void; + + /** Shared event bus for extension communication. */ + events: EventBus; +} + +/** Extension factory function type. */ +export type ExtensionFactory = (pi: ExtensionAPI) => void; + +// ============================================================================ +// Loaded Extension Types +// ============================================================================ + +export interface RegisteredTool { + definition: ToolDefinition; + extensionPath: string; +} + +export interface ExtensionFlag { + name: string; + description?: string; + type: "boolean" | "string"; + default?: boolean | string; + extensionPath: string; +} + +export interface ExtensionShortcut { + shortcut: KeyId; + description?: string; + handler: (ctx: ExtensionContext) => Promise | void; + extensionPath: string; +} + +type HandlerFn = (...args: unknown[]) => Promise; + +export type SendMessageHandler = ( + message: Pick, "customType" | "content" | "display" | "details">, + options?: { triggerTurn?: boolean; deliverAs?: "steer" | "followUp" | "nextTurn" }, +) => void; + +export type AppendEntryHandler = (customType: string, data?: T) => void; + +export type GetActiveToolsHandler = () => string[]; + +export type GetAllToolsHandler = () => string[]; + +export type SetActiveToolsHandler = (toolNames: string[]) => void; + +/** Loaded extension with all registered items. */ +export interface LoadedExtension { + path: string; + resolvedPath: string; + handlers: Map; + tools: Map; + messageRenderers: Map; + commands: Map; + flags: Map; + flagValues: Map; + shortcuts: Map; + setSendMessageHandler: (handler: SendMessageHandler) => void; + setAppendEntryHandler: (handler: AppendEntryHandler) => void; + setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void; + setGetAllToolsHandler: (handler: GetAllToolsHandler) => void; + setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void; + setFlagValue: (name: string, value: boolean | string) => void; +} + +/** Result of loading extensions. */ +export interface LoadExtensionsResult { + extensions: LoadedExtension[]; + errors: Array<{ path: string; error: string }>; + setUIContext(uiContext: ExtensionUIContext, hasUI: boolean): void; +} + +// ============================================================================ +// Extension Error +// ============================================================================ + +export interface ExtensionError { + extensionPath: string; + event: string; + error: string; + stack?: string; +} diff --git a/packages/coding-agent/src/core/extensions/wrapper.ts b/packages/coding-agent/src/core/extensions/wrapper.ts new file mode 100644 index 00000000..cd98fa97 --- /dev/null +++ b/packages/coding-agent/src/core/extensions/wrapper.ts @@ -0,0 +1,119 @@ +/** + * Tool wrappers for extensions. + */ + +import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core"; +import type { ExtensionRunner } from "./runner.js"; +import type { ExtensionContext, RegisteredTool, ToolCallEventResult, ToolResultEventResult } from "./types.js"; + +/** + * Wrap a RegisteredTool into an AgentTool. + */ +export function wrapRegisteredTool(registeredTool: RegisteredTool, getContext: () => ExtensionContext): AgentTool { + const { definition } = registeredTool; + return { + name: definition.name, + label: definition.label, + description: definition.description, + parameters: definition.parameters, + execute: (toolCallId, params, signal, onUpdate) => + definition.execute(toolCallId, params, onUpdate, getContext(), signal), + }; +} + +/** + * Wrap all registered tools into AgentTools. + */ +export function wrapRegisteredTools( + registeredTools: RegisteredTool[], + getContext: () => ExtensionContext, +): AgentTool[] { + return registeredTools.map((rt) => wrapRegisteredTool(rt, getContext)); +} + +/** + * Wrap a tool with extension callbacks for interception. + * - Emits tool_call event before execution (can block) + * - Emits tool_result event after execution (can modify result) + */ +export function wrapToolWithExtensions(tool: AgentTool, runner: ExtensionRunner): AgentTool { + return { + ...tool, + execute: async ( + toolCallId: string, + params: Record, + signal?: AbortSignal, + onUpdate?: AgentToolUpdateCallback, + ) => { + // Emit tool_call event - extensions can block execution + if (runner.hasHandlers("tool_call")) { + try { + const callResult = (await runner.emitToolCall({ + type: "tool_call", + toolName: tool.name, + toolCallId, + input: params, + })) as ToolCallEventResult | undefined; + + if (callResult?.block) { + const reason = callResult.reason || "Tool execution was blocked by an extension"; + throw new Error(reason); + } + } catch (err) { + if (err instanceof Error) { + throw err; + } + throw new Error(`Extension failed, blocking execution: ${String(err)}`); + } + } + + // Execute the actual tool + try { + const result = await tool.execute(toolCallId, params, signal, onUpdate); + + // Emit tool_result event - extensions can modify the result + if (runner.hasHandlers("tool_result")) { + const resultResult = (await runner.emit({ + type: "tool_result", + toolName: tool.name, + toolCallId, + input: params, + content: result.content, + details: result.details, + isError: false, + })) as ToolResultEventResult | undefined; + + if (resultResult) { + return { + content: resultResult.content ?? result.content, + details: (resultResult.details ?? result.details) as T, + }; + } + } + + return result; + } catch (err) { + // Emit tool_result event for errors + if (runner.hasHandlers("tool_result")) { + await runner.emit({ + type: "tool_result", + toolName: tool.name, + toolCallId, + input: params, + content: [{ type: "text", text: err instanceof Error ? err.message : String(err) }], + details: undefined, + isError: true, + }); + } + throw err; + } + }, + }; +} + +/** + * Wrap all tools with extension callbacks. + */ +export function wrapToolsWithExtensions(tools: AgentTool[], runner: ExtensionRunner): AgentTool[] { + return tools.map((tool) => wrapToolWithExtensions(tool, runner)); +}