/** * Extension runner - executes extensions and manages their lifecycle. */ import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { ImageContent, Model } from "@mariozechner/pi-ai"; import type { KeyId } from "@mariozechner/pi-tui"; import { type Theme, theme } from "../../modes/interactive/theme/theme.js"; import type { ModelRegistry } from "../model-registry.js"; import type { SessionManager } from "../session-manager.js"; import type { BeforeAgentStartEvent, BeforeAgentStartEventResult, ContextEvent, ContextEventResult, Extension, ExtensionActions, ExtensionCommandContext, ExtensionCommandContextActions, ExtensionContext, ExtensionContextActions, ExtensionError, ExtensionEvent, ExtensionFlag, ExtensionRuntime, ExtensionShortcut, ExtensionUIContext, MessageRenderer, RegisteredCommand, RegisteredTool, SessionBeforeCompactResult, SessionBeforeTreeResult, ToolCallEvent, ToolCallEventResult, ToolResultEventResult, UserBashEvent, UserBashEventResult, } from "./types.js"; /** Combined result from all before_agent_start handlers */ interface BeforeAgentStartCombinedResult { messages?: NonNullable[]; systemPrompt?: string; } export type ExtensionErrorListener = (error: ExtensionError) => void; export type NewSessionHandler = (options?: { parentSession?: string; setup?: (sessionManager: SessionManager) => Promise; }) => Promise<{ cancelled: boolean }>; export type BranchHandler = (entryId: string) => Promise<{ cancelled: boolean }>; export type NavigateTreeHandler = ( targetId: string, options?: { summarize?: boolean }, ) => Promise<{ cancelled: boolean }>; export type ShutdownHandler = () => void; /** * Helper function to emit session_shutdown event to extensions. * Returns true if the event was emitted, false if there were no handlers. */ export async function emitSessionShutdownEvent(extensionRunner: ExtensionRunner | undefined): Promise { if (extensionRunner?.hasHandlers("session_shutdown")) { await extensionRunner.emit({ type: "session_shutdown", }); return true; } return false; } const noOpUIContext: ExtensionUIContext = { select: async () => undefined, confirm: async () => false, input: async () => undefined, notify: () => {}, setStatus: () => {}, setWorkingMessage: () => {}, setWidget: () => {}, setFooter: () => {}, setHeader: () => {}, setTitle: () => {}, custom: async () => undefined as never, setEditorText: () => {}, getEditorText: () => "", editor: async () => undefined, setEditorComponent: () => {}, get theme() { return theme; }, getAllThemes: () => [], getTheme: () => undefined, setTheme: (_theme: string | Theme) => ({ success: false, error: "UI not available" }), }; export class ExtensionRunner { private extensions: Extension[]; private runtime: ExtensionRuntime; private uiContext: ExtensionUIContext; private cwd: string; private sessionManager: SessionManager; private modelRegistry: ModelRegistry; private errorListeners: Set = new Set(); private getModel: () => Model | undefined = () => undefined; private isIdleFn: () => boolean = () => true; private waitForIdleFn: () => Promise = async () => {}; private abortFn: () => void = () => {}; private hasPendingMessagesFn: () => boolean = () => false; private newSessionHandler: NewSessionHandler = async () => ({ cancelled: false }); private branchHandler: BranchHandler = async () => ({ cancelled: false }); private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false }); private shutdownHandler: ShutdownHandler = () => {}; constructor( extensions: Extension[], runtime: ExtensionRuntime, cwd: string, sessionManager: SessionManager, modelRegistry: ModelRegistry, ) { this.extensions = extensions; this.runtime = runtime; this.uiContext = noOpUIContext; this.cwd = cwd; this.sessionManager = sessionManager; this.modelRegistry = modelRegistry; } initialize( actions: ExtensionActions, contextActions: ExtensionContextActions, commandContextActions?: ExtensionCommandContextActions, uiContext?: ExtensionUIContext, ): void { // Copy actions into the shared runtime (all extension APIs reference this) this.runtime.sendMessage = actions.sendMessage; this.runtime.sendUserMessage = actions.sendUserMessage; this.runtime.appendEntry = actions.appendEntry; this.runtime.getActiveTools = actions.getActiveTools; this.runtime.getAllTools = actions.getAllTools; this.runtime.setActiveTools = actions.setActiveTools; this.runtime.setModel = actions.setModel; this.runtime.getThinkingLevel = actions.getThinkingLevel; this.runtime.setThinkingLevel = actions.setThinkingLevel; // Context actions (required) this.getModel = contextActions.getModel; this.isIdleFn = contextActions.isIdle; this.abortFn = contextActions.abort; this.hasPendingMessagesFn = contextActions.hasPendingMessages; this.shutdownHandler = contextActions.shutdown; // Command context actions (optional, only for interactive mode) if (commandContextActions) { this.waitForIdleFn = commandContextActions.waitForIdle; this.newSessionHandler = commandContextActions.newSession; this.branchHandler = commandContextActions.branch; this.navigateTreeHandler = commandContextActions.navigateTree; } this.uiContext = uiContext ?? noOpUIContext; } getUIContext(): ExtensionUIContext { return this.uiContext; } hasUI(): boolean { return this.uiContext !== noOpUIContext; } getExtensionPaths(): string[] { return this.extensions.map((e) => e.path); } /** Get all registered tools from all extensions. */ getAllRegisteredTools(): RegisteredTool[] { const tools: RegisteredTool[] = []; for (const ext of this.extensions) { for (const tool of ext.tools.values()) { tools.push(tool); } } return tools; } getFlags(): Map { const allFlags = new Map(); for (const ext of this.extensions) { for (const [name, flag] of ext.flags) { allFlags.set(name, flag); } } return allFlags; } setFlagValue(name: string, value: boolean | string): void { this.runtime.flagValues.set(name, value); } private static readonly RESERVED_SHORTCUTS = new Set([ "ctrl+c", "ctrl+d", "ctrl+z", "ctrl+k", "ctrl+p", "ctrl+l", "ctrl+o", "ctrl+t", "ctrl+g", "shift+tab", "shift+ctrl+p", "alt+enter", "escape", "enter", ]); getShortcuts(): Map { const allShortcuts = new Map(); for (const ext of this.extensions) { for (const [key, shortcut] of ext.shortcuts) { const normalizedKey = key.toLowerCase() as KeyId; if (ExtensionRunner.RESERVED_SHORTCUTS.has(normalizedKey)) { console.warn( `Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`, ); continue; } const existing = allShortcuts.get(normalizedKey); if (existing) { console.warn( `Extension shortcut conflict: '${key}' registered by both ${existing.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`, ); } allShortcuts.set(normalizedKey, shortcut); } } return allShortcuts; } onError(listener: ExtensionErrorListener): () => void { this.errorListeners.add(listener); return () => this.errorListeners.delete(listener); } emitError(error: ExtensionError): void { for (const listener of this.errorListeners) { listener(error); } } hasHandlers(eventType: string): boolean { for (const ext of this.extensions) { const handlers = ext.handlers.get(eventType); if (handlers && handlers.length > 0) { return true; } } return false; } getMessageRenderer(customType: string): MessageRenderer | undefined { for (const ext of this.extensions) { const renderer = ext.messageRenderers.get(customType); if (renderer) { return renderer; } } return undefined; } getRegisteredCommands(): RegisteredCommand[] { const commands: RegisteredCommand[] = []; for (const ext of this.extensions) { for (const command of ext.commands.values()) { commands.push(command); } } return commands; } getCommand(name: string): RegisteredCommand | undefined { for (const ext of this.extensions) { const command = ext.commands.get(name); if (command) { return command; } } return undefined; } /** * Request a graceful shutdown. Called by extension tools and event handlers. * The actual shutdown behavior is provided by the mode via initialize(). */ shutdown(): void { this.shutdownHandler(); } /** * Create an ExtensionContext for use in event handlers and tool execution. * Context values are resolved at call time, so changes via initialize() are reflected. */ createContext(): ExtensionContext { return { ui: this.uiContext, hasUI: this.hasUI(), cwd: this.cwd, sessionManager: this.sessionManager, modelRegistry: this.modelRegistry, model: this.getModel(), isIdle: () => this.isIdleFn(), abort: () => this.abortFn(), hasPendingMessages: () => this.hasPendingMessagesFn(), shutdown: () => this.shutdownHandler(), }; } createCommandContext(): ExtensionCommandContext { return { ...this.createContext(), waitForIdle: () => this.waitForIdleFn(), newSession: (options) => this.newSessionHandler(options), branch: (entryId) => this.branchHandler(entryId), navigateTree: (targetId, options) => this.navigateTreeHandler(targetId, options), }; } private isSessionBeforeEvent( type: string, ): type is "session_before_switch" | "session_before_branch" | "session_before_compact" | "session_before_tree" { return ( type === "session_before_switch" || type === "session_before_branch" || type === "session_before_compact" || type === "session_before_tree" ); } async emit( event: ExtensionEvent, ): Promise { const ctx = this.createContext(); let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined; for (const ext of this.extensions) { const handlers = ext.handlers.get(event.type); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { try { const handlerResult = await handler(event, ctx); if (this.isSessionBeforeEvent(event.type) && handlerResult) { result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult; if (result.cancel) { return result; } } if (event.type === "tool_result" && handlerResult) { result = handlerResult as ToolResultEventResult; } } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: event.type, error: message, stack, }); } } } return result; } async emitToolCall(event: ToolCallEvent): Promise { const ctx = this.createContext(); let result: ToolCallEventResult | undefined; for (const ext of this.extensions) { const handlers = ext.handlers.get("tool_call"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const handlerResult = await handler(event, ctx); if (handlerResult) { result = handlerResult as ToolCallEventResult; if (result.block) { return result; } } } } return result; } async emitUserBash(event: UserBashEvent): Promise { const ctx = this.createContext(); for (const ext of this.extensions) { const handlers = ext.handlers.get("user_bash"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { try { const handlerResult = await handler(event, ctx); if (handlerResult) { return handlerResult as UserBashEventResult; } } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: "user_bash", error: message, stack, }); } } } return undefined; } async emitContext(messages: AgentMessage[]): Promise { const ctx = this.createContext(); let currentMessages = structuredClone(messages); for (const ext of this.extensions) { const handlers = ext.handlers.get("context"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { try { const event: ContextEvent = { type: "context", messages: currentMessages }; const handlerResult = await handler(event, ctx); if (handlerResult && (handlerResult as ContextEventResult).messages) { currentMessages = (handlerResult as ContextEventResult).messages!; } } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: "context", error: message, stack, }); } } } return currentMessages; } async emitBeforeAgentStart( prompt: string, images: ImageContent[] | undefined, systemPrompt: string, ): Promise { const ctx = this.createContext(); const messages: NonNullable[] = []; let currentSystemPrompt = systemPrompt; let systemPromptModified = false; for (const ext of this.extensions) { const handlers = ext.handlers.get("before_agent_start"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { try { const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images, systemPrompt: currentSystemPrompt, }; const handlerResult = await handler(event, ctx); if (handlerResult) { const result = handlerResult as BeforeAgentStartEventResult; if (result.message) { messages.push(result.message); } if (result.systemPrompt !== undefined) { currentSystemPrompt = result.systemPrompt; systemPromptModified = true; } } } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: "before_agent_start", error: message, stack, }); } } } if (messages.length > 0 || systemPromptModified) { return { messages: messages.length > 0 ? messages : undefined, systemPrompt: systemPromptModified ? currentSystemPrompt : undefined, }; } return undefined; } }