refactor(coding-agent): simplify extension runtime architecture

- Replace per-extension closures with shared ExtensionRuntime
- Split context actions: ExtensionContextActions (required) + ExtensionCommandContextActions (optional)
- Rename LoadedExtension to Extension, remove setter methods
- Change runner.initialize() from options object to positional params
- Derive hasUI from uiContext presence (no separate param)
- Add warning when extensions override built-in tools
- RPC and print modes now provide full command context actions

BREAKING CHANGE: Extension system types and initialization API changed.
See CHANGELOG.md for migration details.
This commit is contained in:
Mario Zechner 2026-01-07 23:50:18 +01:00
parent faa26ffbf9
commit cb3ac0ba9e
16 changed files with 580 additions and 736 deletions

View file

@ -5,11 +5,31 @@
### Breaking Changes ### Breaking Changes
- `ctx.ui.custom()` factory signature changed from `(tui, theme, done)` to `(tui, theme, keybindings, done)` for consistency with other input-handling factories - `ctx.ui.custom()` factory signature changed from `(tui, theme, done)` to `(tui, theme, keybindings, done)` for consistency with other input-handling factories
- Extension system refactored: `LoadedExtension` renamed to `Extension`, setter methods removed
- `LoadExtensionsResult.setUIContext()` removed, replaced with `runtime: ExtensionRuntime`
- `ExtensionRunner` constructor now requires `runtime: ExtensionRuntime` as second parameter
- `ExtensionRunner.initialize()` signature changed from options object to `(actions, contextActions, commandContextActions?, uiContext?)`
- `ExtensionRunner.getHasUI()` renamed to `hasUI()`
- `CreateAgentSessionOptions.preloadedExtensions` renamed to `preloadedExtensionsResult`
- `CreateAgentSessionResult` now returns `extensionsResult: LoadExtensionsResult` instead of `customToolsResult`
### Added ### Added
- Extension UI dialogs (`ctx.ui.select()`, `ctx.ui.confirm()`, `ctx.ui.input()`) now support a `timeout` option that auto-dismisses the dialog with a live countdown display. Simpler alternative to `AbortSignal` for timed dialogs. - Extension UI dialogs (`ctx.ui.select()`, `ctx.ui.confirm()`, `ctx.ui.input()`) now support a `timeout` option that auto-dismisses the dialog with a live countdown display. Simpler alternative to `AbortSignal` for timed dialogs.
- Extensions can now provide custom editor components via `ctx.ui.setEditorComponent((tui, theme, keybindings) => ...)`. Extend `CustomEditor` for full app keybinding support (escape, ctrl+d, model switching, etc.). See `examples/extensions/modal-editor.ts`, `examples/extensions/rainbow-editor.ts`, and `docs/tui.md` Pattern 7. - Extensions can now provide custom editor components via `ctx.ui.setEditorComponent((tui, theme, keybindings) => ...)`. Extend `CustomEditor` for full app keybinding support (escape, ctrl+d, model switching, etc.). See `examples/extensions/modal-editor.ts`, `examples/extensions/rainbow-editor.ts`, and `docs/tui.md` Pattern 7.
- `ExtensionRuntime` interface for shared runtime state and action methods
- `ExtensionActions` interface for `pi.*` API methods
- `ExtensionContextActions` interface for `ctx.*` in event handlers
- `ExtensionCommandContextActions` interface for `ctx.*` in command handlers (session control)
- `createExtensionRuntime()` function to create runtime with throwing stubs
- `Extension` type exported (cleaner name for loaded extension data)
- Interactive mode now warns when extensions override built-in tools (read, bash, edit, write, grep, find, ls)
### Changed
- Extension loader simplified: shared runtime instead of per-extension closures
- `hasUI` now derived from whether `uiContext` is provided (no longer a separate parameter)
- RPC and print modes now provide `ExtensionCommandContextActions` for full command support
### Fixed ### Fixed

View file

@ -924,6 +924,21 @@ pi.registerTool({
**Important:** Use `StringEnum` from `@mariozechner/pi-ai` for string enums. `Type.Union`/`Type.Literal` doesn't work with Google's API. **Important:** Use `StringEnum` from `@mariozechner/pi-ai` for string enums. `Type.Union`/`Type.Literal` doesn't work with Google's API.
### Overriding Built-in Tools
Extensions can override built-in tools (`read`, `bash`, `edit`, `write`, `grep`, `find`, `ls`) by registering a tool with the same name. Interactive mode displays a warning when this happens.
**Your implementation must match the exact result shape**, including the `details` type. The UI and session logic depend on these shapes for rendering and state tracking.
Built-in tool implementations:
- [read.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/read.ts) - `ReadToolDetails`
- [bash.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/bash.ts) - `BashToolDetails`
- [edit.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/edit.ts)
- [write.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/write.ts)
- [grep.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/grep.ts) - `GrepToolDetails`
- [find.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/find.ts) - `FindToolDetails`
- [ls.ts](https://github.com/badlogic/pi-mono/blob/main/packages/coding-agent/src/core/tools/ls.ts) - `LsToolDetails`
### Output Truncation ### Output Truncation
**Tools MUST truncate their output** to avoid overwhelming the LLM context. Large outputs can cause: **Tools MUST truncate their output** to avoid overwhelming the LLM context. Large outputs can cause:

View file

@ -784,15 +784,18 @@ interface CreateAgentSessionResult {
// The session // The session
session: AgentSession; session: AgentSession;
// Custom tools (for UI setup) // Extensions result (for runner setup)
customToolsResult: { extensionsResult: LoadExtensionsResult;
tools: LoadedCustomTool[];
setUIContext: (ctx, hasUI) => void;
};
// Warning if session model couldn't be restored // Warning if session model couldn't be restored
modelFallbackMessage?: string; modelFallbackMessage?: string;
} }
interface LoadExtensionsResult {
extensions: Extension[];
errors: Array<{ path: string; error: string }>;
runtime: ExtensionRuntime;
}
``` ```
## Complete Example ## Complete Example

View file

@ -2,7 +2,12 @@
* Extension system for lifecycle events and custom tools. * Extension system for lifecycle events and custom tools.
*/ */
export { discoverAndLoadExtensions, loadExtensionFromFactory, loadExtensions } from "./loader.js"; export {
createExtensionRuntime,
discoverAndLoadExtensions,
loadExtensionFromFactory,
loadExtensions,
} from "./loader.js";
export type { BranchHandler, ExtensionErrorListener, NavigateTreeHandler, NewSessionHandler } from "./runner.js"; export type { BranchHandler, ExtensionErrorListener, NavigateTreeHandler, NewSessionHandler } from "./runner.js";
export { ExtensionRunner } from "./runner.js"; export { ExtensionRunner } from "./runner.js";
export type { export type {
@ -25,17 +30,23 @@ export type {
EditToolResultEvent, EditToolResultEvent,
ExecOptions, ExecOptions,
ExecResult, ExecResult,
Extension,
ExtensionActions,
// API // API
ExtensionAPI, ExtensionAPI,
ExtensionCommandContext, ExtensionCommandContext,
ExtensionCommandContextActions,
// Context // Context
ExtensionContext, ExtensionContext,
ExtensionContextActions,
// Errors // Errors
ExtensionError, ExtensionError,
ExtensionEvent, ExtensionEvent,
ExtensionFactory, ExtensionFactory,
ExtensionFlag, ExtensionFlag,
ExtensionHandler, ExtensionHandler,
// Runtime
ExtensionRuntime,
ExtensionShortcut, ExtensionShortcut,
ExtensionUIContext, ExtensionUIContext,
ExtensionUIDialogOptions, ExtensionUIDialogOptions,
@ -46,8 +57,6 @@ export type {
GrepToolResultEvent, GrepToolResultEvent,
KeybindingsManager, KeybindingsManager,
LoadExtensionsResult, LoadExtensionsResult,
// Loaded Extension
LoadedExtension,
LsToolResultEvent, LsToolResultEvent,
// Message Rendering // Message Rendering
MessageRenderer, MessageRenderer,

View file

@ -10,30 +10,17 @@ import { fileURLToPath } from "node:url";
import type { KeyId } from "@mariozechner/pi-tui"; import type { KeyId } from "@mariozechner/pi-tui";
import { createJiti } from "jiti"; import { createJiti } from "jiti";
import { getAgentDir, isBunBinary } from "../../config.js"; import { getAgentDir, isBunBinary } from "../../config.js";
import { theme } from "../../modes/interactive/theme/theme.js";
import { createEventBus, type EventBus } from "../event-bus.js"; import { createEventBus, type EventBus } from "../event-bus.js";
import type { ExecOptions } from "../exec.js"; import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js"; import { execCommand } from "../exec.js";
import type { import type {
AppendEntryHandler, Extension,
ExtensionAPI, ExtensionAPI,
ExtensionFactory, ExtensionFactory,
ExtensionFlag, ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
GetActiveToolsHandler,
GetAllToolsHandler,
GetThinkingLevelHandler,
LoadExtensionsResult, LoadExtensionsResult,
LoadedExtension,
MessageRenderer, MessageRenderer,
RegisteredCommand, RegisteredCommand,
RegisteredTool,
SendMessageHandler,
SendUserMessageHandler,
SetActiveToolsHandler,
SetModelHandler,
SetThinkingLevelHandler,
ToolDefinition, ToolDefinition,
} from "./types.js"; } from "./types.js";
@ -84,87 +71,59 @@ function resolvePath(extPath: string, cwd: string): string {
return path.resolve(cwd, expanded); return path.resolve(cwd, expanded);
} }
function createNoOpUIContext(): ExtensionUIContext { type HandlerFn = (...args: unknown[]) => Promise<unknown>;
/**
* Create a runtime with throwing stubs for action methods.
* Runner.initialize() replaces these with real implementations.
*/
export function createExtensionRuntime(): ExtensionRuntime {
const notInitialized = () => {
throw new Error("Extension runtime not initialized. Action methods cannot be called during extension loading.");
};
return { return {
select: async () => undefined, sendMessage: notInitialized,
confirm: async () => false, sendUserMessage: notInitialized,
input: async () => undefined, appendEntry: notInitialized,
notify: () => {}, getActiveTools: notInitialized,
setStatus: () => {}, getAllTools: notInitialized,
setWidget: () => {}, setActiveTools: notInitialized,
setFooter: () => {}, setModel: () => Promise.reject(new Error("Extension runtime not initialized")),
setHeader: () => {}, getThinkingLevel: notInitialized,
setTitle: () => {}, setThinkingLevel: notInitialized,
custom: async () => undefined as never, flagValues: new Map(),
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
}; };
} }
type HandlerFn = (...args: unknown[]) => Promise<unknown>; /**
* Create the ExtensionAPI for an extension.
* Registration methods write to the extension object.
* Action methods delegate to the shared runtime.
*/
function createExtensionAPI( function createExtensionAPI(
handlers: Map<string, HandlerFn[]>, extension: Extension,
tools: Map<string, RegisteredTool>, runtime: ExtensionRuntime,
cwd: string, cwd: string,
extensionPath: string,
eventBus: EventBus, eventBus: EventBus,
_sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, ): ExtensionAPI {
): {
api: ExtensionAPI;
messageRenderers: Map<string, MessageRenderer>;
commands: Map<string, RegisteredCommand>;
flags: Map<string, ExtensionFlag>;
flagValues: Map<string, boolean | string>;
shortcuts: Map<KeyId, ExtensionShortcut>;
setSendMessageHandler: (handler: SendMessageHandler) => void;
setSendUserMessageHandler: (handler: SendUserMessageHandler) => void;
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void;
setGetAllToolsHandler: (handler: GetAllToolsHandler) => void;
setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void;
setSetModelHandler: (handler: SetModelHandler) => void;
setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => void;
setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => void;
setFlagValue: (name: string, value: boolean | string) => void;
} {
let sendMessageHandler: SendMessageHandler = () => {};
let sendUserMessageHandler: SendUserMessageHandler = () => {};
let appendEntryHandler: AppendEntryHandler = () => {};
let getActiveToolsHandler: GetActiveToolsHandler = () => [];
let getAllToolsHandler: GetAllToolsHandler = () => [];
let setActiveToolsHandler: SetActiveToolsHandler = () => {};
let setModelHandler: SetModelHandler = async () => false;
let getThinkingLevelHandler: GetThinkingLevelHandler = () => "off";
let setThinkingLevelHandler: SetThinkingLevelHandler = () => {};
const messageRenderers = new Map<string, MessageRenderer>();
const commands = new Map<string, RegisteredCommand>();
const flags = new Map<string, ExtensionFlag>();
const flagValues = new Map<string, boolean | string>();
const shortcuts = new Map<KeyId, ExtensionShortcut>();
const api = { const api = {
// Registration methods - write to extension
on(event: string, handler: HandlerFn): void { on(event: string, handler: HandlerFn): void {
const list = handlers.get(event) ?? []; const list = extension.handlers.get(event) ?? [];
list.push(handler); list.push(handler);
handlers.set(event, list); extension.handlers.set(event, list);
}, },
registerTool(tool: ToolDefinition): void { registerTool(tool: ToolDefinition): void {
tools.set(tool.name, { extension.tools.set(tool.name, {
definition: tool, definition: tool,
extensionPath, extensionPath: extension.path,
}); });
}, },
registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void { registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void {
commands.set(name, { name, ...options }); extension.commands.set(name, { name, ...options });
}, },
registerShortcut( registerShortcut(
@ -174,37 +133,40 @@ function createExtensionAPI(
handler: (ctx: import("./types.js").ExtensionContext) => Promise<void> | void; handler: (ctx: import("./types.js").ExtensionContext) => Promise<void> | void;
}, },
): void { ): void {
shortcuts.set(shortcut, { shortcut, extensionPath, ...options }); extension.shortcuts.set(shortcut, { shortcut, extensionPath: extension.path, ...options });
}, },
registerFlag( registerFlag(
name: string, name: string,
options: { description?: string; type: "boolean" | "string"; default?: boolean | string }, options: { description?: string; type: "boolean" | "string"; default?: boolean | string },
): void { ): void {
flags.set(name, { name, extensionPath, ...options }); extension.flags.set(name, { name, extensionPath: extension.path, ...options });
if (options.default !== undefined) { if (options.default !== undefined) {
flagValues.set(name, options.default); runtime.flagValues.set(name, options.default);
} }
}, },
getFlag(name: string): boolean | string | undefined {
return flagValues.get(name);
},
registerMessageRenderer<T>(customType: string, renderer: MessageRenderer<T>): void { registerMessageRenderer<T>(customType: string, renderer: MessageRenderer<T>): void {
messageRenderers.set(customType, renderer as MessageRenderer); extension.messageRenderers.set(customType, renderer as MessageRenderer);
}, },
// Flag access - checks extension registered it, reads from runtime
getFlag(name: string): boolean | string | undefined {
if (!extension.flags.has(name)) return undefined;
return runtime.flagValues.get(name);
},
// Action methods - delegate to shared runtime
sendMessage(message, options): void { sendMessage(message, options): void {
sendMessageHandler(message, options); runtime.sendMessage(message, options);
}, },
sendUserMessage(content, options): void { sendUserMessage(content, options): void {
sendUserMessageHandler(content, options); runtime.sendUserMessage(content, options);
}, },
appendEntry(customType: string, data?: unknown): void { appendEntry(customType: string, data?: unknown): void {
appendEntryHandler(customType, data); runtime.appendEntry(customType, data);
}, },
exec(command: string, args: string[], options?: ExecOptions) { exec(command: string, args: string[], options?: ExecOptions) {
@ -212,222 +174,86 @@ function createExtensionAPI(
}, },
getActiveTools(): string[] { getActiveTools(): string[] {
return getActiveToolsHandler(); return runtime.getActiveTools();
}, },
getAllTools(): string[] { getAllTools(): string[] {
return getAllToolsHandler(); return runtime.getAllTools();
}, },
setActiveTools(toolNames: string[]): void { setActiveTools(toolNames: string[]): void {
setActiveToolsHandler(toolNames); runtime.setActiveTools(toolNames);
}, },
setModel(model) { setModel(model) {
return setModelHandler(model); return runtime.setModel(model);
}, },
getThinkingLevel() { getThinkingLevel() {
return getThinkingLevelHandler(); return runtime.getThinkingLevel();
}, },
setThinkingLevel(level) { setThinkingLevel(level) {
setThinkingLevelHandler(level); runtime.setThinkingLevel(level);
}, },
events: eventBus, events: eventBus,
} as ExtensionAPI; } as ExtensionAPI;
return { return api;
api,
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler: (handler: SendMessageHandler) => {
sendMessageHandler = handler;
},
setSendUserMessageHandler: (handler: SendUserMessageHandler) => {
sendUserMessageHandler = handler;
},
setAppendEntryHandler: (handler: AppendEntryHandler) => {
appendEntryHandler = handler;
},
setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => {
getActiveToolsHandler = handler;
},
setGetAllToolsHandler: (handler: GetAllToolsHandler) => {
getAllToolsHandler = handler;
},
setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => {
setActiveToolsHandler = handler;
},
setSetModelHandler: (handler: SetModelHandler) => {
setModelHandler = handler;
},
setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => {
getThinkingLevelHandler = handler;
},
setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => {
setThinkingLevelHandler = handler;
},
setFlagValue: (name: string, value: boolean | string) => {
flagValues.set(name, value);
},
};
} }
async function loadExtensionWithBun( async function loadBun(path: string) {
resolvedPath: string, const module = await import(path);
cwd: string, const factory = (module.default ?? module) as ExtensionFactory;
extensionPath: string, return typeof factory !== "function" ? undefined : factory;
eventBus: EventBus, }
sharedUI: { ui: ExtensionUIContext; hasUI: boolean },
): Promise<{ extension: LoadedExtension | null; error: string | null }> {
try {
const module = await import(resolvedPath);
const factory = (module.default ?? module) as ExtensionFactory;
if (typeof factory !== "function") { async function loadJiti(path: string) {
return { extension: null, error: "Extension must export a default function" }; const jiti = createJiti(import.meta.url, {
} alias: getAliases(),
});
const handlers = new Map<string, HandlerFn[]>(); const module = await jiti.import(path, { default: true });
const tools = new Map<string, RegisteredTool>(); const factory = module as ExtensionFactory;
const { return typeof factory !== "function" ? undefined : factory;
api, }
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler,
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
} = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI);
await factory(api); /**
* Create an Extension object with empty collections.
return { */
extension: { function createExtension(extensionPath: string, resolvedPath: string): Extension {
path: extensionPath, return {
resolvedPath, path: extensionPath,
handlers, resolvedPath,
tools, handlers: new Map(),
messageRenderers, tools: new Map(),
commands, messageRenderers: new Map(),
flags, commands: new Map(),
flagValues, flags: new Map(),
shortcuts, shortcuts: new Map(),
setSendMessageHandler, };
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
},
error: null,
};
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
if (message.includes("Cannot find module") && message.includes("@mariozechner/")) {
return {
extension: null,
error:
`${message}\n` +
"Note: Extensions importing from @mariozechner/* packages are not supported in the standalone binary.\n" +
"Please install pi via npm: npm install -g @mariozechner/pi-coding-agent",
};
}
return { extension: null, error: `Failed to load extension: ${message}` };
}
} }
async function loadExtension( async function loadExtension(
extensionPath: string, extensionPath: string,
cwd: string, cwd: string,
eventBus: EventBus, eventBus: EventBus,
sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, runtime: ExtensionRuntime,
): Promise<{ extension: LoadedExtension | null; error: string | null }> { ): Promise<{ extension: Extension | null; error: string | null }> {
const resolvedPath = resolvePath(extensionPath, cwd); const resolvedPath = resolvePath(extensionPath, cwd);
if (isBunBinary) {
return loadExtensionWithBun(resolvedPath, cwd, extensionPath, eventBus, sharedUI);
}
try { try {
const jiti = createJiti(import.meta.url, { const factory = isBunBinary ? await loadBun(resolvedPath) : await loadJiti(resolvedPath);
alias: getAliases(), if (!factory) {
}); return { extension: null, error: `Extension does not export a valid factory function: ${extensionPath}` };
const module = await jiti.import(resolvedPath, { default: true });
const factory = module as ExtensionFactory;
if (typeof factory !== "function") {
return { extension: null, error: "Extension must export a default function" };
} }
const handlers = new Map<string, HandlerFn[]>(); const extension = createExtension(extensionPath, resolvedPath);
const tools = new Map<string, RegisteredTool>(); const api = createExtensionAPI(extension, runtime, cwd, eventBus);
const {
api,
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler,
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
} = createExtensionAPI(handlers, tools, cwd, extensionPath, eventBus, sharedUI);
await factory(api); await factory(api);
return { return { extension, error: null };
extension: {
path: extensionPath,
resolvedPath,
handlers,
tools,
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler,
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
},
error: null,
};
} catch (err) { } catch (err) {
const message = err instanceof Error ? err.message : String(err); const message = err instanceof Error ? err.message : String(err);
return { extension: null, error: `Failed to load extension: ${message}` }; return { extension: null, error: `Failed to load extension: ${message}` };
@ -435,72 +261,32 @@ async function loadExtension(
} }
/** /**
* Create a LoadedExtension from an inline factory function. * Create an Extension from an inline factory function.
*/ */
export async function loadExtensionFromFactory( export async function loadExtensionFromFactory(
factory: ExtensionFactory, factory: ExtensionFactory,
cwd: string, cwd: string,
eventBus: EventBus, eventBus: EventBus,
sharedUI: { ui: ExtensionUIContext; hasUI: boolean }, runtime: ExtensionRuntime,
name = "<inline>", extensionPath = "<inline>",
): Promise<LoadedExtension> { ): Promise<Extension> {
const handlers = new Map<string, HandlerFn[]>(); const extension = createExtension(extensionPath, extensionPath);
const tools = new Map<string, RegisteredTool>(); const api = createExtensionAPI(extension, runtime, cwd, eventBus);
const {
api,
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler,
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
} = createExtensionAPI(handlers, tools, cwd, name, eventBus, sharedUI);
await factory(api); await factory(api);
return extension;
return {
path: name,
resolvedPath: name,
handlers,
tools,
messageRenderers,
commands,
flags,
flagValues,
shortcuts,
setSendMessageHandler,
setSendUserMessageHandler,
setAppendEntryHandler,
setGetActiveToolsHandler,
setGetAllToolsHandler,
setSetActiveToolsHandler,
setSetModelHandler,
setGetThinkingLevelHandler,
setSetThinkingLevelHandler,
setFlagValue,
};
} }
/** /**
* Load extensions from paths. * Load extensions from paths.
*/ */
export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise<LoadExtensionsResult> { export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise<LoadExtensionsResult> {
const extensions: LoadedExtension[] = []; const extensions: Extension[] = [];
const errors: Array<{ path: string; error: string }> = []; const errors: Array<{ path: string; error: string }> = [];
const resolvedEventBus = eventBus ?? createEventBus(); const resolvedEventBus = eventBus ?? createEventBus();
const sharedUI = { ui: createNoOpUIContext(), hasUI: false }; const runtime = createExtensionRuntime();
for (const extPath of paths) { for (const extPath of paths) {
const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, sharedUI); const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, runtime);
if (error) { if (error) {
errors.push({ path: extPath, error }); errors.push({ path: extPath, error });
@ -515,10 +301,7 @@ export async function loadExtensions(paths: string[], cwd: string, eventBus?: Ev
return { return {
extensions, extensions,
errors, errors,
setUIContext(uiContext, hasUI) { runtime,
sharedUI.ui = uiContext;
sharedUI.hasUI = hasUI;
},
}; };
} }

View file

@ -9,32 +9,27 @@ import { theme } from "../../modes/interactive/theme/theme.js";
import type { ModelRegistry } from "../model-registry.js"; import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js"; import type { SessionManager } from "../session-manager.js";
import type { import type {
AppendEntryHandler,
BeforeAgentStartEvent, BeforeAgentStartEvent,
BeforeAgentStartEventResult, BeforeAgentStartEventResult,
ContextEvent, ContextEvent,
ContextEventResult, ContextEventResult,
Extension,
ExtensionActions,
ExtensionCommandContext, ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext, ExtensionContext,
ExtensionContextActions,
ExtensionError, ExtensionError,
ExtensionEvent, ExtensionEvent,
ExtensionFlag, ExtensionFlag,
ExtensionRuntime,
ExtensionShortcut, ExtensionShortcut,
ExtensionUIContext, ExtensionUIContext,
GetActiveToolsHandler,
GetAllToolsHandler,
GetThinkingLevelHandler,
LoadedExtension,
MessageRenderer, MessageRenderer,
RegisteredCommand, RegisteredCommand,
RegisteredTool, RegisteredTool,
SendMessageHandler,
SendUserMessageHandler,
SessionBeforeCompactResult, SessionBeforeCompactResult,
SessionBeforeTreeResult, SessionBeforeTreeResult,
SetActiveToolsHandler,
SetModelHandler,
SetThinkingLevelHandler,
ToolCallEvent, ToolCallEvent,
ToolCallEventResult, ToolCallEventResult,
ToolResultEventResult, ToolResultEventResult,
@ -81,9 +76,9 @@ const noOpUIContext: ExtensionUIContext = {
}; };
export class ExtensionRunner { export class ExtensionRunner {
private extensions: LoadedExtension[]; private extensions: Extension[];
private runtime: ExtensionRuntime;
private uiContext: ExtensionUIContext; private uiContext: ExtensionUIContext;
private hasUI: boolean;
private cwd: string; private cwd: string;
private sessionManager: SessionManager; private sessionManager: SessionManager;
private modelRegistry: ModelRegistry; private modelRegistry: ModelRegistry;
@ -98,78 +93,60 @@ export class ExtensionRunner {
private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false }); private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false });
constructor( constructor(
extensions: LoadedExtension[], extensions: Extension[],
runtime: ExtensionRuntime,
cwd: string, cwd: string,
sessionManager: SessionManager, sessionManager: SessionManager,
modelRegistry: ModelRegistry, modelRegistry: ModelRegistry,
) { ) {
this.extensions = extensions; this.extensions = extensions;
this.runtime = runtime;
this.uiContext = noOpUIContext; this.uiContext = noOpUIContext;
this.hasUI = false;
this.cwd = cwd; this.cwd = cwd;
this.sessionManager = sessionManager; this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry; this.modelRegistry = modelRegistry;
} }
initialize(options: { initialize(
getModel: () => Model<any> | undefined; actions: ExtensionActions,
sendMessageHandler: SendMessageHandler; contextActions: ExtensionContextActions,
sendUserMessageHandler: SendUserMessageHandler; commandContextActions?: ExtensionCommandContextActions,
appendEntryHandler: AppendEntryHandler; uiContext?: ExtensionUIContext,
getActiveToolsHandler: GetActiveToolsHandler; ): void {
getAllToolsHandler: GetAllToolsHandler; // Copy actions into the shared runtime (all extension APIs reference this)
setActiveToolsHandler: SetActiveToolsHandler; this.runtime.sendMessage = actions.sendMessage;
setModelHandler: SetModelHandler; this.runtime.sendUserMessage = actions.sendUserMessage;
getThinkingLevelHandler: GetThinkingLevelHandler; this.runtime.appendEntry = actions.appendEntry;
setThinkingLevelHandler: SetThinkingLevelHandler; this.runtime.getActiveTools = actions.getActiveTools;
newSessionHandler?: NewSessionHandler; this.runtime.getAllTools = actions.getAllTools;
branchHandler?: BranchHandler; this.runtime.setActiveTools = actions.setActiveTools;
navigateTreeHandler?: NavigateTreeHandler; this.runtime.setModel = actions.setModel;
isIdle?: () => boolean; this.runtime.getThinkingLevel = actions.getThinkingLevel;
waitForIdle?: () => Promise<void>; this.runtime.setThinkingLevel = actions.setThinkingLevel;
abort?: () => void;
hasPendingMessages?: () => boolean;
uiContext?: ExtensionUIContext;
hasUI?: boolean;
}): void {
this.getModel = options.getModel;
this.isIdleFn = options.isIdle ?? (() => true);
this.waitForIdleFn = options.waitForIdle ?? (async () => {});
this.abortFn = options.abort ?? (() => {});
this.hasPendingMessagesFn = options.hasPendingMessages ?? (() => false);
if (options.newSessionHandler) { // Context actions (required)
this.newSessionHandler = options.newSessionHandler; this.getModel = contextActions.getModel;
} this.isIdleFn = contextActions.isIdle;
if (options.branchHandler) { this.abortFn = contextActions.abort;
this.branchHandler = options.branchHandler; this.hasPendingMessagesFn = contextActions.hasPendingMessages;
}
if (options.navigateTreeHandler) { // Command context actions (optional, only for interactive mode)
this.navigateTreeHandler = options.navigateTreeHandler; if (commandContextActions) {
this.waitForIdleFn = commandContextActions.waitForIdle;
this.newSessionHandler = commandContextActions.newSession;
this.branchHandler = commandContextActions.branch;
this.navigateTreeHandler = commandContextActions.navigateTree;
} }
for (const ext of this.extensions) { this.uiContext = uiContext ?? noOpUIContext;
ext.setSendMessageHandler(options.sendMessageHandler);
ext.setSendUserMessageHandler(options.sendUserMessageHandler);
ext.setAppendEntryHandler(options.appendEntryHandler);
ext.setGetActiveToolsHandler(options.getActiveToolsHandler);
ext.setGetAllToolsHandler(options.getAllToolsHandler);
ext.setSetActiveToolsHandler(options.setActiveToolsHandler);
ext.setSetModelHandler(options.setModelHandler);
ext.setGetThinkingLevelHandler(options.getThinkingLevelHandler);
ext.setSetThinkingLevelHandler(options.setThinkingLevelHandler);
}
this.uiContext = options.uiContext ?? noOpUIContext;
this.hasUI = options.hasUI ?? false;
} }
getUIContext(): ExtensionUIContext | null { getUIContext(): ExtensionUIContext {
return this.uiContext; return this.uiContext;
} }
getHasUI(): boolean { hasUI(): boolean {
return this.hasUI; return this.uiContext !== noOpUIContext;
} }
getExtensionPaths(): string[] { getExtensionPaths(): string[] {
@ -198,11 +175,7 @@ export class ExtensionRunner {
} }
setFlagValue(name: string, value: boolean | string): void { setFlagValue(name: string, value: boolean | string): void {
for (const ext of this.extensions) { this.runtime.flagValues.set(name, value);
if (ext.flags.has(name)) {
ext.setFlagValue(name, value);
}
}
} }
private static readonly RESERVED_SHORTCUTS = new Set([ private static readonly RESERVED_SHORTCUTS = new Set([
@ -301,7 +274,7 @@ export class ExtensionRunner {
private createContext(): ExtensionContext { private createContext(): ExtensionContext {
return { return {
ui: this.uiContext, ui: this.uiContext,
hasUI: this.hasUI, hasUI: this.hasUI(),
cwd: this.cwd, cwd: this.cwd,
sessionManager: this.sessionManager, sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry, modelRegistry: this.modelRegistry,

View file

@ -21,7 +21,7 @@ import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { CompactionPreparation, CompactionResult } from "../compaction/index.js"; import type { CompactionPreparation, CompactionResult } from "../compaction/index.js";
import type { EventBus } from "../event-bus.js"; import type { EventBus } from "../event-bus.js";
import type { ExecOptions, ExecResult } from "../exec.js"; import type { ExecOptions, ExecResult } from "../exec.js";
import type { AppAction, KeybindingsManager } from "../keybindings.js"; import type { KeybindingsManager } from "../keybindings.js";
import type { CustomMessage } from "../messages.js"; import type { CustomMessage } from "../messages.js";
import type { ModelRegistry } from "../model-registry.js"; import type { ModelRegistry } from "../model-registry.js";
import type { import type {
@ -742,8 +742,63 @@ export type GetThinkingLevelHandler = () => ThinkingLevel;
export type SetThinkingLevelHandler = (level: ThinkingLevel) => void; export type SetThinkingLevelHandler = (level: ThinkingLevel) => void;
/**
* Shared state created by loader, used during registration and runtime.
* Contains flag values (defaults set during registration, CLI values set after).
*/
export interface ExtensionRuntimeState {
flagValues: Map<string, boolean | string>;
}
/**
* Action implementations for pi.* API methods.
* Provided to runner.initialize(), copied into the shared runtime.
*/
export interface ExtensionActions {
sendMessage: SendMessageHandler;
sendUserMessage: SendUserMessageHandler;
appendEntry: AppendEntryHandler;
getActiveTools: GetActiveToolsHandler;
getAllTools: GetAllToolsHandler;
setActiveTools: SetActiveToolsHandler;
setModel: SetModelHandler;
getThinkingLevel: GetThinkingLevelHandler;
setThinkingLevel: SetThinkingLevelHandler;
}
/**
* Actions for ExtensionContext (ctx.* in event handlers).
* Required by all modes.
*/
export interface ExtensionContextActions {
getModel: () => Model<any> | undefined;
isIdle: () => boolean;
abort: () => void;
hasPendingMessages: () => boolean;
}
/**
* Actions for ExtensionCommandContext (ctx.* in command handlers).
* Only needed for interactive mode where extension commands are invokable.
*/
export interface ExtensionCommandContextActions {
waitForIdle: () => Promise<void>;
newSession: (options?: {
parentSession?: string;
setup?: (sessionManager: SessionManager) => Promise<void>;
}) => Promise<{ cancelled: boolean }>;
branch: (entryId: string) => Promise<{ cancelled: boolean }>;
navigateTree: (targetId: string, options?: { summarize?: boolean }) => Promise<{ cancelled: boolean }>;
}
/**
* Full runtime = state + actions.
* Created by loader with throwing action stubs, completed by runner.initialize().
*/
export interface ExtensionRuntime extends ExtensionRuntimeState, ExtensionActions {}
/** Loaded extension with all registered items. */ /** Loaded extension with all registered items. */
export interface LoadedExtension { export interface Extension {
path: string; path: string;
resolvedPath: string; resolvedPath: string;
handlers: Map<string, HandlerFn[]>; handlers: Map<string, HandlerFn[]>;
@ -751,25 +806,15 @@ export interface LoadedExtension {
messageRenderers: Map<string, MessageRenderer>; messageRenderers: Map<string, MessageRenderer>;
commands: Map<string, RegisteredCommand>; commands: Map<string, RegisteredCommand>;
flags: Map<string, ExtensionFlag>; flags: Map<string, ExtensionFlag>;
flagValues: Map<string, boolean | string>;
shortcuts: Map<KeyId, ExtensionShortcut>; shortcuts: Map<KeyId, ExtensionShortcut>;
setSendMessageHandler: (handler: SendMessageHandler) => void;
setSendUserMessageHandler: (handler: SendUserMessageHandler) => void;
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
setGetActiveToolsHandler: (handler: GetActiveToolsHandler) => void;
setGetAllToolsHandler: (handler: GetAllToolsHandler) => void;
setSetActiveToolsHandler: (handler: SetActiveToolsHandler) => void;
setSetModelHandler: (handler: SetModelHandler) => void;
setGetThinkingLevelHandler: (handler: GetThinkingLevelHandler) => void;
setSetThinkingLevelHandler: (handler: SetThinkingLevelHandler) => void;
setFlagValue: (name: string, value: boolean | string) => void;
} }
/** Result of loading extensions. */ /** Result of loading extensions. */
export interface LoadExtensionsResult { export interface LoadExtensionsResult {
extensions: LoadedExtension[]; extensions: Extension[];
errors: Array<{ path: string; error: string }>; errors: Array<{ path: string; error: string }>;
setUIContext(uiContext: ExtensionUIContext, hasUI: boolean): void; /** Shared runtime - actions are throwing stubs until runner.initialize() */
runtime: ExtensionRuntime;
} }
// ============================================================================ // ============================================================================

View file

@ -26,6 +26,7 @@ export {
discoverAndLoadExtensions, discoverAndLoadExtensions,
type ExecOptions, type ExecOptions,
type ExecResult, type ExecResult,
type Extension,
type ExtensionAPI, type ExtensionAPI,
type ExtensionCommandContext, type ExtensionCommandContext,
type ExtensionContext, type ExtensionContext,
@ -38,7 +39,6 @@ export {
type ExtensionShortcut, type ExtensionShortcut,
type ExtensionUIContext, type ExtensionUIContext,
type LoadExtensionsResult, type LoadExtensionsResult,
type LoadedExtension,
type MessageRenderer, type MessageRenderer,
type RegisteredCommand, type RegisteredCommand,
type SessionBeforeBranchEvent, type SessionBeforeBranchEvent,

View file

@ -28,11 +28,11 @@ import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js"; import { AuthStorage } from "./auth-storage.js";
import { createEventBus, type EventBus } from "./event-bus.js"; import { createEventBus, type EventBus } from "./event-bus.js";
import { import {
createExtensionRuntime,
discoverAndLoadExtensions, discoverAndLoadExtensions,
type ExtensionFactory, type ExtensionFactory,
ExtensionRunner, ExtensionRunner,
type LoadExtensionsResult, type LoadExtensionsResult,
type LoadedExtension,
loadExtensionFromFactory, loadExtensionFromFactory,
type ToolDefinition, type ToolDefinition,
wrapRegisteredTools, wrapRegisteredTools,
@ -106,10 +106,10 @@ export interface CreateAgentSessionOptions {
/** Additional extension paths to load (merged with discovery). */ /** Additional extension paths to load (merged with discovery). */
additionalExtensionPaths?: string[]; additionalExtensionPaths?: string[];
/** /**
* Pre-loaded extensions (skips file discovery). * Pre-loaded extensions result (skips file discovery).
* @internal Used by CLI when extensions are loaded early to parse custom flags. * @internal Used by CLI when extensions are loaded early to parse custom flags.
*/ */
preloadedExtensions?: LoadedExtension[]; preloadedExtensionsResult?: LoadExtensionsResult;
/** Shared event bus for tool/extension communication. Default: creates new bus. */ /** Shared event bus for tool/extension communication. Default: creates new bus. */
eventBus?: EventBus; eventBus?: EventBus;
@ -438,20 +438,17 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Load extensions (discovers from standard locations + configured paths) // Load extensions (discovers from standard locations + configured paths)
let extensionsResult: LoadExtensionsResult; let extensionsResult: LoadExtensionsResult;
if (options.preloadedExtensions !== undefined && options.preloadedExtensions.length > 0) { if (options.preloadedExtensionsResult !== undefined) {
// Use pre-loaded extensions (from early CLI flag discovery) // Use pre-loaded extensions (from early CLI flag discovery)
extensionsResult = { extensionsResult = options.preloadedExtensionsResult;
extensions: options.preloadedExtensions,
errors: [],
setUIContext: () => {},
};
} else if (options.extensions !== undefined) { } else if (options.extensions !== undefined) {
// User explicitly provided extensions array (even if empty) - skip discovery // User explicitly provided extensions array (even if empty) - skip discovery
// Inline factories from options.extensions are loaded below // Create runtime for inline extensions
const runtime = createExtensionRuntime();
extensionsResult = { extensionsResult = {
extensions: [], extensions: [],
errors: [], errors: [],
setUIContext: () => {}, runtime,
}; };
} else { } else {
// Discover extensions, merging with additional paths // Discover extensions, merging with additional paths
@ -465,45 +462,29 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Load inline extensions from factories // Load inline extensions from factories
if (options.extensions && options.extensions.length > 0) { if (options.extensions && options.extensions.length > 0) {
// Create shared UI context holder that will be set later
const uiHolder: { ui: any; hasUI: boolean } = {
ui: {
select: async () => undefined,
confirm: async () => false,
input: async () => undefined,
notify: () => {},
setStatus: () => {},
setWidget: () => {},
setFooter: () => {},
setTitle: () => {},
custom: async () => undefined as never,
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
get theme() {
return {} as any;
},
},
hasUI: false,
};
for (let i = 0; i < options.extensions.length; i++) { for (let i = 0; i < options.extensions.length; i++) {
const factory = options.extensions[i]; const factory = options.extensions[i];
const loaded = await loadExtensionFromFactory(factory, cwd, eventBus, uiHolder, `<inline-${i}>`); const loaded = await loadExtensionFromFactory(
factory,
cwd,
eventBus,
extensionsResult.runtime,
`<inline-${i}>`,
);
extensionsResult.extensions.push(loaded); extensionsResult.extensions.push(loaded);
} }
// Extend setUIContext to update inline extensions too
const originalSetUIContext = extensionsResult.setUIContext;
extensionsResult.setUIContext = (uiContext, hasUI) => {
originalSetUIContext(uiContext, hasUI);
uiHolder.ui = uiContext;
uiHolder.hasUI = hasUI;
};
} }
// Create extension runner if we have extensions // Create extension runner if we have extensions
let extensionRunner: ExtensionRunner | undefined; let extensionRunner: ExtensionRunner | undefined;
if (extensionsResult.extensions.length > 0) { if (extensionsResult.extensions.length > 0) {
extensionRunner = new ExtensionRunner(extensionsResult.extensions, cwd, sessionManager, modelRegistry); extensionRunner = new ExtensionRunner(
extensionsResult.extensions,
extensionsResult.runtime,
cwd,
sessionManager,
modelRegistry,
);
} }
// Wrap extension-registered tools and SDK-provided custom tools with context getter // Wrap extension-registered tools and SDK-provided custom tools with context getter
@ -536,7 +517,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
return {} as any; return {} as any;
}, },
}, },
hasUI: extensionRunner?.getHasUI() ?? false, hasUI: extensionRunner?.hasUI() ?? false,
cwd, cwd,
sessionManager, sessionManager,
modelRegistry, modelRegistry,

View file

@ -45,20 +45,24 @@ export type {
ContextEvent, ContextEvent,
ExecOptions, ExecOptions,
ExecResult, ExecResult,
Extension,
ExtensionActions,
ExtensionAPI, ExtensionAPI,
ExtensionCommandContext, ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext, ExtensionContext,
ExtensionContextActions,
ExtensionError, ExtensionError,
ExtensionEvent, ExtensionEvent,
ExtensionFactory, ExtensionFactory,
ExtensionFlag, ExtensionFlag,
ExtensionHandler, ExtensionHandler,
ExtensionRuntime,
ExtensionShortcut, ExtensionShortcut,
ExtensionUIContext, ExtensionUIContext,
ExtensionUIDialogOptions, ExtensionUIDialogOptions,
KeybindingsManager, KeybindingsManager,
LoadExtensionsResult, LoadExtensionsResult,
LoadedExtension,
MessageRenderer, MessageRenderer,
MessageRenderOptions, MessageRenderOptions,
RegisteredCommand, RegisteredCommand,
@ -81,6 +85,7 @@ export type {
TurnStartEvent, TurnStartEvent,
} from "./core/extensions/index.js"; } from "./core/extensions/index.js";
export { export {
createExtensionRuntime,
discoverAndLoadExtensions, discoverAndLoadExtensions,
ExtensionRunner, ExtensionRunner,
isBashToolResult, isBashToolResult,

View file

@ -18,7 +18,7 @@ import type { AgentSession } from "./core/agent-session.js";
import { createEventBus } from "./core/event-bus.js"; import { createEventBus } from "./core/event-bus.js";
import { exportFromFile } from "./core/export-html/index.js"; import { exportFromFile } from "./core/export-html/index.js";
import { discoverAndLoadExtensions, type ExtensionUIContext, type LoadedExtension } from "./core/extensions/index.js"; import { discoverAndLoadExtensions, type LoadExtensionsResult } from "./core/extensions/index.js";
import type { ModelRegistry } from "./core/model-registry.js"; import type { ModelRegistry } from "./core/model-registry.js";
import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js"; import { resolveModelScope, type ScopedModel } from "./core/model-resolver.js";
import { type CreateAgentSessionOptions, createAgentSession, discoverAuthStorage, discoverModels } from "./core/sdk.js"; import { type CreateAgentSessionOptions, createAgentSession, discoverAuthStorage, discoverModels } from "./core/sdk.js";
@ -60,13 +60,11 @@ async function runInteractiveMode(
migratedProviders: string[], migratedProviders: string[],
versionCheckPromise: Promise<string | undefined>, versionCheckPromise: Promise<string | undefined>,
initialMessages: string[], initialMessages: string[],
extensions: LoadedExtension[],
setExtensionUIContext: (uiContext: ExtensionUIContext, hasUI: boolean) => void,
initialMessage?: string, initialMessage?: string,
initialImages?: ImageContent[], initialImages?: ImageContent[],
fdPath: string | undefined = undefined, fdPath: string | undefined = undefined,
): Promise<void> { ): Promise<void> {
const mode = new InteractiveMode(session, version, changelogMarkdown, extensions, setExtensionUIContext, fdPath); const mode = new InteractiveMode(session, version, changelogMarkdown, fdPath);
await mode.init(); await mode.init();
@ -236,7 +234,7 @@ function buildSessionOptions(
sessionManager: SessionManager | undefined, sessionManager: SessionManager | undefined,
modelRegistry: ModelRegistry, modelRegistry: ModelRegistry,
settingsManager: SettingsManager, settingsManager: SettingsManager,
preloadedExtensions?: LoadedExtension[], extensionsResult?: LoadExtensionsResult,
): CreateAgentSessionOptions { ): CreateAgentSessionOptions {
const options: CreateAgentSessionOptions = {}; const options: CreateAgentSessionOptions = {};
@ -302,8 +300,8 @@ function buildSessionOptions(
} }
// Pre-loaded extensions (from early CLI flag discovery) // Pre-loaded extensions (from early CLI flag discovery)
if (preloadedExtensions && preloadedExtensions.length > 0) { if (extensionsResult && extensionsResult.extensions.length > 0) {
options.preloadedExtensions = preloadedExtensions; options.preloadedExtensionsResult = extensionsResult;
} }
return options; return options;
@ -332,12 +330,12 @@ export async function main(args: string[]) {
time("SettingsManager.create"); time("SettingsManager.create");
// Merge CLI --extension args with settings.json extensions // Merge CLI --extension args with settings.json extensions
const extensionPaths = [...settingsManager.getExtensionPaths(), ...(firstPass.extensions ?? [])]; const extensionPaths = [...settingsManager.getExtensionPaths(), ...(firstPass.extensions ?? [])];
const { extensions: loadedExtensions } = await discoverAndLoadExtensions(extensionPaths, cwd, agentDir, eventBus); const extensionsResult = await discoverAndLoadExtensions(extensionPaths, cwd, agentDir, eventBus);
time("discoverExtensionFlags"); time("discoverExtensionFlags");
// Collect all extension flags // Collect all extension flags
const extensionFlags = new Map<string, { type: "boolean" | "string" }>(); const extensionFlags = new Map<string, { type: "boolean" | "string" }>();
for (const ext of loadedExtensions) { for (const ext of extensionsResult.extensions) {
for (const [name, flag] of ext.flags) { for (const [name, flag] of ext.flags) {
extensionFlags.set(name, { type: flag.type }); extensionFlags.set(name, { type: flag.type });
} }
@ -347,13 +345,9 @@ export async function main(args: string[]) {
const parsed = parseArgs(args, extensionFlags); const parsed = parseArgs(args, extensionFlags);
time("parseArgs"); time("parseArgs");
// Pass flag values to extensions // Pass flag values to extensions via runtime
for (const [name, value] of parsed.unknownFlags) { for (const [name, value] of parsed.unknownFlags) {
for (const ext of loadedExtensions) { extensionsResult.runtime.flagValues.set(name, value);
if (ext.flags.has(name)) {
ext.setFlagValue(name, value);
}
}
} }
if (parsed.version) { if (parsed.version) {
@ -436,7 +430,7 @@ export async function main(args: string[]) {
sessionManager, sessionManager,
modelRegistry, modelRegistry,
settingsManager, settingsManager,
loadedExtensions, extensionsResult,
); );
sessionOptions.authStorage = authStorage; sessionOptions.authStorage = authStorage;
sessionOptions.modelRegistry = modelRegistry; sessionOptions.modelRegistry = modelRegistry;
@ -452,7 +446,7 @@ export async function main(args: string[]) {
} }
time("buildSessionOptions"); time("buildSessionOptions");
const { session, extensionsResult, modelFallbackMessage } = await createAgentSession(sessionOptions); const { session, modelFallbackMessage } = await createAgentSession(sessionOptions);
time("createAgentSession"); time("createAgentSession");
if (!isInteractive && !session.model) { if (!isInteractive && !session.model) {
@ -505,8 +499,6 @@ export async function main(args: string[]) {
migratedProviders, migratedProviders,
versionCheckPromise, versionCheckPromise,
parsed.messages, parsed.messages,
extensionsResult.extensions,
extensionsResult.setUIContext,
initialMessage, initialMessage,
initialImages, initialImages,
fdPath, fdPath,

View file

@ -33,13 +33,13 @@ import type {
ExtensionRunner, ExtensionRunner,
ExtensionUIContext, ExtensionUIContext,
ExtensionUIDialogOptions, ExtensionUIDialogOptions,
LoadedExtension,
} from "../../core/extensions/index.js"; } from "../../core/extensions/index.js";
import { KeybindingsManager } from "../../core/keybindings.js"; import { KeybindingsManager } from "../../core/keybindings.js";
import { createCompactionSummaryMessage } from "../../core/messages.js"; import { createCompactionSummaryMessage } from "../../core/messages.js";
import { type SessionContext, SessionManager } from "../../core/session-manager.js"; import { type SessionContext, SessionManager } from "../../core/session-manager.js";
import { loadSkills } from "../../core/skills.js"; import { loadSkills } from "../../core/skills.js";
import { loadProjectContextFiles } from "../../core/system-prompt.js"; import { loadProjectContextFiles } from "../../core/system-prompt.js";
import { allTools } from "../../core/tools/index.js";
import type { TruncationResult } from "../../core/tools/truncate.js"; import type { TruncationResult } from "../../core/tools/truncate.js";
import { getChangelogPath, parseChangelog } from "../../utils/changelog.js"; import { getChangelogPath, parseChangelog } from "../../utils/changelog.js";
import { copyToClipboard } from "../../utils/clipboard.js"; import { copyToClipboard } from "../../utils/clipboard.js";
@ -184,8 +184,6 @@ export class InteractiveMode {
session: AgentSession, session: AgentSession,
version: string, version: string,
changelogMarkdown: string | undefined = undefined, changelogMarkdown: string | undefined = undefined,
_extensions: LoadedExtension[] = [],
private setExtensionUIContext: (uiContext: ExtensionUIContext, hasUI: boolean) => void = () => {},
fdPath: string | undefined = undefined, fdPath: string | undefined = undefined,
) { ) {
this.session = session; this.session = session;
@ -429,123 +427,124 @@ export class InteractiveMode {
} }
} }
// Create and set extension UI context
const uiContext = this.createExtensionUIContext();
this.setExtensionUIContext(uiContext, true);
const extensionRunner = this.session.extensionRunner; const extensionRunner = this.session.extensionRunner;
if (!extensionRunner) { if (!extensionRunner) {
return; // No extensions loaded return; // No extensions loaded
} }
extensionRunner.initialize({ // Create extension UI context
getModel: () => this.session.model, const uiContext = this.createExtensionUIContext();
sendMessageHandler: (message, options) => {
const wasStreaming = this.session.isStreaming; extensionRunner.initialize(
this.session // ExtensionActions - for pi.* API
.sendCustomMessage(message, options) {
.then(() => { sendMessage: (message, options) => {
// For non-streaming cases with display=true, update UI const wasStreaming = this.session.isStreaming;
// (streaming cases update via message_end event) this.session
if (!wasStreaming && message.display) { .sendCustomMessage(message, options)
this.rebuildChatFromMessages(); .then(() => {
} if (!wasStreaming && message.display) {
}) this.rebuildChatFromMessages();
.catch((err) => { }
this.showError(`Extension sendMessage failed: ${err instanceof Error ? err.message : String(err)}`); })
.catch((err) => {
this.showError(
`Extension sendMessage failed: ${err instanceof Error ? err.message : String(err)}`,
);
});
},
sendUserMessage: (content, options) => {
this.session.sendUserMessage(content, options).catch((err) => {
this.showError(
`Extension sendUserMessage failed: ${err instanceof Error ? err.message : String(err)}`,
);
}); });
},
appendEntry: (customType, data) => {
this.sessionManager.appendCustomEntry(customType, data);
},
getActiveTools: () => this.session.getActiveToolNames(),
getAllTools: () => this.session.getAllToolNames(),
setActiveTools: (toolNames) => this.session.setActiveToolsByName(toolNames),
setModel: async (model) => {
const key = await this.session.modelRegistry.getApiKey(model);
if (!key) return false;
await this.session.setModel(model);
return true;
},
getThinkingLevel: () => this.session.thinkingLevel,
setThinkingLevel: (level) => this.session.setThinkingLevel(level),
}, },
sendUserMessageHandler: (content, options) => { // ExtensionContextActions - for ctx.* in event handlers
this.session.sendUserMessage(content, options).catch((err) => { {
this.showError(`Extension sendUserMessage failed: ${err instanceof Error ? err.message : String(err)}`); getModel: () => this.session.model,
}); isIdle: () => !this.session.isStreaming,
abort: () => this.session.abort(),
hasPendingMessages: () => this.session.pendingMessageCount > 0,
}, },
appendEntryHandler: (customType, data) => { // ExtensionCommandContextActions - for ctx.* in command handlers
this.sessionManager.appendCustomEntry(customType, data); {
waitForIdle: () => this.session.agent.waitForIdle(),
newSession: async (options) => {
if (this.loadingAnimation) {
this.loadingAnimation.stop();
this.loadingAnimation = undefined;
}
this.statusContainer.clear();
const success = await this.session.newSession({ parentSession: options?.parentSession });
if (!success) {
return { cancelled: true };
}
if (options?.setup) {
await options.setup(this.sessionManager);
}
this.chatContainer.clear();
this.pendingMessagesContainer.clear();
this.compactionQueuedMessages = [];
this.streamingComponent = undefined;
this.streamingMessage = undefined;
this.pendingTools.clear();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(`${theme.fg("accent", "✓ New session started")}`, 1, 1));
this.ui.requestRender();
return { cancelled: false };
},
branch: async (entryId) => {
const result = await this.session.branch(entryId);
if (result.cancelled) {
return { cancelled: true };
}
this.chatContainer.clear();
this.renderInitialMessages();
this.editor.setText(result.selectedText);
this.showStatus("Branched to new session");
return { cancelled: false };
},
navigateTree: async (targetId, options) => {
const result = await this.session.navigateTree(targetId, { summarize: options?.summarize });
if (result.cancelled) {
return { cancelled: true };
}
this.chatContainer.clear();
this.renderInitialMessages();
if (result.editorText) {
this.editor.setText(result.editorText);
}
this.showStatus("Navigated to selected point");
return { cancelled: false };
},
}, },
getActiveToolsHandler: () => this.session.getActiveToolNames(),
getAllToolsHandler: () => this.session.getAllToolNames(),
setActiveToolsHandler: (toolNames) => this.session.setActiveToolsByName(toolNames),
newSessionHandler: async (options) => {
// Stop any loading animation
if (this.loadingAnimation) {
this.loadingAnimation.stop();
this.loadingAnimation = undefined;
}
this.statusContainer.clear();
// Create new session
const success = await this.session.newSession({ parentSession: options?.parentSession });
if (!success) {
return { cancelled: true };
}
// Call setup callback if provided
if (options?.setup) {
await options.setup(this.sessionManager);
}
// Clear UI state
this.chatContainer.clear();
this.pendingMessagesContainer.clear();
this.compactionQueuedMessages = [];
this.streamingComponent = undefined;
this.streamingMessage = undefined;
this.pendingTools.clear();
this.chatContainer.addChild(new Spacer(1));
this.chatContainer.addChild(new Text(`${theme.fg("accent", "✓ New session started")}`, 1, 1));
this.ui.requestRender();
return { cancelled: false };
},
branchHandler: async (entryId) => {
const result = await this.session.branch(entryId);
if (result.cancelled) {
return { cancelled: true };
}
// Update UI
this.chatContainer.clear();
this.renderInitialMessages();
this.editor.setText(result.selectedText);
this.showStatus("Branched to new session");
return { cancelled: false };
},
navigateTreeHandler: async (targetId, options) => {
const result = await this.session.navigateTree(targetId, { summarize: options?.summarize });
if (result.cancelled) {
return { cancelled: true };
}
// Update UI
this.chatContainer.clear();
this.renderInitialMessages();
if (result.editorText) {
this.editor.setText(result.editorText);
}
this.showStatus("Navigated to selected point");
return { cancelled: false };
},
setModelHandler: async (model) => {
const key = await this.session.modelRegistry.getApiKey(model);
if (!key) return false;
await this.session.setModel(model);
return true;
},
getThinkingLevelHandler: () => this.session.thinkingLevel,
setThinkingLevelHandler: (level) => this.session.setThinkingLevel(level),
isIdle: () => !this.session.isStreaming,
waitForIdle: () => this.session.agent.waitForIdle(),
abort: () => {
this.session.abort();
},
hasPendingMessages: () => this.session.pendingMessageCount > 0,
uiContext, uiContext,
hasUI: true, );
});
// Subscribe to extension errors // Subscribe to extension errors
extensionRunner.onError((error) => { extensionRunner.onError((error) => {
@ -563,6 +562,24 @@ export class InteractiveMode {
this.chatContainer.addChild(new Spacer(1)); this.chatContainer.addChild(new Spacer(1));
} }
// Warn about built-in tool overrides
const builtInToolNames = new Set(Object.keys(allTools));
const registeredTools = extensionRunner.getAllRegisteredTools();
for (const tool of registeredTools) {
if (builtInToolNames.has(tool.definition.name)) {
this.chatContainer.addChild(
new Text(
theme.fg(
"warning",
`Warning: Extension "${tool.extensionPath}" overrides built-in tool "${tool.definition.name}"`,
),
0,
0,
),
);
}
}
// Emit session_start event // Emit session_start event
await extensionRunner.emit({ await extensionRunner.emit({
type: "session_start", type: "session_start",

View file

@ -26,37 +26,65 @@ export async function runPrintMode(
initialMessage?: string, initialMessage?: string,
initialImages?: ImageContent[], initialImages?: ImageContent[],
): Promise<void> { ): Promise<void> {
// Extension runner already has no-op UI context by default (set in loader) // Set up extensions for print mode (no UI, no command context)
// Set up extensions for print mode (no UI)
const extensionRunner = session.extensionRunner; const extensionRunner = session.extensionRunner;
if (extensionRunner) { if (extensionRunner) {
extensionRunner.initialize({ extensionRunner.initialize(
getModel: () => session.model, // ExtensionActions
sendMessageHandler: (message, options) => { {
session.sendCustomMessage(message, options).catch((e) => { sendMessage: (message, options) => {
console.error(`Extension sendMessage failed: ${e instanceof Error ? e.message : String(e)}`); session.sendCustomMessage(message, options).catch((e) => {
}); console.error(`Extension sendMessage failed: ${e instanceof Error ? e.message : String(e)}`);
});
},
sendUserMessage: (content, options) => {
session.sendUserMessage(content, options).catch((e) => {
console.error(`Extension sendUserMessage failed: ${e instanceof Error ? e.message : String(e)}`);
});
},
appendEntry: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
},
getActiveTools: () => session.getActiveToolNames(),
getAllTools: () => session.getAllToolNames(),
setActiveTools: (toolNames: string[]) => session.setActiveToolsByName(toolNames),
setModel: async (model) => {
const key = await session.modelRegistry.getApiKey(model);
if (!key) return false;
await session.setModel(model);
return true;
},
getThinkingLevel: () => session.thinkingLevel,
setThinkingLevel: (level) => session.setThinkingLevel(level),
}, },
sendUserMessageHandler: (content, options) => { // ExtensionContextActions
session.sendUserMessage(content, options).catch((e) => { {
console.error(`Extension sendUserMessage failed: ${e instanceof Error ? e.message : String(e)}`); getModel: () => session.model,
}); isIdle: () => !session.isStreaming,
abort: () => session.abort(),
hasPendingMessages: () => session.pendingMessageCount > 0,
}, },
appendEntryHandler: (customType, data) => { // ExtensionCommandContextActions - commands invokable via prompt("/command")
session.sessionManager.appendCustomEntry(customType, data); {
waitForIdle: () => session.agent.waitForIdle(),
newSession: async (options) => {
const success = await session.newSession({ parentSession: options?.parentSession });
if (success && options?.setup) {
await options.setup(session.sessionManager);
}
return { cancelled: !success };
},
branch: async (entryId) => {
const result = await session.branch(entryId);
return { cancelled: result.cancelled };
},
navigateTree: async (targetId, options) => {
const result = await session.navigateTree(targetId, { summarize: options?.summarize });
return { cancelled: result.cancelled };
},
}, },
getActiveToolsHandler: () => session.getActiveToolNames(), // No UI context
getAllToolsHandler: () => session.getAllToolNames(), );
setActiveToolsHandler: (toolNames: string[]) => session.setActiveToolsByName(toolNames),
setModelHandler: async (model) => {
const key = await session.modelRegistry.getApiKey(model);
if (!key) return false;
await session.setModel(model);
return true;
},
getThinkingLevelHandler: () => session.thinkingLevel,
setThinkingLevelHandler: (level) => session.setThinkingLevel(level),
});
extensionRunner.onError((err) => { extensionRunner.onError((err) => {
console.error(`Extension error (${err.extensionPath}): ${err.error}`); console.error(`Extension error (${err.extensionPath}): ${err.error}`);
}); });

View file

@ -231,35 +231,63 @@ export async function runRpcMode(session: AgentSession): Promise<never> {
// Set up extensions with RPC-based UI context // Set up extensions with RPC-based UI context
const extensionRunner = session.extensionRunner; const extensionRunner = session.extensionRunner;
if (extensionRunner) { if (extensionRunner) {
extensionRunner.initialize({ extensionRunner.initialize(
getModel: () => session.agent.state.model, // ExtensionActions
sendMessageHandler: (message, options) => { {
session.sendCustomMessage(message, options).catch((e) => { sendMessage: (message, options) => {
output(error(undefined, "extension_send", e.message)); session.sendCustomMessage(message, options).catch((e) => {
}); output(error(undefined, "extension_send", e.message));
});
},
sendUserMessage: (content, options) => {
session.sendUserMessage(content, options).catch((e) => {
output(error(undefined, "extension_send_user", e.message));
});
},
appendEntry: (customType, data) => {
session.sessionManager.appendCustomEntry(customType, data);
},
getActiveTools: () => session.getActiveToolNames(),
getAllTools: () => session.getAllToolNames(),
setActiveTools: (toolNames: string[]) => session.setActiveToolsByName(toolNames),
setModel: async (model) => {
const key = await session.modelRegistry.getApiKey(model);
if (!key) return false;
await session.setModel(model);
return true;
},
getThinkingLevel: () => session.thinkingLevel,
setThinkingLevel: (level) => session.setThinkingLevel(level),
}, },
sendUserMessageHandler: (content, options) => { // ExtensionContextActions
session.sendUserMessage(content, options).catch((e) => { {
output(error(undefined, "extension_send_user", e.message)); getModel: () => session.agent.state.model,
}); isIdle: () => !session.isStreaming,
abort: () => session.abort(),
hasPendingMessages: () => session.pendingMessageCount > 0,
}, },
appendEntryHandler: (customType, data) => { // ExtensionCommandContextActions - commands invokable via prompt("/command")
session.sessionManager.appendCustomEntry(customType, data); {
waitForIdle: () => session.agent.waitForIdle(),
newSession: async (options) => {
const success = await session.newSession({ parentSession: options?.parentSession });
// Note: setup callback runs but no UI feedback in RPC mode
if (success && options?.setup) {
await options.setup(session.sessionManager);
}
return { cancelled: !success };
},
branch: async (entryId) => {
const result = await session.branch(entryId);
return { cancelled: result.cancelled };
},
navigateTree: async (targetId, options) => {
const result = await session.navigateTree(targetId, { summarize: options?.summarize });
return { cancelled: result.cancelled };
},
}, },
getActiveToolsHandler: () => session.getActiveToolNames(), createExtensionUIContext(),
getAllToolsHandler: () => session.getAllToolNames(), );
setActiveToolsHandler: (toolNames: string[]) => session.setActiveToolsByName(toolNames),
setModelHandler: async (model) => {
const key = await session.modelRegistry.getApiKey(model);
if (!key) return false;
await session.setModel(model);
return true;
},
getThinkingLevelHandler: () => session.thinkingLevel,
setThinkingLevelHandler: (level) => session.setThinkingLevel(level),
uiContext: createExtensionUIContext(),
hasUI: false,
});
extensionRunner.onError((err) => { extensionRunner.onError((err) => {
output({ type: "extension_error", extensionPath: err.extensionPath, event: err.event, error: err.error }); output({ type: "extension_error", extensionPath: err.extensionPath, event: err.event, error: err.error });
}); });

View file

@ -11,8 +11,9 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { AgentSession } from "../src/core/agent-session.js"; import { AgentSession } from "../src/core/agent-session.js";
import { AuthStorage } from "../src/core/auth-storage.js"; import { AuthStorage } from "../src/core/auth-storage.js";
import { import {
createExtensionRuntime,
type Extension,
ExtensionRunner, ExtensionRunner,
type LoadedExtension,
type SessionBeforeCompactEvent, type SessionBeforeCompactEvent,
type SessionCompactEvent, type SessionCompactEvent,
type SessionEvent, type SessionEvent,
@ -21,7 +22,6 @@ import { ModelRegistry } from "../src/core/model-registry.js";
import { SessionManager } from "../src/core/session-manager.js"; import { SessionManager } from "../src/core/session-manager.js";
import { SettingsManager } from "../src/core/settings-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js";
import { codingTools } from "../src/core/tools/index.js"; import { codingTools } from "../src/core/tools/index.js";
import { theme } from "../src/modes/interactive/theme/theme.js";
const API_KEY = process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; const API_KEY = process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
@ -49,7 +49,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
function createExtension( function createExtension(
onBeforeCompact?: (event: SessionBeforeCompactEvent) => { cancel?: boolean; compaction?: any } | undefined, onBeforeCompact?: (event: SessionBeforeCompactEvent) => { cancel?: boolean; compaction?: any } | undefined,
onCompact?: (event: SessionCompactEvent) => void, onCompact?: (event: SessionCompactEvent) => void,
): LoadedExtension { ): Extension {
const handlers = new Map<string, ((event: any, ctx: any) => Promise<any>)[]>(); const handlers = new Map<string, ((event: any, ctx: any) => Promise<any>)[]>();
handlers.set("session_before_compact", [ handlers.set("session_before_compact", [
@ -80,22 +80,11 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
messageRenderers: new Map(), messageRenderers: new Map(),
commands: new Map(), commands: new Map(),
flags: new Map(), flags: new Map(),
flagValues: new Map(),
shortcuts: new Map(), shortcuts: new Map(),
setSendMessageHandler: () => {},
setSendUserMessageHandler: () => {},
setAppendEntryHandler: () => {},
setGetActiveToolsHandler: () => {},
setGetAllToolsHandler: () => {},
setSetActiveToolsHandler: () => {},
setSetModelHandler: () => {},
setGetThinkingLevelHandler: () => {},
setSetThinkingLevelHandler: () => {},
setFlagValue: () => {},
}; };
} }
function createSession(extensions: LoadedExtension[]) { function createSession(extensions: Extension[]) {
const model = getModel("anthropic", "claude-sonnet-4-5")!; const model = getModel("anthropic", "claude-sonnet-4-5")!;
const agent = new Agent({ const agent = new Agent({
getApiKey: () => API_KEY, getApiKey: () => API_KEY,
@ -111,39 +100,29 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
const authStorage = new AuthStorage(join(tempDir, "auth.json")); const authStorage = new AuthStorage(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage); const modelRegistry = new ModelRegistry(authStorage);
extensionRunner = new ExtensionRunner(extensions, tempDir, sessionManager, modelRegistry); const runtime = createExtensionRuntime();
extensionRunner.initialize({ extensionRunner = new ExtensionRunner(extensions, runtime, tempDir, sessionManager, modelRegistry);
getModel: () => session.model, extensionRunner.initialize(
sendMessageHandler: async () => {}, // ExtensionActions
sendUserMessageHandler: async () => {}, {
appendEntryHandler: async () => {}, sendMessage: async () => {},
getActiveToolsHandler: () => [], sendUserMessage: async () => {},
getAllToolsHandler: () => [], appendEntry: async () => {},
setActiveToolsHandler: () => {}, getActiveTools: () => [],
setModelHandler: async () => false, getAllTools: () => [],
getThinkingLevelHandler: () => "off", setActiveTools: () => {},
setThinkingLevelHandler: () => {}, setModel: async () => false,
uiContext: { getThinkingLevel: () => "off",
select: async () => undefined, setThinkingLevel: () => {},
confirm: async () => false,
input: async () => undefined,
notify: () => {},
setStatus: () => {},
setWidget: () => {},
setFooter: () => {},
setHeader: () => {},
setTitle: () => {},
custom: async () => undefined as never,
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
}, },
hasUI: false, // ExtensionContextActions
}); {
getModel: () => session.model,
isIdle: () => !session.isStreaming,
abort: () => session.abort(),
hasPendingMessages: () => session.pendingMessageCount > 0,
},
);
session = new AgentSession({ session = new AgentSession({
agent, agent,
@ -264,7 +243,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
}, 120000); }, 120000);
it("should continue with default compaction if extension throws error", async () => { it("should continue with default compaction if extension throws error", async () => {
const throwingExtension: LoadedExtension = { const throwingExtension: Extension = {
path: "throwing-extension", path: "throwing-extension",
resolvedPath: "/test/throwing-extension.ts", resolvedPath: "/test/throwing-extension.ts",
handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([ handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([
@ -291,18 +270,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
messageRenderers: new Map(), messageRenderers: new Map(),
commands: new Map(), commands: new Map(),
flags: new Map(), flags: new Map(),
flagValues: new Map(),
shortcuts: new Map(), shortcuts: new Map(),
setSendMessageHandler: () => {},
setSendUserMessageHandler: () => {},
setAppendEntryHandler: () => {},
setGetActiveToolsHandler: () => {},
setGetAllToolsHandler: () => {},
setSetActiveToolsHandler: () => {},
setSetModelHandler: () => {},
setGetThinkingLevelHandler: () => {},
setSetThinkingLevelHandler: () => {},
setFlagValue: () => {},
}; };
createSession([throwingExtension]); createSession([throwingExtension]);
@ -323,7 +291,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
it("should call multiple extensions in order", async () => { it("should call multiple extensions in order", async () => {
const callOrder: string[] = []; const callOrder: string[] = [];
const extension1: LoadedExtension = { const extension1: Extension = {
path: "extension1", path: "extension1",
resolvedPath: "/test/extension1.ts", resolvedPath: "/test/extension1.ts",
handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([ handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([
@ -350,21 +318,10 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
messageRenderers: new Map(), messageRenderers: new Map(),
commands: new Map(), commands: new Map(),
flags: new Map(), flags: new Map(),
flagValues: new Map(),
shortcuts: new Map(), shortcuts: new Map(),
setSendMessageHandler: () => {},
setSendUserMessageHandler: () => {},
setAppendEntryHandler: () => {},
setGetActiveToolsHandler: () => {},
setGetAllToolsHandler: () => {},
setSetActiveToolsHandler: () => {},
setSetModelHandler: () => {},
setGetThinkingLevelHandler: () => {},
setSetThinkingLevelHandler: () => {},
setFlagValue: () => {},
}; };
const extension2: LoadedExtension = { const extension2: Extension = {
path: "extension2", path: "extension2",
resolvedPath: "/test/extension2.ts", resolvedPath: "/test/extension2.ts",
handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([ handlers: new Map<string, ((event: any, ctx: any) => Promise<any>)[]>([
@ -391,18 +348,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
messageRenderers: new Map(), messageRenderers: new Map(),
commands: new Map(), commands: new Map(),
flags: new Map(), flags: new Map(),
flagValues: new Map(),
shortcuts: new Map(), shortcuts: new Map(),
setSendMessageHandler: () => {},
setSendUserMessageHandler: () => {},
setAppendEntryHandler: () => {},
setGetActiveToolsHandler: () => {},
setGetAllToolsHandler: () => {},
setSetActiveToolsHandler: () => {},
setSetModelHandler: () => {},
setGetThinkingLevelHandler: () => {},
setSetThinkingLevelHandler: () => {},
setFlagValue: () => {},
}; };
createSession([extension1, extension2]); createSession([extension1, extension2]);

View file

@ -46,7 +46,7 @@ describe("ExtensionRunner", () => {
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const shortcuts = runner.getShortcuts(); const shortcuts = runner.getShortcuts();
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("conflicts with built-in")); expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("conflicts with built-in"));
@ -79,7 +79,7 @@ describe("ExtensionRunner", () => {
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const shortcuts = runner.getShortcuts(); const shortcuts = runner.getShortcuts();
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("shortcut conflict")); expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("shortcut conflict"));
@ -108,7 +108,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "tool-b.ts"), toolCode("tool_b")); fs.writeFileSync(path.join(extensionsDir, "tool-b.ts"), toolCode("tool_b"));
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const tools = runner.getAllRegisteredTools(); const tools = runner.getAllRegisteredTools();
expect(tools.length).toBe(2); expect(tools.length).toBe(2);
@ -130,7 +130,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "cmd-b.ts"), cmdCode("cmd-b")); fs.writeFileSync(path.join(extensionsDir, "cmd-b.ts"), cmdCode("cmd-b"));
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const commands = runner.getRegisteredCommands(); const commands = runner.getRegisteredCommands();
expect(commands.length).toBe(2); expect(commands.length).toBe(2);
@ -149,7 +149,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "cmd.ts"), cmdCode); fs.writeFileSync(path.join(extensionsDir, "cmd.ts"), cmdCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const cmd = runner.getCommand("my-cmd"); const cmd = runner.getCommand("my-cmd");
expect(cmd).toBeDefined(); expect(cmd).toBeDefined();
@ -173,7 +173,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "throws.ts"), extCode); fs.writeFileSync(path.join(extensionsDir, "throws.ts"), extCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const errors: Array<{ extensionPath: string; event: string; error: string }> = []; const errors: Array<{ extensionPath: string; event: string; error: string }> = [];
runner.onError((err) => { runner.onError((err) => {
@ -199,7 +199,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "renderer.ts"), extCode); fs.writeFileSync(path.join(extensionsDir, "renderer.ts"), extCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const renderer = runner.getMessageRenderer("my-type"); const renderer = runner.getMessageRenderer("my-type");
expect(renderer).toBeDefined(); expect(renderer).toBeDefined();
@ -222,7 +222,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "with-flag.ts"), extCode); fs.writeFileSync(path.join(extensionsDir, "with-flag.ts"), extCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
const flags = runner.getFlags(); const flags = runner.getFlags();
expect(flags.has("--my-flag")).toBe(true); expect(flags.has("--my-flag")).toBe(true);
@ -240,14 +240,13 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "flag.ts"), extCode); fs.writeFileSync(path.join(extensionsDir, "flag.ts"), extCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
// Setting a flag value should not throw // Setting a flag value should not throw
runner.setFlagValue("--test-flag", true); runner.setFlagValue("--test-flag", true);
// The flag values are stored in the extension's flagValues map // The flag values are stored in the shared runtime
const ext = result.extensions[0]; expect(result.runtime.flagValues.get("--test-flag")).toBe(true);
expect(ext.flagValues.get("--test-flag")).toBe(true);
}); });
}); });
@ -261,7 +260,7 @@ describe("ExtensionRunner", () => {
fs.writeFileSync(path.join(extensionsDir, "handler.ts"), extCode); fs.writeFileSync(path.join(extensionsDir, "handler.ts"), extCode);
const result = await discoverAndLoadExtensions([], tempDir, tempDir); const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const runner = new ExtensionRunner(result.extensions, tempDir, sessionManager, modelRegistry); const runner = new ExtensionRunner(result.extensions, result.runtime, tempDir, sessionManager, modelRegistry);
expect(runner.hasHandlers("tool_call")).toBe(true); expect(runner.hasHandlers("tool_call")).toBe(true);
expect(runner.hasHandlers("agent_end")).toBe(false); expect(runner.hasHandlers("agent_end")).toBe(false);