From cb3ac0ba9e82ba06ca309f7da4fef7e68bf9ef00 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Wed, 7 Jan 2026 23:50:18 +0100 Subject: [PATCH] refactor(coding-agent): simplify extension runtime architecture - Replace per-extension closures with shared ExtensionRuntime - Split context actions: ExtensionContextActions (required) + ExtensionCommandContextActions (optional) - Rename LoadedExtension to Extension, remove setter methods - Change runner.initialize() from options object to positional params - Derive hasUI from uiContext presence (no separate param) - Add warning when extensions override built-in tools - RPC and print modes now provide full command context actions BREAKING CHANGE: Extension system types and initialization API changed. See CHANGELOG.md for migration details. --- packages/coding-agent/CHANGELOG.md | 20 + packages/coding-agent/docs/extensions.md | 15 + packages/coding-agent/docs/sdk.md | 13 +- .../coding-agent/src/core/extensions/index.ts | 15 +- .../src/core/extensions/loader.ts | 425 +++++------------- .../src/core/extensions/runner.ts | 115 ++--- .../coding-agent/src/core/extensions/types.ts | 75 +++- packages/coding-agent/src/core/index.ts | 2 +- packages/coding-agent/src/core/sdk.ts | 65 +-- packages/coding-agent/src/index.ts | 7 +- packages/coding-agent/src/main.ts | 30 +- .../src/modes/interactive/interactive-mode.ts | 237 +++++----- packages/coding-agent/src/modes/print-mode.ts | 80 ++-- .../coding-agent/src/modes/rpc/rpc-mode.ts | 80 ++-- .../test/compaction-extensions.test.ts | 112 ++--- .../test/extensions-runner.test.ts | 25 +- 16 files changed, 580 insertions(+), 736 deletions(-) diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index 4363836a..c4f322b2 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -5,11 +5,31 @@ ### Breaking Changes - `ctx.ui.custom()` factory signature changed from `(tui, theme, done)` to `(tui, theme, keybindings, done)` for consistency with other input-handling factories +- Extension system refactored: `LoadedExtension` renamed to `Extension`, setter methods removed +- `LoadExtensionsResult.setUIContext()` removed, replaced with `runtime: ExtensionRuntime` +- `ExtensionRunner` constructor now requires `runtime: ExtensionRuntime` as second parameter +- `ExtensionRunner.initialize()` signature changed from options object to `(actions, contextActions, commandContextActions?, uiContext?)` +- `ExtensionRunner.getHasUI()` renamed to `hasUI()` +- `CreateAgentSessionOptions.preloadedExtensions` renamed to `preloadedExtensionsResult` +- `CreateAgentSessionResult` now returns `extensionsResult: LoadExtensionsResult` instead of `customToolsResult` ### Added - Extension UI dialogs (`ctx.ui.select()`, `ctx.ui.confirm()`, `ctx.ui.input()`) now support a `timeout` option that auto-dismisses the dialog with a live countdown display. Simpler alternative to `AbortSignal` for timed dialogs. - Extensions can now provide custom editor components via `ctx.ui.setEditorComponent((tui, theme, keybindings) => ...)`. Extend `CustomEditor` for full app keybinding support (escape, ctrl+d, model switching, etc.). See `examples/extensions/modal-editor.ts`, `examples/extensions/rainbow-editor.ts`, and `docs/tui.md` Pattern 7. +- `ExtensionRuntime` interface for shared runtime state and action methods +- `ExtensionActions` interface for `pi.*` API methods +- `ExtensionContextActions` interface for `ctx.*` in event handlers +- `ExtensionCommandContextActions` interface for `ctx.*` in command handlers (session control) +- `createExtensionRuntime()` function to create runtime with throwing stubs +- `Extension` type exported (cleaner name for loaded extension data) +- Interactive mode now warns when extensions override built-in tools (read, bash, edit, write, grep, find, ls) + +### Changed + +- Extension loader simplified: shared runtime instead of per-extension closures +- `hasUI` now derived from whether `uiContext` is provided (no longer a separate parameter) +- RPC and print modes now provide `ExtensionCommandContextActions` for full command support ### Fixed diff --git a/packages/coding-agent/docs/extensions.md b/packages/coding-agent/docs/extensions.md index 75718389..97a163e3 100644 --- a/packages/coding-agent/docs/extensions.md +++ b/packages/coding-agent/docs/extensions.md @@ -924,6 +924,21 @@ pi.registerTool({ **Important:** Use `StringEnum` from `@mariozechner/pi-ai` for string enums. `Type.Union`/`Type.Literal` doesn't work with Google's API. +### Overriding Built-in Tools + +Extensions can override built-in tools (`read`, `bash`, `edit`, `write`, `grep`, `find`, `ls`) by registering a tool with the same name. Interactive mode displays a warning when this happens. + +**Your implementation must match the exact result shape**, including the `details` type. The UI and session logic depend on these shapes for rendering and state tracking. + +Built-in tool implementations: +- [read.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/read.ts) - `ReadToolDetails` +- [bash.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/bash.ts) - `BashToolDetails` +- [edit.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/edit.ts) +- [write.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/write.ts) +- [grep.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/grep.ts) - `GrepToolDetails` +- [find.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/find.ts) - `FindToolDetails` +- [ls.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/ls.ts) - `LsToolDetails` + ### Output Truncation **Tools MUST truncate their output** to avoid overwhelming the LLM context. Large outputs can cause: diff --git a/packages/coding-agent/docs/sdk.md b/packages/coding-agent/docs/sdk.md index 8d62c0f4..b2a29193 100644 --- a/packages/coding-agent/docs/sdk.md +++ b/packages/coding-agent/docs/sdk.md @@ -784,15 +784,18 @@ interface CreateAgentSessionResult { // The session session: AgentSession; - // Custom tools (for UI setup) - customToolsResult: { - tools: LoadedCustomTool[]; - setUIContext: (ctx, hasUI) => void; - }; + // Extensions result (for runner setup) + extensionsResult: LoadExtensionsResult; // Warning if session model couldn't be restored modelFallbackMessage?: string; } + +interface LoadExtensionsResult { + extensions: Extension[]; + errors: Array<{ path: string; error: string }>; + runtime: ExtensionRuntime; +} ``` ## Complete Example diff --git a/packages/coding-agent/src/core/extensions/index.ts b/packages/coding-agent/src/core/extensions/index.ts index 66f8756f..755f4ee6 100644 --- a/packages/coding-agent/src/core/extensions/index.ts +++ b/packages/coding-agent/src/core/extensions/index.ts @@ -2,7 +2,12 @@ * Extension system for lifecycle events and custom tools. */ -export { discoverAndLoadExtensions, loadExtensionFromFactory, loadExtensions } from "./loader.js"; +export { + createExtensionRuntime, + discoverAndLoadExtensions, + loadExtensionFromFactory, + loadExtensions, +} from "./loader.js"; export type { BranchHandler, ExtensionErrorListener, NavigateTreeHandler, NewSessionHandler } from "./runner.js"; export { ExtensionRunner } from "./runner.js"; export type { @@ -25,17 +30,23 @@ export type { EditToolResultEvent, ExecOptions, ExecResult, + Extension, + ExtensionActions, // API ExtensionAPI, ExtensionCommandContext, + ExtensionCommandContextActions, // Context ExtensionContext, + ExtensionContextActions, // Errors ExtensionError, ExtensionEvent, ExtensionFactory, ExtensionFlag, ExtensionHandler, + // Runtime + ExtensionRuntime, ExtensionShortcut, ExtensionUIContext, ExtensionUIDialogOptions, @@ -46,8 +57,6 @@ export type { GrepToolResultEvent, KeybindingsManager, LoadExtensionsResult, - // Loaded Extension - LoadedExtension, LsToolResultEvent, // Message Rendering MessageRenderer, diff --git a/packages/coding-agent/src/core/extensions/loader.ts b/packages/coding-agent/src/core/extensions/loader.ts index 16460b16..917cb9ee 100644 --- a/packages/coding-agent/src/core/extensions/loader.ts +++ b/packages/coding-agent/src/core/extensions/loader.ts @@ -10,30 +10,17 @@ import { fileURLToPath } from "node:url"; import type { KeyId } from "@mariozechner/pi-tui"; import { createJiti } from "jiti"; import { getAgentDir, isBunBinary } from "../../config.js"; -import { theme } from "../../modes/interactive/theme/theme.js"; import { createEventBus, type EventBus } from "../event-bus.js"; import type { ExecOptions } from "../exec.js"; import { execCommand } from "../exec.js"; import type { - AppendEntryHandler, + Extension, ExtensionAPI, ExtensionFactory, - ExtensionFlag, - ExtensionShortcut, - ExtensionUIContext, - GetActiveToolsHandler, - GetAllToolsHandler, - GetThinkingLevelHandler, + ExtensionRuntime, LoadExtensionsResult, - LoadedExtension, MessageRenderer, RegisteredCommand, - RegisteredTool, - SendMessageHandler, - SendUserMessageHandler, - SetActiveToolsHandler, - SetModelHandler, - SetThinkingLevelHandler, ToolDefinition, } from "./types.js"; @@ -84,87 +71,59 @@ function resolvePath(extPath: string, cwd: string): string { return path.resolve(cwd, expanded); } -function createNoOpUIContext(): ExtensionUIContext { +type HandlerFn = (...args: unknown[]) => Promise; + +/** + * Create a runtime with throwing stubs for action methods. + * Runner.initialize() replaces these with real implementations. + */ +export function createExtensionRuntime(): ExtensionRuntime { + const notInitialized = () => { + throw new Error("Extension runtime not initialized. Action methods cannot be called during extension loading."); + }; + return { - select: async () => undefined, - confirm: async () => false, - input: async () => undefined, - notify: () => {}, - setStatus: () => {}, - setWidget: () => {}, - setFooter: () => {}, - setHeader: () => {}, - setTitle: () => {}, - custom: async () => undefined as never, - setEditorText: () => {}, - getEditorText: () => "", - editor: async () => undefined, - setEditorComponent: () => {}, - get theme() { - return theme; - }, + sendMessage: notInitialized, + sendUserMessage: notInitialized, + appendEntry: notInitialized, + getActiveTools: notInitialized, + getAllTools: notInitialized, + setActiveTools: notInitialized, + setModel: () => Promise.reject(new Error("Extension runtime not initialized")), + getThinkingLevel: notInitialized, + setThinkingLevel: notInitialized, + flagValues: new Map(), }; } -type HandlerFn = (...args: unknown[]) => Promise; - +/** + * Create the ExtensionAPI for an extension. + * Registration methods write to the extension object. + * Action methods delegate to the shared runtime. + */ function createExtensionAPI( - handlers: Map, - tools: Map, + extension: Extension, + runtime: ExtensionRuntime, cwd: string, - extensionPath: string, eventBus: EventBus, - _sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, -): { - api: ExtensionAPI; - messageRenderers: Map; - commands: Map; - flags: Map; - flagValues: Map; - shortcuts: Map; - setSendMessageHandler: (handler: SendMessageHandler) => void; - setSendUserMessageHandler: (handler: SendUserMessageHandler) => void; - setAppendEntryHandler: (handler: AppendEntryHandler) => void; - setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void; - setGetAllToolsHandler: (handler: GetAllToolsHandler) => void; - setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void; - setSetModelHandler: (handler: SetModelHandler) => void; - setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => void; - setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => void; - setFlagValue: (name: string, value: boolean | string) => void; -} { - let sendMessageHandler: SendMessageHandler = () => {}; - let sendUserMessageHandler: SendUserMessageHandler = () => {}; - let appendEntryHandler: AppendEntryHandler = () => {}; - let getActiveToolsHandler: GetActiveToolsHandler = () => []; - let getAllToolsHandler: GetAllToolsHandler = () => []; - let setActiveToolsHandler: SetActiveToolsHandler = () => {}; - let setModelHandler: SetModelHandler = async () => false; - let getThinkingLevelHandler: GetThinkingLevelHandler = () => "off"; - let setThinkingLevelHandler: SetThinkingLevelHandler = () => {}; - - const messageRenderers = new Map(); - const commands = new Map(); - const flags = new Map(); - const flagValues = new Map(); - const shortcuts = new Map(); - +): ExtensionAPI { const api = { + // Registration methods - write to extension on(event: string, handler: HandlerFn): void { - const list = handlers.get(event) ?? []; + const list = extension.handlers.get(event) ?? []; list.push(handler); - handlers.set(event, list); + extension.handlers.set(event, list); }, registerTool(tool: ToolDefinition): void { - tools.set(tool.name, { + extension.tools.set(tool.name, { definition: tool, - extensionPath, + extensionPath: extension.path, }); }, registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void { - commands.set(name, { name, ...options }); + extension.commands.set(name, { name, ...options }); }, registerShortcut( @@ -174,37 +133,40 @@ function createExtensionAPI( handler: (ctx: import("./types.js").ExtensionContext) => Promise | void; }, ): void { - shortcuts.set(shortcut, { shortcut, extensionPath, ...options }); + extension.shortcuts.set(shortcut, { shortcut, extensionPath: extension.path, ...options }); }, registerFlag( name: string, options: { description?: string; type: "boolean" | "string"; default?: boolean | string }, ): void { - flags.set(name, { name, extensionPath, ...options }); + extension.flags.set(name, { name, extensionPath: extension.path, ...options }); if (options.default !== undefined) { - flagValues.set(name, options.default); + runtime.flagValues.set(name, options.default); } }, - getFlag(name: string): boolean | string | undefined { - return flagValues.get(name); - }, - registerMessageRenderer(customType: string, renderer: MessageRenderer): void { - messageRenderers.set(customType, renderer as MessageRenderer); + extension.messageRenderers.set(customType, renderer as MessageRenderer); }, + // Flag access - checks extension registered it, reads from runtime + getFlag(name: string): boolean | string | undefined { + if (!extension.flags.has(name)) return undefined; + return runtime.flagValues.get(name); + }, + + // Action methods - delegate to shared runtime sendMessage(message, options): void { - sendMessageHandler(message, options); + runtime.sendMessage(message, options); }, sendUserMessage(content, options): void { - sendUserMessageHandler(content, options); + runtime.sendUserMessage(content, options); }, appendEntry(customType: string, data?: unknown): void { - appendEntryHandler(customType, data); + runtime.appendEntry(customType, data); }, exec(command: string, args: string[], options?: ExecOptions) { @@ -212,222 +174,86 @@ function createExtensionAPI( }, getActiveTools(): string[] { - return getActiveToolsHandler(); + return runtime.getActiveTools(); }, getAllTools(): string[] { - return getAllToolsHandler(); + return runtime.getAllTools(); }, setActiveTools(toolNames: string[]): void { - setActiveToolsHandler(toolNames); + runtime.setActiveTools(toolNames); }, setModel(model) { - return setModelHandler(model); + return runtime.setModel(model); }, getThinkingLevel() { - return getThinkingLevelHandler(); + return runtime.getThinkingLevel(); }, setThinkingLevel(level) { - setThinkingLevelHandler(level); + runtime.setThinkingLevel(level); }, events: eventBus, } as ExtensionAPI; - return { - api, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler: (handler: SendMessageHandler) => { - sendMessageHandler = handler; - }, - setSendUserMessageHandler: (handler: SendUserMessageHandler) => { - sendUserMessageHandler = handler; - }, - setAppendEntryHandler: (handler: AppendEntryHandler) => { - appendEntryHandler = handler; - }, - setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => { - getActiveToolsHandler = handler; - }, - setGetAllToolsHandler: (handler: GetAllToolsHandler) => { - getAllToolsHandler = handler; - }, - setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => { - setActiveToolsHandler = handler; - }, - setSetModelHandler: (handler: SetModelHandler) => { - setModelHandler = handler; - }, - setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => { - getThinkingLevelHandler = handler; - }, - setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => { - setThinkingLevelHandler = handler; - }, - setFlagValue: (name: string, value: boolean | string) => { - flagValues.set(name, value); - }, - }; + return api; } -async function loadExtensionWithBun( - resolvedPath: string, - cwd: string, - extensionPath: string, - eventBus: EventBus, - sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, -): Promise<{ extension: LoadedExtension | null; error: string | null }> { - try { - const module = await import(resolvedPath); - const factory = (module.default ?? module) as ExtensionFactory; +async function loadBun(path: string) { + const module = await import(path); + const factory = (module.default ?? module) as ExtensionFactory; + return typeof factory !== "function" ? undefined : factory; +} - if (typeof factory !== "function") { - return { extension: null, error: "Extension must export a default function" }; - } +async function loadJiti(path: string) { + const jiti = createJiti(import.meta.url, { + alias: getAliases(), + }); - const handlers = new Map(); - const tools = new Map(); - const { - api, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - } = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI); + const module = await jiti.import(path, { default: true }); + const factory = module as ExtensionFactory; + return typeof factory !== "function" ? undefined : factory; +} - await factory(api); - - return { - extension: { - path: extensionPath, - resolvedPath, - handlers, - tools, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - }, - error: null, - }; - } catch (err) { - const message = err instanceof Error ? err.message : String(err); - - if (message.includes("Cannot find module") && message.includes("@mariozechner/")) { - return { - extension: null, - error: - `${message}\n` + - "Note: Extensions importing from @mariozechner/* packages are not supported in the standalone binary.\n" + - "Please install pi via npm: npm install -g @mariozechner/pi-coding-agent", - }; - } - - return { extension: null, error: `Failed to load extension: ${message}` }; - } +/** + * Create an Extension object with empty collections. + */ +function createExtension(extensionPath: string, resolvedPath: string): Extension { + return { + path: extensionPath, + resolvedPath, + handlers: new Map(), + tools: new Map(), + messageRenderers: new Map(), + commands: new Map(), + flags: new Map(), + shortcuts: new Map(), + }; } async function loadExtension( extensionPath: string, cwd: string, eventBus: EventBus, - sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, -): Promise<{ extension: LoadedExtension | null; error: string | null }> { + runtime: ExtensionRuntime, +): Promise<{ extension: Extension | null; error: string | null }> { const resolvedPath = resolvePath(extensionPath, cwd); - if (isBunBinary) { - return loadExtensionWithBun(resolvedPath, cwd, extensionPath, eventBus, sharedUI); - } - try { - const jiti = createJiti(import.meta.url, { - alias: getAliases(), - }); - - const module = await jiti.import(resolvedPath, { default: true }); - const factory = module as ExtensionFactory; - - if (typeof factory !== "function") { - return { extension: null, error: "Extension must export a default function" }; + const factory = isBunBinary ? await loadBun(resolvedPath) : await loadJiti(resolvedPath); + if (!factory) { + return { extension: null, error: `Extension does not export a valid factory function: ${extensionPath}` }; } - const handlers = new Map(); - const tools = new Map(); - const { - api, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - } = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI); - + const extension = createExtension(extensionPath, resolvedPath); + const api = createExtensionAPI(extension, runtime, cwd, eventBus); await factory(api); - return { - extension: { - path: extensionPath, - resolvedPath, - handlers, - tools, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - }, - error: null, - }; + return { extension, error: null }; } catch (err) { const message = err instanceof Error ? err.message : String(err); return { extension: null, error: `Failed to load extension: ${message}` }; @@ -435,72 +261,32 @@ async function loadExtension( } /** - * Create a LoadedExtension from an inline factory function. + * Create an Extension from an inline factory function. */ export async function loadExtensionFromFactory( factory: ExtensionFactory, cwd: string, eventBus: EventBus, - sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, - name = "", -): Promise { - const handlers = new Map(); - const tools = new Map(); - const { - api, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - } = createExtensionAPI(handlers, tools, cwd, name, eventBus, sharedUI); - + runtime: ExtensionRuntime, + extensionPath = "", +): Promise { + const extension = createExtension(extensionPath, extensionPath); + const api = createExtensionAPI(extension, runtime, cwd, eventBus); await factory(api); - - return { - path: name, - resolvedPath: name, - handlers, - tools, - messageRenderers, - commands, - flags, - flagValues, - shortcuts, - setSendMessageHandler, - setSendUserMessageHandler, - setAppendEntryHandler, - setGetActiveToolsHandler, - setGetAllToolsHandler, - setSetActiveToolsHandler, - setSetModelHandler, - setGetThinkingLevelHandler, - setSetThinkingLevelHandler, - setFlagValue, - }; + return extension; } /** * Load extensions from paths. */ export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise { - const extensions: LoadedExtension[] = []; + const extensions: Extension[] = []; const errors: Array<{ path: string; error: string }> = []; const resolvedEventBus = eventBus ?? createEventBus(); - const sharedUI = { ui: createNoOpUIContext(), hasUI: false }; + const runtime = createExtensionRuntime(); for (const extPath of paths) { - const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, sharedUI); + const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, runtime); if (error) { errors.push({ path: extPath, error }); @@ -515,10 +301,7 @@ export async function loadExtensions(paths: string[], cwd: string, eventBus?: Ev return { extensions, errors, - setUIContext(uiContext, hasUI) { - sharedUI.ui = uiContext; - sharedUI.hasUI = hasUI; - }, + runtime, }; } diff --git a/packages/coding-agent/src/core/extensions/runner.ts b/packages/coding-agent/src/core/extensions/runner.ts index 79ddcd65..c81f7624 100644 --- a/packages/coding-agent/src/core/extensions/runner.ts +++ b/packages/coding-agent/src/core/extensions/runner.ts @@ -9,32 +9,27 @@ import { theme } from "../../modes/interactive/theme/theme.js"; import type { ModelRegistry } from "../model-registry.js"; import type { SessionManager } from "../session-manager.js"; import type { - AppendEntryHandler, BeforeAgentStartEvent, BeforeAgentStartEventResult, ContextEvent, ContextEventResult, + Extension, + ExtensionActions, ExtensionCommandContext, + ExtensionCommandContextActions, ExtensionContext, + ExtensionContextActions, ExtensionError, ExtensionEvent, ExtensionFlag, + ExtensionRuntime, ExtensionShortcut, ExtensionUIContext, - GetActiveToolsHandler, - GetAllToolsHandler, - GetThinkingLevelHandler, - LoadedExtension, MessageRenderer, RegisteredCommand, RegisteredTool, - SendMessageHandler, - SendUserMessageHandler, SessionBeforeCompactResult, SessionBeforeTreeResult, - SetActiveToolsHandler, - SetModelHandler, - SetThinkingLevelHandler, ToolCallEvent, ToolCallEventResult, ToolResultEventResult, @@ -81,9 +76,9 @@ const noOpUIContext: ExtensionUIContext = { }; export class ExtensionRunner { - private extensions: LoadedExtension[]; + private extensions: Extension[]; + private runtime: ExtensionRuntime; private uiContext: ExtensionUIContext; - private hasUI: boolean; private cwd: string; private sessionManager: SessionManager; private modelRegistry: ModelRegistry; @@ -98,78 +93,60 @@ export class ExtensionRunner { private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false }); constructor( - extensions: LoadedExtension[], + extensions: Extension[], + runtime: ExtensionRuntime, cwd: string, sessionManager: SessionManager, modelRegistry: ModelRegistry, ) { this.extensions = extensions; + this.runtime = runtime; this.uiContext = noOpUIContext; - this.hasUI = false; this.cwd = cwd; this.sessionManager = sessionManager; this.modelRegistry = modelRegistry; } - initialize(options: { - getModel: () => Model | undefined; - sendMessageHandler: SendMessageHandler; - sendUserMessageHandler: SendUserMessageHandler; - appendEntryHandler: AppendEntryHandler; - getActiveToolsHandler: GetActiveToolsHandler; - getAllToolsHandler: GetAllToolsHandler; - setActiveToolsHandler: SetActiveToolsHandler; - setModelHandler: SetModelHandler; - getThinkingLevelHandler: GetThinkingLevelHandler; - setThinkingLevelHandler: SetThinkingLevelHandler; - newSessionHandler?: NewSessionHandler; - branchHandler?: BranchHandler; - navigateTreeHandler?: NavigateTreeHandler; - isIdle?: () => boolean; - waitForIdle?: () => Promise; - abort?: () => void; - hasPendingMessages?: () => boolean; - uiContext?: ExtensionUIContext; - hasUI?: boolean; - }): void { - this.getModel = options.getModel; - this.isIdleFn = options.isIdle ?? (() => true); - this.waitForIdleFn = options.waitForIdle ?? (async () => {}); - this.abortFn = options.abort ?? (() => {}); - this.hasPendingMessagesFn = options.hasPendingMessages ?? (() => false); + initialize( + actions: ExtensionActions, + contextActions: ExtensionContextActions, + commandContextActions?: ExtensionCommandContextActions, + uiContext?: ExtensionUIContext, + ): void { + // Copy actions into the shared runtime (all extension APIs reference this) + this.runtime.sendMessage = actions.sendMessage; + this.runtime.sendUserMessage = actions.sendUserMessage; + this.runtime.appendEntry = actions.appendEntry; + this.runtime.getActiveTools = actions.getActiveTools; + this.runtime.getAllTools = actions.getAllTools; + this.runtime.setActiveTools = actions.setActiveTools; + this.runtime.setModel = actions.setModel; + this.runtime.getThinkingLevel = actions.getThinkingLevel; + this.runtime.setThinkingLevel = actions.setThinkingLevel; - if (options.newSessionHandler) { - this.newSessionHandler = options.newSessionHandler; - } - if (options.branchHandler) { - this.branchHandler = options.branchHandler; - } - if (options.navigateTreeHandler) { - this.navigateTreeHandler = options.navigateTreeHandler; + // Context actions (required) + this.getModel = contextActions.getModel; + this.isIdleFn = contextActions.isIdle; + this.abortFn = contextActions.abort; + this.hasPendingMessagesFn = contextActions.hasPendingMessages; + + // Command context actions (optional, only for interactive mode) + if (commandContextActions) { + this.waitForIdleFn = commandContextActions.waitForIdle; + this.newSessionHandler = commandContextActions.newSession; + this.branchHandler = commandContextActions.branch; + this.navigateTreeHandler = commandContextActions.navigateTree; } - for (const ext of this.extensions) { - ext.setSendMessageHandler(options.sendMessageHandler); - ext.setSendUserMessageHandler(options.sendUserMessageHandler); - ext.setAppendEntryHandler(options.appendEntryHandler); - ext.setGetActiveToolsHandler(options.getActiveToolsHandler); - ext.setGetAllToolsHandler(options.getAllToolsHandler); - ext.setSetActiveToolsHandler(options.setActiveToolsHandler); - ext.setSetModelHandler(options.setModelHandler); - ext.setGetThinkingLevelHandler(options.getThinkingLevelHandler); - ext.setSetThinkingLevelHandler(options.setThinkingLevelHandler); - } - - this.uiContext = options.uiContext ?? noOpUIContext; - this.hasUI = options.hasUI ?? false; + this.uiContext = uiContext ?? noOpUIContext; } - getUIContext(): ExtensionUIContext | null { + getUIContext(): ExtensionUIContext { return this.uiContext; } - getHasUI(): boolean { - return this.hasUI; + hasUI(): boolean { + return this.uiContext !== noOpUIContext; } getExtensionPaths(): string[] { @@ -198,11 +175,7 @@ export class ExtensionRunner { } setFlagValue(name: string, value: boolean | string): void { - for (const ext of this.extensions) { - if (ext.flags.has(name)) { - ext.setFlagValue(name, value); - } - } + this.runtime.flagValues.set(name, value); } private static readonly RESERVED_SHORTCUTS = new Set([ @@ -301,7 +274,7 @@ export class ExtensionRunner { private createContext(): ExtensionContext { return { ui: this.uiContext, - hasUI: this.hasUI, + hasUI: this.hasUI(), cwd: this.cwd, sessionManager: this.sessionManager, modelRegistry: this.modelRegistry, diff --git a/packages/coding-agent/src/core/extensions/types.ts b/packages/coding-agent/src/core/extensions/types.ts index dd5de3ab..e333ca04 100644 --- a/packages/coding-agent/src/core/extensions/types.ts +++ b/packages/coding-agent/src/core/extensions/types.ts @@ -21,7 +21,7 @@ import type { Theme } from "../../modes/interactive/theme/theme.js"; import type { CompactionPreparation, CompactionResult } from "../compaction/index.js"; import type { EventBus } from "../event-bus.js"; import type { ExecOptions, ExecResult } from "../exec.js"; -import type { AppAction, KeybindingsManager } from "../keybindings.js"; +import type { KeybindingsManager } from "../keybindings.js"; import type { CustomMessage } from "../messages.js"; import type { ModelRegistry } from "../model-registry.js"; import type { @@ -742,8 +742,63 @@ export type GetThinkingLevelHandler = () => ThinkingLevel; export type SetThinkingLevelHandler = (level: ThinkingLevel) => void; +/** + * Shared state created by loader, used during registration and runtime. + * Contains flag values (defaults set during registration, CLI values set after). + */ +export interface ExtensionRuntimeState { + flagValues: Map; +} + +/** + * Action implementations for pi.* API methods. + * Provided to runner.initialize(), copied into the shared runtime. + */ +export interface ExtensionActions { + sendMessage: SendMessageHandler; + sendUserMessage: SendUserMessageHandler; + appendEntry: AppendEntryHandler; + getActiveTools: GetActiveToolsHandler; + getAllTools: GetAllToolsHandler; + setActiveTools: SetActiveToolsHandler; + setModel: SetModelHandler; + getThinkingLevel: GetThinkingLevelHandler; + setThinkingLevel: SetThinkingLevelHandler; +} + +/** + * Actions for ExtensionContext (ctx.* in event handlers). + * Required by all modes. + */ +export interface ExtensionContextActions { + getModel: () => Model | undefined; + isIdle: () => boolean; + abort: () => void; + hasPendingMessages: () => boolean; +} + +/** + * Actions for ExtensionCommandContext (ctx.* in command handlers). + * Only needed for interactive mode where extension commands are invokable. + */ +export interface ExtensionCommandContextActions { + waitForIdle: () => Promise; + newSession: (options?: { + parentSession?: string; + setup?: (sessionManager: SessionManager) => Promise; + }) => Promise<{ cancelled: boolean }>; + branch: (entryId: string) => Promise<{ cancelled: boolean }>; + navigateTree: (targetId: string, options?: { summarize?: boolean }) => Promise<{ cancelled: boolean }>; +} + +/** + * Full runtime = state + actions. + * Created by loader with throwing action stubs, completed by runner.initialize(). + */ +export interface ExtensionRuntime extends ExtensionRuntimeState, ExtensionActions {} + /** Loaded extension with all registered items. */ -export interface LoadedExtension { +export interface Extension { path: string; resolvedPath: string; handlers: Map; @@ -751,25 +806,15 @@ export interface LoadedExtension { messageRenderers: Map; commands: Map; flags: Map; - flagValues: Map; shortcuts: Map; - setSendMessageHandler: (handler: SendMessageHandler) => void; - setSendUserMessageHandler: (handler: SendUserMessageHandler) => void; - setAppendEntryHandler: (handler: AppendEntryHandler) => void; - setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void; - setGetAllToolsHandler: (handler: GetAllToolsHandler) => void; - setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void; - setSetModelHandler: (handler: SetModelHandler) => void; - setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => void; - setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => void; - setFlagValue: (name: string, value: boolean | string) => void; } /** Result of loading extensions. */ export interface LoadExtensionsResult { - extensions: LoadedExtension[]; + extensions: Extension[]; errors: Array<{ path: string; error: string }>; - setUIContext(uiContext: ExtensionUIContext, hasUI: boolean): void; + /** Shared runtime - actions are throwing stubs until runner.initialize() */ + runtime: ExtensionRuntime; } // ============================================================================ diff --git a/packages/coding-agent/src/core/index.ts b/packages/coding-agent/src/core/index.ts index e0b95f4c..6aa4276c 100644 --- a/packages/coding-agent/src/core/index.ts +++ b/packages/coding-agent/src/core/index.ts @@ -26,6 +26,7 @@ export { discoverAndLoadExtensions, type ExecOptions, type ExecResult, + type Extension, type ExtensionAPI, type ExtensionCommandContext, type ExtensionContext, @@ -38,7 +39,6 @@ export { type ExtensionShortcut, type ExtensionUIContext, type LoadExtensionsResult, - type LoadedExtension, type MessageRenderer, type RegisteredCommand, type SessionBeforeBranchEvent, diff --git a/packages/coding-agent/src/core/sdk.ts b/packages/coding-agent/src/core/sdk.ts index 97812b58..c99ca8a2 100644 --- a/packages/coding-agent/src/core/sdk.ts +++ b/packages/coding-agent/src/core/sdk.ts @@ -28,11 +28,11 @@ import { AgentSession } from "./agent-session.js"; import { AuthStorage } from "./auth-storage.js"; import { createEventBus, type EventBus } from "./event-bus.js"; import { + createExtensionRuntime, discoverAndLoadExtensions, type ExtensionFactory, ExtensionRunner, type LoadExtensionsResult, - type LoadedExtension, loadExtensionFromFactory, type ToolDefinition, wrapRegisteredTools, @@ -106,10 +106,10 @@ export interface CreateAgentSessionOptions { /** Additional extension paths to load (merged with discovery). */ additionalExtensionPaths?: string[]; /** - * Pre-loaded extensions (skips file discovery). + * Pre-loaded extensions result (skips file discovery). * @internal Used by CLI when extensions are loaded early to parse custom flags. */ - preloadedExtensions?: LoadedExtension[]; + preloadedExtensionsResult?: LoadExtensionsResult; /** Shared event bus for tool/extension communication. Default: creates new bus. */ eventBus?: EventBus; @@ -438,20 +438,17 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // Load extensions (discovers from standard locations + configured paths) let extensionsResult: LoadExtensionsResult; - if (options.preloadedExtensions !== undefined && options.preloadedExtensions.length > 0) { + if (options.preloadedExtensionsResult !== undefined) { // Use pre-loaded extensions (from early CLI flag discovery) - extensionsResult = { - extensions: options.preloadedExtensions, - errors: [], - setUIContext: () => {}, - }; + extensionsResult = options.preloadedExtensionsResult; } else if (options.extensions !== undefined) { // User explicitly provided extensions array (even if empty) - skip discovery - // Inline factories from options.extensions are loaded below + // Create runtime for inline extensions + const runtime = createExtensionRuntime(); extensionsResult = { extensions: [], errors: [], - setUIContext: () => {}, + runtime, }; } else { // Discover extensions, merging with additional paths @@ -465,45 +462,29 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // Load inline extensions from factories if (options.extensions && options.extensions.length > 0) { - // Create shared UI context holder that will be set later - const uiHolder: { ui: any; hasUI: boolean } = { - ui: { - select: async () => undefined, - confirm: async () => false, - input: async () => undefined, - notify: () => {}, - setStatus: () => {}, - setWidget: () => {}, - setFooter: () => {}, - setTitle: () => {}, - custom: async () => undefined as never, - setEditorText: () => {}, - getEditorText: () => "", - editor: async () => undefined, - get theme() { - return {} as any; - }, - }, - hasUI: false, - }; for (let i = 0; i < options.extensions.length; i++) { const factory = options.extensions[i]; - const loaded = await loadExtensionFromFactory(factory, cwd, eventBus, uiHolder, ``); + const loaded = await loadExtensionFromFactory( + factory, + cwd, + eventBus, + extensionsResult.runtime, + ``, + ); extensionsResult.extensions.push(loaded); } - // Extend setUIContext to update inline extensions too - const originalSetUIContext = extensionsResult.setUIContext; - extensionsResult.setUIContext = (uiContext, hasUI) => { - originalSetUIContext(uiContext, hasUI); - uiHolder.ui = uiContext; - uiHolder.hasUI = hasUI; - }; } // Create extension runner if we have extensions let extensionRunner: ExtensionRunner | undefined; if (extensionsResult.extensions.length > 0) { - extensionRunner = new ExtensionRunner(extensionsResult.extensions, cwd, sessionManager, modelRegistry); + extensionRunner = new ExtensionRunner( + extensionsResult.extensions, + extensionsResult.runtime, + cwd, + sessionManager, + modelRegistry, + ); } // Wrap extension-registered tools and SDK-provided custom tools with context getter @@ -536,7 +517,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} return {} as any; }, }, - hasUI: extensionRunner?.getHasUI() ?? false, + hasUI: extensionRunner?.hasUI() ?? false, cwd, sessionManager, modelRegistry, diff --git a/packages/coding-agent/src/index.ts b/packages/coding-agent/src/index.ts index 7b81f39c..2aa9f924 100644 --- a/packages/coding-agent/src/index.ts +++ b/packages/coding-agent/src/index.ts @@ -45,20 +45,24 @@ export type { ContextEvent, ExecOptions, ExecResult, + Extension, + ExtensionActions, ExtensionAPI, ExtensionCommandContext, + ExtensionCommandContextActions, ExtensionContext, + ExtensionContextActions, ExtensionError, ExtensionEvent, ExtensionFactory, ExtensionFlag, ExtensionHandler, + ExtensionRuntime, ExtensionShortcut, ExtensionUIContext, ExtensionUIDialogOptions, KeybindingsManager, LoadExtensionsResult, - LoadedExtension, MessageRenderer, MessageRenderOptions, RegisteredCommand, @@ -81,6 +85,7 @@ export type { TurnStartEvent, } from "./core/extensions/index.js"; export { + createExtensionRuntime, discoverAndLoadExtensions, ExtensionRunner, isBashToolResult, diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index a05c46a8..276a9697 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -18,7 +18,7 @@ import type { AgentSession } from "./core/agent-session.js"; import { createEventBus } from "./core/event-bus.js"; import { exportFromFile } from "./core/export-html/index.js"; -import { discoverAndLoadExtensions, type ExtensionUIContext, type LoadedExtension } from "./core/extensions/index.js"; +import { discoverAndLoadExtensions, type LoadExtensionsResult } from "./core/extensions/index.js"; import type { ModelRegistry } from "./core/model-registry.js"; import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js"; import { type CreateAgentSessionOptions, createAgentSession, discoverAuthStorage, discoverModels } from "./core/sdk.js"; @@ -60,13 +60,11 @@ async function runInteractiveMode( migratedProviders: string[], versionCheckPromise: Promise, initialMessages: string[], - extensions: LoadedExtension[], - setExtensionUIContext: (uiContext: ExtensionUIContext, hasUI: boolean) => void, initialMessage?: string, initialImages?: ImageContent[], fdPath: string | undefined = undefined, ): Promise { - const mode = new InteractiveMode(session, version, changelogMarkdown, extensions, setExtensionUIContext, fdPath); + const mode = new InteractiveMode(session, version, changelogMarkdown, fdPath); await mode.init(); @@ -236,7 +234,7 @@ function buildSessionOptions( sessionManager: SessionManager | undefined, modelRegistry: ModelRegistry, settingsManager: SettingsManager, - preloadedExtensions?: LoadedExtension[], + extensionsResult?: LoadExtensionsResult, ): CreateAgentSessionOptions { const options: CreateAgentSessionOptions = {}; @@ -302,8 +300,8 @@ function buildSessionOptions( } // Pre-loaded extensions (from early CLI flag discovery) - if (preloadedExtensions && preloadedExtensions.length > 0) { - options.preloadedExtensions = preloadedExtensions; + if (extensionsResult && extensionsResult.extensions.length > 0) { + options.preloadedExtensionsResult = extensionsResult; } return options; @@ -332,12 +330,12 @@ export async function main(args: string[]) { time("SettingsManager.create"); // Merge CLI --extension args with settings.json extensions const extensionPaths = [...settingsManager.getExtensionPaths(), ...(firstPass.extensions ?? [])]; - const { extensions: loadedExtensions } = await discoverAndLoadExtensions(extensionPaths, cwd, agentDir, eventBus); + const extensionsResult = await discoverAndLoadExtensions(extensionPaths, cwd, agentDir, eventBus); time("discoverExtensionFlags"); // Collect all extension flags const extensionFlags = new Map(); - for (const ext of loadedExtensions) { + for (const ext of extensionsResult.extensions) { for (const [name, flag] of ext.flags) { extensionFlags.set(name, { type: flag.type }); } @@ -347,13 +345,9 @@ export async function main(args: string[]) { const parsed = parseArgs(args, extensionFlags); time("parseArgs"); - // Pass flag values to extensions + // Pass flag values to extensions via runtime for (const [name, value] of parsed.unknownFlags) { - for (const ext of loadedExtensions) { - if (ext.flags.has(name)) { - ext.setFlagValue(name, value); - } - } + extensionsResult.runtime.flagValues.set(name, value); } if (parsed.version) { @@ -436,7 +430,7 @@ export async function main(args: string[]) { sessionManager, modelRegistry, settingsManager, - loadedExtensions, + extensionsResult, ); sessionOptions.authStorage = authStorage; sessionOptions.modelRegistry = modelRegistry; @@ -452,7 +446,7 @@ export async function main(args: string[]) { } time("buildSessionOptions"); - const { session, extensionsResult, modelFallbackMessage } = await createAgentSession(sessionOptions); + const { session, modelFallbackMessage } = await createAgentSession(sessionOptions); time("createAgentSession"); if (!isInteractive && !session.model) { @@ -505,8 +499,6 @@ export async function main(args: string[]) { migratedProviders, versionCheckPromise, parsed.messages, - extensionsResult.extensions, - extensionsResult.setUIContext, initialMessage, initialImages, fdPath, diff --git a/packages/coding-agent/src/modes/interactive/interactive-mode.ts b/packages/coding-agent/src/modes/interactive/interactive-mode.ts index 6ae1f24f..de32bdcb 100644 --- a/packages/coding-agent/src/modes/interactive/interactive-mode.ts +++ b/packages/coding-agent/src/modes/interactive/interactive-mode.ts @@ -33,13 +33,13 @@ import type { ExtensionRunner, ExtensionUIContext, ExtensionUIDialogOptions, - LoadedExtension, } from "../../core/extensions/index.js"; import { KeybindingsManager } from "../../core/keybindings.js"; import { createCompactionSummaryMessage } from "../../core/messages.js"; import { type SessionContext, SessionManager } from "../../core/session-manager.js"; import { loadSkills } from "../../core/skills.js"; import { loadProjectContextFiles } from "../../core/system-prompt.js"; +import { allTools } from "../../core/tools/index.js"; import type { TruncationResult } from "../../core/tools/truncate.js"; import { getChangelogPath, parseChangelog } from "../../utils/changelog.js"; import { copyToClipboard } from "../../utils/clipboard.js"; @@ -184,8 +184,6 @@ export class InteractiveMode { session: AgentSession, version: string, changelogMarkdown: string | undefined = undefined, - _extensions: LoadedExtension[] = [], - private setExtensionUIContext: (uiContext: ExtensionUIContext, hasUI: boolean) => void = () => {}, fdPath: string | undefined = undefined, ) { this.session = session; @@ -429,123 +427,124 @@ export class InteractiveMode { } } - // Create and set extension UI context - const uiContext = this.createExtensionUIContext(); - this.setExtensionUIContext(uiContext, true); - const extensionRunner = this.session.extensionRunner; if (!extensionRunner) { return; // No extensions loaded } - extensionRunner.initialize({ - getModel: () => this.session.model, - sendMessageHandler: (message, options) => { - const wasStreaming = this.session.isStreaming; - this.session - .sendCustomMessage(message, options) - .then(() => { - // For non-streaming cases with display=true, update UI - // (streaming cases update via message_end event) - if (!wasStreaming && message.display) { - this.rebuildChatFromMessages(); - } - }) - .catch((err) => { - this.showError(`Extension sendMessage failed: ${err instanceof Error ? err.message : String(err)}`); + // Create extension UI context + const uiContext = this.createExtensionUIContext(); + + extensionRunner.initialize( + // ExtensionActions - for pi.* API + { + sendMessage: (message, options) => { + const wasStreaming = this.session.isStreaming; + this.session + .sendCustomMessage(message, options) + .then(() => { + if (!wasStreaming && message.display) { + this.rebuildChatFromMessages(); + } + }) + .catch((err) => { + this.showError( + `Extension sendMessage failed: ${err instanceof Error ? err.message : String(err)}`, + ); + }); + }, + sendUserMessage: (content, options) => { + this.session.sendUserMessage(content, options).catch((err) => { + this.showError( + `Extension sendUserMessage failed: ${err instanceof Error ? err.message : String(err)}`, + ); }); + }, + appendEntry: (customType, data) => { + this.sessionManager.appendCustomEntry(customType, data); + }, + getActiveTools: () => this.session.getActiveToolNames(), + getAllTools: () => this.session.getAllToolNames(), + setActiveTools: (toolNames) => this.session.setActiveToolsByName(toolNames), + setModel: async (model) => { + const key = await this.session.modelRegistry.getApiKey(model); + if (!key) return false; + await this.session.setModel(model); + return true; + }, + getThinkingLevel: () => this.session.thinkingLevel, + setThinkingLevel: (level) => this.session.setThinkingLevel(level), }, - sendUserMessageHandler: (content, options) => { - this.session.sendUserMessage(content, options).catch((err) => { - this.showError(`Extension sendUserMessage failed: ${err instanceof Error ? err.message : String(err)}`); - }); + // ExtensionContextActions - for ctx.* in event handlers + { + getModel: () => this.session.model, + isIdle: () => !this.session.isStreaming, + abort: () => this.session.abort(), + hasPendingMessages: () => this.session.pendingMessageCount > 0, }, - appendEntryHandler: (customType, data) => { - this.sessionManager.appendCustomEntry(customType, data); + // ExtensionCommandContextActions - for ctx.* in command handlers + { + waitForIdle: () => this.session.agent.waitForIdle(), + newSession: async (options) => { + if (this.loadingAnimation) { + this.loadingAnimation.stop(); + this.loadingAnimation = undefined; + } + this.statusContainer.clear(); + + const success = await this.session.newSession({ parentSession: options?.parentSession }); + if (!success) { + return { cancelled: true }; + } + + if (options?.setup) { + await options.setup(this.sessionManager); + } + + this.chatContainer.clear(); + this.pendingMessagesContainer.clear(); + this.compactionQueuedMessages = []; + this.streamingComponent = undefined; + this.streamingMessage = undefined; + this.pendingTools.clear(); + + this.chatContainer.addChild(new Spacer(1)); + this.chatContainer.addChild(new Text(`${theme.fg("accent", "✓ New session started")}`, 1, 1)); + this.ui.requestRender(); + + return { cancelled: false }; + }, + branch: async (entryId) => { + const result = await this.session.branch(entryId); + if (result.cancelled) { + return { cancelled: true }; + } + + this.chatContainer.clear(); + this.renderInitialMessages(); + this.editor.setText(result.selectedText); + this.showStatus("Branched to new session"); + + return { cancelled: false }; + }, + navigateTree: async (targetId, options) => { + const result = await this.session.navigateTree(targetId, { summarize: options?.summarize }); + if (result.cancelled) { + return { cancelled: true }; + } + + this.chatContainer.clear(); + this.renderInitialMessages(); + if (result.editorText) { + this.editor.setText(result.editorText); + } + this.showStatus("Navigated to selected point"); + + return { cancelled: false }; + }, }, - getActiveToolsHandler: () => this.session.getActiveToolNames(), - getAllToolsHandler: () => this.session.getAllToolNames(), - setActiveToolsHandler: (toolNames) => this.session.setActiveToolsByName(toolNames), - newSessionHandler: async (options) => { - // Stop any loading animation - if (this.loadingAnimation) { - this.loadingAnimation.stop(); - this.loadingAnimation = undefined; - } - this.statusContainer.clear(); - - // Create new session - const success = await this.session.newSession({ parentSession: options?.parentSession }); - if (!success) { - return { cancelled: true }; - } - - // Call setup callback if provided - if (options?.setup) { - await options.setup(this.sessionManager); - } - - // Clear UI state - this.chatContainer.clear(); - this.pendingMessagesContainer.clear(); - this.compactionQueuedMessages = []; - this.streamingComponent = undefined; - this.streamingMessage = undefined; - this.pendingTools.clear(); - - this.chatContainer.addChild(new Spacer(1)); - this.chatContainer.addChild(new Text(`${theme.fg("accent", "✓ New session started")}`, 1, 1)); - this.ui.requestRender(); - - return { cancelled: false }; - }, - branchHandler: async (entryId) => { - const result = await this.session.branch(entryId); - if (result.cancelled) { - return { cancelled: true }; - } - - // Update UI - this.chatContainer.clear(); - this.renderInitialMessages(); - this.editor.setText(result.selectedText); - this.showStatus("Branched to new session"); - - return { cancelled: false }; - }, - navigateTreeHandler: async (targetId, options) => { - const result = await this.session.navigateTree(targetId, { summarize: options?.summarize }); - if (result.cancelled) { - return { cancelled: true }; - } - - // Update UI - this.chatContainer.clear(); - this.renderInitialMessages(); - if (result.editorText) { - this.editor.setText(result.editorText); - } - this.showStatus("Navigated to selected point"); - - return { cancelled: false }; - }, - setModelHandler: async (model) => { - const key = await this.session.modelRegistry.getApiKey(model); - if (!key) return false; - await this.session.setModel(model); - return true; - }, - getThinkingLevelHandler: () => this.session.thinkingLevel, - setThinkingLevelHandler: (level) => this.session.setThinkingLevel(level), - isIdle: () => !this.session.isStreaming, - waitForIdle: () => this.session.agent.waitForIdle(), - abort: () => { - this.session.abort(); - }, - hasPendingMessages: () => this.session.pendingMessageCount > 0, uiContext, - hasUI: true, - }); + ); // Subscribe to extension errors extensionRunner.onError((error) => { @@ -563,6 +562,24 @@ export class InteractiveMode { this.chatContainer.addChild(new Spacer(1)); } + // Warn about built-in tool overrides + const builtInToolNames = new Set(Object.keys(allTools)); + const registeredTools = extensionRunner.getAllRegisteredTools(); + for (const tool of registeredTools) { + if (builtInToolNames.has(tool.definition.name)) { + this.chatContainer.addChild( + new Text( + theme.fg( + "warning", + `Warning: Extension "${tool.extensionPath}" overrides built-in tool "${tool.definition.name}"`, + ), + 0, + 0, + ), + ); + } + } + // Emit session_start event await extensionRunner.emit({ type: "session_start", diff --git a/packages/coding-agent/src/modes/print-mode.ts b/packages/coding-agent/src/modes/print-mode.ts index afa26853..87d9db09 100644 --- a/packages/coding-agent/src/modes/print-mode.ts +++ b/packages/coding-agent/src/modes/print-mode.ts @@ -26,37 +26,65 @@ export async function runPrintMode( initialMessage?: string, initialImages?: ImageContent[], ): Promise { - // Extension runner already has no-op UI context by default (set in loader) - // Set up extensions for print mode (no UI) + // Set up extensions for print mode (no UI, no command context) const extensionRunner = session.extensionRunner; if (extensionRunner) { - extensionRunner.initialize({ - getModel: () => session.model, - sendMessageHandler: (message, options) => { - session.sendCustomMessage(message, options).catch((e) => { - console.error(`Extension sendMessage failed: ${e instanceof Error ? e.message : String(e)}`); - }); + extensionRunner.initialize( + // ExtensionActions + { + sendMessage: (message, options) => { + session.sendCustomMessage(message, options).catch((e) => { + console.error(`Extension sendMessage failed: ${e instanceof Error ? e.message : String(e)}`); + }); + }, + sendUserMessage: (content, options) => { + session.sendUserMessage(content, options).catch((e) => { + console.error(`Extension sendUserMessage failed: ${e instanceof Error ? e.message : String(e)}`); + }); + }, + appendEntry: (customType, data) => { + session.sessionManager.appendCustomEntry(customType, data); + }, + getActiveTools: () => session.getActiveToolNames(), + getAllTools: () => session.getAllToolNames(), + setActiveTools: (toolNames: string[]) => session.setActiveToolsByName(toolNames), + setModel: async (model) => { + const key = await session.modelRegistry.getApiKey(model); + if (!key) return false; + await session.setModel(model); + return true; + }, + getThinkingLevel: () => session.thinkingLevel, + setThinkingLevel: (level) => session.setThinkingLevel(level), }, - sendUserMessageHandler: (content, options) => { - session.sendUserMessage(content, options).catch((e) => { - console.error(`Extension sendUserMessage failed: ${e instanceof Error ? e.message : String(e)}`); - }); + // ExtensionContextActions + { + getModel: () => session.model, + isIdle: () => !session.isStreaming, + abort: () => session.abort(), + hasPendingMessages: () => session.pendingMessageCount > 0, }, - appendEntryHandler: (customType, data) => { - session.sessionManager.appendCustomEntry(customType, data); + // ExtensionCommandContextActions - commands invokable via prompt("/command") + { + waitForIdle: () => session.agent.waitForIdle(), + newSession: async (options) => { + const success = await session.newSession({ parentSession: options?.parentSession }); + if (success && options?.setup) { + await options.setup(session.sessionManager); + } + return { cancelled: !success }; + }, + branch: async (entryId) => { + const result = await session.branch(entryId); + return { cancelled: result.cancelled }; + }, + navigateTree: async (targetId, options) => { + const result = await session.navigateTree(targetId, { summarize: options?.summarize }); + return { cancelled: result.cancelled }; + }, }, - getActiveToolsHandler: () => session.getActiveToolNames(), - getAllToolsHandler: () => session.getAllToolNames(), - setActiveToolsHandler: (toolNames: string[]) => session.setActiveToolsByName(toolNames), - setModelHandler: async (model) => { - const key = await session.modelRegistry.getApiKey(model); - if (!key) return false; - await session.setModel(model); - return true; - }, - getThinkingLevelHandler: () => session.thinkingLevel, - setThinkingLevelHandler: (level) => session.setThinkingLevel(level), - }); + // No UI context + ); extensionRunner.onError((err) => { console.error(`Extension error (${err.extensionPath}): ${err.error}`); }); diff --git a/packages/coding-agent/src/modes/rpc/rpc-mode.ts b/packages/coding-agent/src/modes/rpc/rpc-mode.ts index 4906b695..4808cd51 100644 --- a/packages/coding-agent/src/modes/rpc/rpc-mode.ts +++ b/packages/coding-agent/src/modes/rpc/rpc-mode.ts @@ -231,35 +231,63 @@ export async function runRpcMode(session: AgentSession): Promise { // Set up extensions with RPC-based UI context const extensionRunner = session.extensionRunner; if (extensionRunner) { - extensionRunner.initialize({ - getModel: () => session.agent.state.model, - sendMessageHandler: (message, options) => { - session.sendCustomMessage(message, options).catch((e) => { - output(error(undefined, "extension_send", e.message)); - }); + extensionRunner.initialize( + // ExtensionActions + { + sendMessage: (message, options) => { + session.sendCustomMessage(message, options).catch((e) => { + output(error(undefined, "extension_send", e.message)); + }); + }, + sendUserMessage: (content, options) => { + session.sendUserMessage(content, options).catch((e) => { + output(error(undefined, "extension_send_user", e.message)); + }); + }, + appendEntry: (customType, data) => { + session.sessionManager.appendCustomEntry(customType, data); + }, + getActiveTools: () => session.getActiveToolNames(), + getAllTools: () => session.getAllToolNames(), + setActiveTools: (toolNames: string[]) => session.setActiveToolsByName(toolNames), + setModel: async (model) => { + const key = await session.modelRegistry.getApiKey(model); + if (!key) return false; + await session.setModel(model); + return true; + }, + getThinkingLevel: () => session.thinkingLevel, + setThinkingLevel: (level) => session.setThinkingLevel(level), }, - sendUserMessageHandler: (content, options) => { - session.sendUserMessage(content, options).catch((e) => { - output(error(undefined, "extension_send_user", e.message)); - }); + // ExtensionContextActions + { + getModel: () => session.agent.state.model, + isIdle: () => !session.isStreaming, + abort: () => session.abort(), + hasPendingMessages: () => session.pendingMessageCount > 0, }, - appendEntryHandler: (customType, data) => { - session.sessionManager.appendCustomEntry(customType, data); + // ExtensionCommandContextActions - commands invokable via prompt("/command") + { + waitForIdle: () => session.agent.waitForIdle(), + newSession: async (options) => { + const success = await session.newSession({ parentSession: options?.parentSession }); + // Note: setup callback runs but no UI feedback in RPC mode + if (success && options?.setup) { + await options.setup(session.sessionManager); + } + return { cancelled: !success }; + }, + branch: async (entryId) => { + const result = await session.branch(entryId); + return { cancelled: result.cancelled }; + }, + navigateTree: async (targetId, options) => { + const result = await session.navigateTree(targetId, { summarize: options?.summarize }); + return { cancelled: result.cancelled }; + }, }, - getActiveToolsHandler: () => session.getActiveToolNames(), - getAllToolsHandler: () => session.getAllToolNames(), - setActiveToolsHandler: (toolNames: string[]) => session.setActiveToolsByName(toolNames), - setModelHandler: async (model) => { - const key = await session.modelRegistry.getApiKey(model); - if (!key) return false; - await session.setModel(model); - return true; - }, - getThinkingLevelHandler: () => session.thinkingLevel, - setThinkingLevelHandler: (level) => session.setThinkingLevel(level), - uiContext: createExtensionUIContext(), - hasUI: false, - }); + createExtensionUIContext(), + ); extensionRunner.onError((err) => { output({ type: "extension_error", extensionPath: err.extensionPath, event: err.event, error: err.error }); }); diff --git a/packages/coding-agent/test/compaction-extensions.test.ts b/packages/coding-agent/test/compaction-extensions.test.ts index bf906256..344227a4 100644 --- a/packages/coding-agent/test/compaction-extensions.test.ts +++ b/packages/coding-agent/test/compaction-extensions.test.ts @@ -11,8 +11,9 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { AgentSession } from "../src/core/agent-session.js"; import { AuthStorage } from "../src/core/auth-storage.js"; import { + createExtensionRuntime, + type Extension, ExtensionRunner, - type LoadedExtension, type SessionBeforeCompactEvent, type SessionCompactEvent, type SessionEvent, @@ -21,7 +22,6 @@ import { ModelRegistry } from "../src/core/model-registry.js"; import { SessionManager } from "../src/core/session-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js"; import { codingTools } from "../src/core/tools/index.js"; -import { theme } from "../src/modes/interactive/theme/theme.js"; const API_KEY = process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; @@ -49,7 +49,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { function createExtension( onBeforeCompact?: (event: SessionBeforeCompactEvent) => { cancel?: boolean; compaction?: any } | undefined, onCompact?: (event: SessionCompactEvent) => void, - ): LoadedExtension { + ): Extension { const handlers = new Map Promise)[]>(); handlers.set("session_before_compact", [ @@ -80,22 +80,11 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { messageRenderers: new Map(), commands: new Map(), flags: new Map(), - flagValues: new Map(), shortcuts: new Map(), - setSendMessageHandler: () => {}, - setSendUserMessageHandler: () => {}, - setAppendEntryHandler: () => {}, - setGetActiveToolsHandler: () => {}, - setGetAllToolsHandler: () => {}, - setSetActiveToolsHandler: () => {}, - setSetModelHandler: () => {}, - setGetThinkingLevelHandler: () => {}, - setSetThinkingLevelHandler: () => {}, - setFlagValue: () => {}, }; } - function createSession(extensions: LoadedExtension[]) { + function createSession(extensions: Extension[]) { const model = getModel("anthropic", "claude-sonnet-4-5")!; const agent = new Agent({ getApiKey: () => API_KEY, @@ -111,39 +100,29 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { const authStorage = new AuthStorage(join(tempDir, "auth.json")); const modelRegistry = new ModelRegistry(authStorage); - extensionRunner = new ExtensionRunner(extensions, tempDir, sessionManager, modelRegistry); - extensionRunner.initialize({ - getModel: () => session.model, - sendMessageHandler: async () => {}, - sendUserMessageHandler: async () => {}, - appendEntryHandler: async () => {}, - getActiveToolsHandler: () => [], - getAllToolsHandler: () => [], - setActiveToolsHandler: () => {}, - setModelHandler: async () => false, - getThinkingLevelHandler: () => "off", - setThinkingLevelHandler: () => {}, - uiContext: { - select: async () => undefined, - confirm: async () => false, - input: async () => undefined, - notify: () => {}, - setStatus: () => {}, - setWidget: () => {}, - setFooter: () => {}, - setHeader: () => {}, - setTitle: () => {}, - custom: async () => undefined as never, - setEditorText: () => {}, - getEditorText: () => "", - editor: async () => undefined, - setEditorComponent: () => {}, - get theme() { - return theme; - }, + const runtime = createExtensionRuntime(); + extensionRunner = new ExtensionRunner(extensions, runtime, tempDir, sessionManager, modelRegistry); + extensionRunner.initialize( + // ExtensionActions + { + sendMessage: async () => {}, + sendUserMessage: async () => {}, + appendEntry: async () => {}, + getActiveTools: () => [], + getAllTools: () => [], + setActiveTools: () => {}, + setModel: async () => false, + getThinkingLevel: () => "off", + setThinkingLevel: () => {}, }, - hasUI: false, - }); + // ExtensionContextActions + { + getModel: () => session.model, + isIdle: () => !session.isStreaming, + abort: () => session.abort(), + hasPendingMessages: () => session.pendingMessageCount > 0, + }, + ); session = new AgentSession({ agent, @@ -264,7 +243,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { }, 120000); it("should continue with default compaction if extension throws error", async () => { - const throwingExtension: LoadedExtension = { + const throwingExtension: Extension = { path: "throwing-extension", resolvedPath: "/test/throwing-extension.ts", handlers: new Map Promise)[]>([ @@ -291,18 +270,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { messageRenderers: new Map(), commands: new Map(), flags: new Map(), - flagValues: new Map(), shortcuts: new Map(), - setSendMessageHandler: () => {}, - setSendUserMessageHandler: () => {}, - setAppendEntryHandler: () => {}, - setGetActiveToolsHandler: () => {}, - setGetAllToolsHandler: () => {}, - setSetActiveToolsHandler: () => {}, - setSetModelHandler: () => {}, - setGetThinkingLevelHandler: () => {}, - setSetThinkingLevelHandler: () => {}, - setFlagValue: () => {}, }; createSession([throwingExtension]); @@ -323,7 +291,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { it("should call multiple extensions in order", async () => { const callOrder: string[] = []; - const extension1: LoadedExtension = { + const extension1: Extension = { path: "extension1", resolvedPath: "/test/extension1.ts", handlers: new Map Promise)[]>([ @@ -350,21 +318,10 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { messageRenderers: new Map(), commands: new Map(), flags: new Map(), - flagValues: new Map(), shortcuts: new Map(), - setSendMessageHandler: () => {}, - setSendUserMessageHandler: () => {}, - setAppendEntryHandler: () => {}, - setGetActiveToolsHandler: () => {}, - setGetAllToolsHandler: () => {}, - setSetActiveToolsHandler: () => {}, - setSetModelHandler: () => {}, - setGetThinkingLevelHandler: () => {}, - setSetThinkingLevelHandler: () => {}, - setFlagValue: () => {}, }; - const extension2: LoadedExtension = { + const extension2: Extension = { path: "extension2", resolvedPath: "/test/extension2.ts", handlers: new Map Promise)[]>([ @@ -391,18 +348,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { messageRenderers: new Map(), commands: new Map(), flags: new Map(), - flagValues: new Map(), shortcuts: new Map(), - setSendMessageHandler: () => {}, - setSendUserMessageHandler: () => {}, - setAppendEntryHandler: () => {}, - setGetActiveToolsHandler: () => {}, - setGetAllToolsHandler: () => {}, - setSetActiveToolsHandler: () => {}, - setSetModelHandler: () => {}, - setGetThinkingLevelHandler: () => {}, - setSetThinkingLevelHandler: () => {}, - setFlagValue: () => {}, }; createSession([extension1, extension2]); diff --git a/packages/coding-agent/test/extensions-runner.test.ts b/packages/coding-agent/test/extensions-runner.test.ts index 6fd4f37e..82574da0 100644 --- a/packages/coding-agent/test/extensions-runner.test.ts +++ b/packages/coding-agent/test/extensions-runner.test.ts @@ -46,7 +46,7 @@ describe("ExtensionRunner", () => { const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const shortcuts = runner.getShortcuts(); expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("conflicts with built-in")); @@ -79,7 +79,7 @@ describe("ExtensionRunner", () => { const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const shortcuts = runner.getShortcuts(); expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("shortcut conflict")); @@ -108,7 +108,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "tool-b.ts"), toolCode("tool_b")); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const tools = runner.getAllRegisteredTools(); expect(tools.length).toBe(2); @@ -130,7 +130,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "cmd-b.ts"), cmdCode("cmd-b")); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const commands = runner.getRegisteredCommands(); expect(commands.length).toBe(2); @@ -149,7 +149,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "cmd.ts"), cmdCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const cmd = runner.getCommand("my-cmd"); expect(cmd).toBeDefined(); @@ -173,7 +173,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "throws.ts"), extCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const errors: Array<{ extensionPath: string; event: string; error: string }> = []; runner.onError((err) => { @@ -199,7 +199,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "renderer.ts"), extCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const renderer = runner.getMessageRenderer("my-type"); expect(renderer).toBeDefined(); @@ -222,7 +222,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "with-flag.ts"), extCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); const flags = runner.getFlags(); expect(flags.has("--my-flag")).toBe(true); @@ -240,14 +240,13 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "flag.ts"), extCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); // Setting a flag value should not throw runner.setFlagValue("--test-flag", true); - // The flag values are stored in the extension's flagValues map - const ext = result.extensions[0]; - expect(ext.flagValues.get("--test-flag")).toBe(true); + // The flag values are stored in the shared runtime + expect(result.runtime.flagValues.get("--test-flag")).toBe(true); }); }); @@ -261,7 +260,7 @@ describe("ExtensionRunner", () => { fs.writeFileSync(path.join(extensionsDir, "handler.ts"), extCode); const result = await discoverAndLoadExtensions([], tempDir, tempDir); - const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); + const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry); expect(runner.hasHandlers("tool_call")).toBe(true); expect(runner.hasHandlers("agent_end")).toBe(false);