co-mono/packages/coding-agent/src/core/extensions/runner.ts
2026-01-17 21:18:59 +01:00

584 lines
17 KiB
TypeScript

/**
* Extension runner - executes extensions and manages their lifecycle.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Model } from "@mariozechner/pi-ai";
import type { KeyId } from "@mariozechner/pi-tui";
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js";
import type {
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
CompactOptions,
ContextEvent,
ContextEventResult,
ContextUsage,
Extension,
ExtensionActions,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
ExtensionError,
ExtensionEvent,
ExtensionFlag,
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
InputEvent,
InputEventResult,
InputSource,
MessageRenderer,
RegisteredCommand,
RegisteredTool,
SessionBeforeCompactResult,
SessionBeforeTreeResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEventResult,
UserBashEvent,
UserBashEventResult,
} from "./types.js";
/** Combined result from all before_agent_start handlers */
interface BeforeAgentStartCombinedResult {
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
systemPrompt?: string;
}
export type ExtensionErrorListener = (error: ExtensionError) => void;
export type NewSessionHandler = (options?: {
parentSession?: string;
setup?: (sessionManager: SessionManager) => Promise<void>;
}) => Promise<{ cancelled: boolean }>;
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
export type NavigateTreeHandler = (
targetId: string,
options?: { summarize?: boolean; customInstructions?: string; replaceInstructions?: boolean; label?: string },
) => Promise<{ cancelled: boolean }>;
export type ShutdownHandler = () => void;
/**
* Helper function to emit session_shutdown event to extensions.
* Returns true if the event was emitted, false if there were no handlers.
*/
export async function emitSessionShutdownEvent(extensionRunner: ExtensionRunner | undefined): Promise<boolean> {
if (extensionRunner?.hasHandlers("session_shutdown")) {
await extensionRunner.emit({
type: "session_shutdown",
});
return true;
}
return false;
}
const noOpUIContext: ExtensionUIContext = {
select: async () => undefined,
confirm: async () => false,
input: async () => undefined,
notify: () => {},
setStatus: () => {},
setWorkingMessage: () => {},
setWidget: () => {},
setFooter: () => {},
setHeader: () => {},
setTitle: () => {},
custom: async () => undefined as never,
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
getAllThemes: () => [],
getTheme: () => undefined,
setTheme: (_theme: string | Theme) => ({ success: false, error: "UI not available" }),
};
export class ExtensionRunner {
private extensions: Extension[];
private runtime: ExtensionRuntime;
private uiContext: ExtensionUIContext;
private cwd: string;
private sessionManager: SessionManager;
private modelRegistry: ModelRegistry;
private errorListeners: Set<ExtensionErrorListener> = new Set();
private getModel: () => Model<any> | undefined = () => undefined;
private isIdleFn: () => boolean = () => true;
private waitForIdleFn: () => Promise<void> = async () => {};
private abortFn: () => void = () => {};
private hasPendingMessagesFn: () => boolean = () => false;
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
private compactFn: (options?: CompactOptions) => void = () => {};
private newSessionHandler: NewSessionHandler = async () => ({ cancelled: false });
private forkHandler: ForkHandler = async () => ({ cancelled: false });
private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false });
private shutdownHandler: ShutdownHandler = () => {};
constructor(
extensions: Extension[],
runtime: ExtensionRuntime,
cwd: string,
sessionManager: SessionManager,
modelRegistry: ModelRegistry,
) {
this.extensions = extensions;
this.runtime = runtime;
this.uiContext = noOpUIContext;
this.cwd = cwd;
this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry;
}
initialize(
actions: ExtensionActions,
contextActions: ExtensionContextActions,
commandContextActions?: ExtensionCommandContextActions,
uiContext?: ExtensionUIContext,
): void {
// Copy actions into the shared runtime (all extension APIs reference this)
this.runtime.sendMessage = actions.sendMessage;
this.runtime.sendUserMessage = actions.sendUserMessage;
this.runtime.appendEntry = actions.appendEntry;
this.runtime.setSessionName = actions.setSessionName;
this.runtime.getSessionName = actions.getSessionName;
this.runtime.setLabel = actions.setLabel;
this.runtime.getActiveTools = actions.getActiveTools;
this.runtime.getAllTools = actions.getAllTools;
this.runtime.setActiveTools = actions.setActiveTools;
this.runtime.setModel = actions.setModel;
this.runtime.getThinkingLevel = actions.getThinkingLevel;
this.runtime.setThinkingLevel = actions.setThinkingLevel;
// Context actions (required)
this.getModel = contextActions.getModel;
this.isIdleFn = contextActions.isIdle;
this.abortFn = contextActions.abort;
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact;
// Command context actions (optional, only for interactive mode)
if (commandContextActions) {
this.waitForIdleFn = commandContextActions.waitForIdle;
this.newSessionHandler = commandContextActions.newSession;
this.forkHandler = commandContextActions.fork;
this.navigateTreeHandler = commandContextActions.navigateTree;
}
this.uiContext = uiContext ?? noOpUIContext;
}
getUIContext(): ExtensionUIContext {
return this.uiContext;
}
hasUI(): boolean {
return this.uiContext !== noOpUIContext;
}
getExtensionPaths(): string[] {
return this.extensions.map((e) => e.path);
}
/** Get all registered tools from all extensions. */
getAllRegisteredTools(): RegisteredTool[] {
const tools: RegisteredTool[] = [];
for (const ext of this.extensions) {
for (const tool of ext.tools.values()) {
tools.push(tool);
}
}
return tools;
}
/** Get a tool definition by name. Returns undefined if not found. */
getToolDefinition(toolName: string): RegisteredTool["definition"] | undefined {
for (const ext of this.extensions) {
const tool = ext.tools.get(toolName);
if (tool) {
return tool.definition;
}
}
return undefined;
}
getFlags(): Map<string, ExtensionFlag> {
const allFlags = new Map<string, ExtensionFlag>();
for (const ext of this.extensions) {
for (const [name, flag] of ext.flags) {
allFlags.set(name, flag);
}
}
return allFlags;
}
setFlagValue(name: string, value: boolean | string): void {
this.runtime.flagValues.set(name, value);
}
private static readonly RESERVED_SHORTCUTS = new Set([
"ctrl+c",
"ctrl+d",
"ctrl+z",
"ctrl+k",
"ctrl+p",
"ctrl+l",
"ctrl+o",
"ctrl+t",
"ctrl+g",
"shift+tab",
"shift+ctrl+p",
"alt+enter",
"escape",
"enter",
]);
getShortcuts(): Map<KeyId, ExtensionShortcut> {
const allShortcuts = new Map<KeyId, ExtensionShortcut>();
for (const ext of this.extensions) {
for (const [key, shortcut] of ext.shortcuts) {
const normalizedKey = key.toLowerCase() as KeyId;
if (ExtensionRunner.RESERVED_SHORTCUTS.has(normalizedKey)) {
console.warn(
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
);
continue;
}
const existing = allShortcuts.get(normalizedKey);
if (existing) {
console.warn(
`Extension shortcut conflict: '${key}' registered by both ${existing.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
);
}
allShortcuts.set(normalizedKey, shortcut);
}
}
return allShortcuts;
}
onError(listener: ExtensionErrorListener): () => void {
this.errorListeners.add(listener);
return () => this.errorListeners.delete(listener);
}
emitError(error: ExtensionError): void {
for (const listener of this.errorListeners) {
listener(error);
}
}
hasHandlers(eventType: string): boolean {
for (const ext of this.extensions) {
const handlers = ext.handlers.get(eventType);
if (handlers && handlers.length > 0) {
return true;
}
}
return false;
}
getMessageRenderer(customType: string): MessageRenderer | undefined {
for (const ext of this.extensions) {
const renderer = ext.messageRenderers.get(customType);
if (renderer) {
return renderer;
}
}
return undefined;
}
getRegisteredCommands(): RegisteredCommand[] {
const commands: RegisteredCommand[] = [];
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
commands.push(command);
}
}
return commands;
}
getCommand(name: string): RegisteredCommand | undefined {
for (const ext of this.extensions) {
const command = ext.commands.get(name);
if (command) {
return command;
}
}
return undefined;
}
/**
* Request a graceful shutdown. Called by extension tools and event handlers.
* The actual shutdown behavior is provided by the mode via initialize().
*/
shutdown(): void {
this.shutdownHandler();
}
/**
* Create an ExtensionContext for use in event handlers and tool execution.
* Context values are resolved at call time, so changes via initialize() are reflected.
*/
createContext(): ExtensionContext {
const getModel = this.getModel;
return {
ui: this.uiContext,
hasUI: this.hasUI(),
cwd: this.cwd,
sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry,
get model() {
return getModel();
},
isIdle: () => this.isIdleFn(),
abort: () => this.abortFn(),
hasPendingMessages: () => this.hasPendingMessagesFn(),
shutdown: () => this.shutdownHandler(),
getContextUsage: () => this.getContextUsageFn(),
compact: (options) => this.compactFn(options),
};
}
createCommandContext(): ExtensionCommandContext {
return {
...this.createContext(),
waitForIdle: () => this.waitForIdleFn(),
newSession: (options) => this.newSessionHandler(options),
fork: (entryId) => this.forkHandler(entryId),
navigateTree: (targetId, options) => this.navigateTreeHandler(targetId, options),
};
}
private isSessionBeforeEvent(
type: string,
): type is "session_before_switch" | "session_before_fork" | "session_before_compact" | "session_before_tree" {
return (
type === "session_before_switch" ||
type === "session_before_fork" ||
type === "session_before_compact" ||
type === "session_before_tree"
);
}
async emit(
event: ExtensionEvent,
): Promise<SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined> {
const ctx = this.createContext();
let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get(event.type);
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (this.isSessionBeforeEvent(event.type) && handlerResult) {
result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult;
if (result.cancel) {
return result;
}
}
if (event.type === "tool_result" && handlerResult) {
result = handlerResult as ToolResultEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: event.type,
error: message,
stack,
});
}
}
}
return result;
}
async emitToolCall(event: ToolCallEvent): Promise<ToolCallEventResult | undefined> {
const ctx = this.createContext();
let result: ToolCallEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_call");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
result = handlerResult as ToolCallEventResult;
if (result.block) {
return result;
}
}
}
}
return result;
}
async emitUserBash(event: UserBashEvent): Promise<UserBashEventResult | undefined> {
const ctx = this.createContext();
for (const ext of this.extensions) {
const handlers = ext.handlers.get("user_bash");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
return handlerResult as UserBashEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "user_bash",
error: message,
stack,
});
}
}
}
return undefined;
}
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
const ctx = this.createContext();
let currentMessages = structuredClone(messages);
for (const ext of this.extensions) {
const handlers = ext.handlers.get("context");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ContextEvent = { type: "context", messages: currentMessages };
const handlerResult = await handler(event, ctx);
if (handlerResult && (handlerResult as ContextEventResult).messages) {
currentMessages = (handlerResult as ContextEventResult).messages!;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "context",
error: message,
stack,
});
}
}
}
return currentMessages;
}
async emitBeforeAgentStart(
prompt: string,
images: ImageContent[] | undefined,
systemPrompt: string,
): Promise<BeforeAgentStartCombinedResult | undefined> {
const ctx = this.createContext();
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
let currentSystemPrompt = systemPrompt;
let systemPromptModified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("before_agent_start");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeAgentStartEvent = {
type: "before_agent_start",
prompt,
images,
systemPrompt: currentSystemPrompt,
};
const handlerResult = await handler(event, ctx);
if (handlerResult) {
const result = handlerResult as BeforeAgentStartEventResult;
if (result.message) {
messages.push(result.message);
}
if (result.systemPrompt !== undefined) {
currentSystemPrompt = result.systemPrompt;
systemPromptModified = true;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "before_agent_start",
error: message,
stack,
});
}
}
}
if (messages.length > 0 || systemPromptModified) {
return {
messages: messages.length > 0 ? messages : undefined,
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
};
}
return undefined;
}
/** Emit input event. Transforms chain, "handled" short-circuits. */
async emitInput(text: string, images: ImageContent[] | undefined, source: InputSource): Promise<InputEventResult> {
const ctx = this.createContext();
let currentText = text;
let currentImages = images;
for (const ext of this.extensions) {
for (const handler of ext.handlers.get("input") ?? []) {
try {
const event: InputEvent = { type: "input", text: currentText, images: currentImages, source };
const result = (await handler(event, ctx)) as InputEventResult | undefined;
if (result?.action === "handled") return result;
if (result?.action === "transform") {
currentText = result.text;
currentImages = result.images ?? currentImages;
}
} catch (err) {
this.emitError({
extensionPath: ext.path,
event: "input",
error: err instanceof Error ? err.message : String(err),
stack: err instanceof Error ? err.stack : undefined,
});
}
}
}
return currentText !== text || currentImages !== images
? { action: "transform", text: currentText, images: currentImages }
: { action: "continue" };
}
}