diff --git a/.gitignore b/.gitignore index 8ad2078d..d320e0b1 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ compaction-results/ .opencode/ syntax.jsonl out.jsonl +pi-*.html diff --git a/.pi/commands/review.md b/.pi/commands/review.md new file mode 100644 index 00000000..3e1db779 --- /dev/null +++ b/.pi/commands/review.md @@ -0,0 +1,12 @@ +--- +description: Review a file for issues +--- +Please review the following file for potential issues, bugs, or improvements: + +$1 + +Focus on: +- Logic errors +- Edge cases +- Code style +- Performance concerns diff --git a/.pi/hooks/test-command.ts b/.pi/hooks/test-command.ts new file mode 100644 index 00000000..7f737498 --- /dev/null +++ b/.pi/hooks/test-command.ts @@ -0,0 +1,86 @@ +/** + * Test hook demonstrating custom commands, message rendering, and before_agent_start. + */ +import type { BeforeAgentStartEvent, HookAPI } from "@mariozechner/pi-coding-agent"; +import { Box, Spacer, Text } from "@mariozechner/pi-tui"; + +export default function (pi: HookAPI) { + // Track whether injection is enabled + let injectEnabled = false; + + // Register a custom message renderer for our "test-info" type + pi.registerMessageRenderer("test-info", (message, options, theme) => { + const box = new Box(1, 1, (t) => theme.bg("customMessageBg", t)); + + const label = theme.fg("success", "[TEST INFO]"); + box.addChild(new Text(label, 0, 0)); + box.addChild(new Spacer(1)); + + const content = + typeof message.content === "string" + ? message.content + : message.content.map((c) => (c.type === "text" ? c.text : "[image]")).join(""); + + box.addChild(new Text(theme.fg("text", content), 0, 0)); + + if (options.expanded && message.details) { + box.addChild(new Spacer(1)); + box.addChild(new Text(theme.fg("dim", `Details: ${JSON.stringify(message.details)}`), 0, 0)); + } + + return box; + }); + + // Register /test-msg command + pi.registerCommand("test-msg", { + description: "Send a test custom message", + handler: async () => { + pi.sendMessage( + { + customType: "test-info", + content: "This is a test message with custom rendering!", + display: true, + details: { timestamp: Date.now(), source: "test-command hook" }, + }, + true, // triggerTurn: start agent run + ); + }, + }); + + // Register /test-hidden command + pi.registerCommand("test-hidden", { + description: "Send a hidden message (display: false)", + handler: async (ctx) => { + pi.sendMessage({ + customType: "test-info", + content: "This message is in context but not displayed", + display: false, + }); + ctx.ui.notify("Sent hidden message (check session file)"); + }, + }); + + // Register /test-inject command to toggle before_agent_start injection + pi.registerCommand("test-inject", { + description: "Toggle context injection before agent starts", + handler: async (ctx) => { + injectEnabled = !injectEnabled; + ctx.ui.notify(`Context injection ${injectEnabled ? "enabled" : "disabled"}`); + }, + }); + + // Demonstrate before_agent_start: inject context when enabled + pi.on("before_agent_start", async (event: BeforeAgentStartEvent) => { + if (!injectEnabled) return; + + // Return a message to inject before the user's prompt + return { + message: { + customType: "test-info", + content: `[Injected context for prompt: "${event.prompt.slice(0, 50)}..."]`, + display: true, + details: { injectedAt: Date.now() }, + }, + }; + }); +} diff --git a/AGENTS.md b/AGENTS.md index 3990816f..686a1a32 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -30,7 +30,7 @@ When reading issues: When creating issues: - Add `pkg:*` labels to indicate which package(s) the issue affects - - Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:proxy`, `pkg:tui`, `pkg:web-ui` + - Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:tui`, `pkg:web-ui` - If an issue spans multiple packages, add all relevant labels When closing issues via commit: @@ -39,7 +39,7 @@ When closing issues via commit: ## Tools - GitHub CLI for issues/PRs -- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:proxy, pkg:tui, pkg:web-ui +- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:tui, pkg:web-ui - TUI interaction: use tmux ## Style @@ -49,43 +49,37 @@ When closing issues via commit: - Technical prose only, be kind but direct (e.g., "Thanks @user" not "Thanks so much @user!") ## Changelog -Location: `packages/coding-agent/CHANGELOG.md`, `packages/ai/CHANGELOG.md`, `packages/tui/CHANGELOG.md`, pick the one relevant to the changes or ask user. +Location: `packages/*/CHANGELOG.md` (each package has its own) + +### Format +Use these sections under `## [Unreleased]`: +- `### Breaking Changes` - API changes requiring migration +- `### Added` - New features +- `### Changed` - Changes to existing functionality +- `### Fixed` - Bug fixes +- `### Removed` - Removed features + +### Rules - New entries ALWAYS go under `## [Unreleased]` section - NEVER modify already-released version sections (e.g., `## [0.12.2]`) - Each version section is immutable once released -- When releasing: rename `[Unreleased]` to the new version, then add a fresh empty `[Unreleased]` section -### Attribution format -- **Internal changes (from issues)**: Reference issue only - - Example: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))` -- **External contributions (PRs from others)**: Reference PR and credit the contributor - - Example: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))` -- If a PR addresses an issue, reference both: `([#123](...issues/123), [#456](...pull/456) by [@user](...))` or just the PR if the issue context is clear from the description +### Attribution +- **Internal changes (from issues)**: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))` +- **External contributions**: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))` ## Releasing -1. **Bump version** (all packages use lockstep versioning): +1. **Update CHANGELOGs**: Ensure all changes since last release are documented in the `[Unreleased]` section of each affected package's CHANGELOG.md + +2. **Run release script**: ```bash - npm run version:patch # For bug fixes - npm run version:minor # For new features - npm run version:major # For breaking changes + npm run release:patch # Bug fixes + npm run release:minor # New features + npm run release:major # Breaking changes ``` -2. **Finalize CHANGELOG.md**: Change `[Unreleased]` to the new version with today's date (e.g., `## [0.12.12] - 2025-12-05`) - -3. **Commit and tag**: - ```bash - git add . - git commit -m "Release v0.12.12" - git tag v0.12.12 - git push origin main - git push origin v0.12.12 - ``` - -4. **Publish to npm**: - ```bash - npm run publish - ``` - -5. **Add new [Unreleased] section** at top of CHANGELOG.md for next cycle, commit it +The script handles: version bump, CHANGELOG finalization, commit, tag, publish, and adding new `[Unreleased]` sections. +### Tool Usage +**CTRICIAL**: NEVER use sed/cat to read a file or a range of a file. Always use the read tool (use offset + limit for ranged reads). \ No newline at end of file diff --git a/README.md b/README.md index d550a1df..5f4637fb 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,11 @@ Tools for building AI agents and managing LLM deployments. | Package | Description | |---------|-------------| | **[@mariozechner/pi-ai](packages/ai)** | Unified multi-provider LLM API (OpenAI, Anthropic, Google, etc.) | -| **[@mariozechner/pi-agent](packages/agent)** | Agent runtime with tool calling and state management | +| **[@mariozechner/pi-agent-core](packages/agent)** | Agent runtime with tool calling and state management | | **[@mariozechner/pi-coding-agent](packages/coding-agent)** | Interactive coding agent CLI | | **[@mariozechner/pi-mom](packages/mom)** | Slack bot that delegates messages to the pi coding agent | | **[@mariozechner/pi-tui](packages/tui)** | Terminal UI library with differential rendering | | **[@mariozechner/pi-web-ui](packages/web-ui)** | Web components for AI chat interfaces | -| **[@mariozechner/pi-proxy](packages/proxy)** | CORS proxy for browser-based LLM API calls | | **[@mariozechner/pi-pods](packages/pods)** | CLI for managing vLLM deployments on GPU pods | ## Development @@ -71,55 +70,18 @@ These commands: ### Publishing -Complete release process: +```bash +npm run release:patch # Bug fixes +npm run release:minor # New features +npm run release:major # Breaking changes +``` -1. **Add changes to CHANGELOG.md** (if changes affect coding-agent): - ```bash - # Add your changes to the [Unreleased] section in packages/coding-agent/CHANGELOG.md - # Always add new entries under [Unreleased], never under already-released versions - ``` +This handles version bump, CHANGELOG updates, commit, tag, publish, and push. -2. **Bump version** (all packages): - ```bash - npm run version:patch # For bug fixes - npm run version:minor # For new features - npm run version:major # For breaking changes - ``` - -3. **Finalize CHANGELOG.md for release** (if changes affect coding-agent): - ```bash - # Change [Unreleased] to the new version number with today's date - # e.g., ## [0.7.16] - 2025-11-17 - # NEVER add entries to already-released version sections - # Each version section is immutable once released - ``` - -4. **Commit and tag**: - ```bash - git add . - git commit -m "Release v0.7.16" - git tag v0.7.16 - git push origin main - git push origin v0.7.16 - ``` - -5. **Publish to npm**: - ```bash - npm run publish # Publish all packages to npm - ``` - - **NPM Token Setup**: Publishing requires a granular access token with "Bypass 2FA on publish" enabled. - - Go to https://www.npmjs.com/settings/badlogic/tokens/ - - Create a new "Granular Access Token" - - Select "Bypass 2FA on publish" - - Tokens expire after 90 days, so regenerate when needed - - Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN` - -6. **Add new [Unreleased] section** (for next development cycle): - ```bash - # Add a new [Unreleased] section at the top of CHANGELOG.md - # Commit: git commit -am "Add [Unreleased] section" - ``` +**NPM Token Setup**: Requires a granular access token with "Bypass 2FA on publish" enabled. +- Go to https://www.npmjs.com/settings/badlogic/tokens/ +- Create a new "Granular Access Token" with "Bypass 2FA on publish" +- Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN` ## License diff --git a/package-lock.json b/package-lock.json index 7aa30116..3e84ed4d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "packages/web-ui/example" ], "dependencies": { + "@mariozechner/pi-coding-agent": "^0.30.2", "get-east-asian-width": "^1.4.0" }, "devDependencies": { @@ -695,18 +696,6 @@ } } }, - "node_modules/@hono/node-server": { - "version": "1.19.7", - "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.7.tgz", - "integrity": "sha512-vUcD0uauS7EU2caukW8z5lJKtoGMokxNbJtBiwHgpqxEXokaHCBkQUmCHhjFB1VUTWdqj25QoMkMKzgjq+uhrw==", - "license": "MIT", - "engines": { - "node": ">=18.14.1" - }, - "peerDependencies": { - "hono": "^4" - } - }, "node_modules/@isaacs/balanced-match": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", @@ -964,10 +953,6 @@ "resolved": "packages/mom", "link": true }, - "node_modules/@mariozechner/pi-proxy": { - "resolved": "packages/proxy", - "link": true - }, "node_modules/@mariozechner/pi-tui": { "resolved": "packages/tui", "link": true @@ -995,9 +980,9 @@ } }, "node_modules/@napi-rs/canvas": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.86.tgz", - "integrity": "sha512-hOkywnrkdFdVpsuaNsZWfEY7kc96eROV2DuMTTvGF15AZfwobzdG2w0eDlU5UBx3Lg/XlWUnqVT5zLUWyo5h6A==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.88.tgz", + "integrity": "sha512-/p08f93LEbsL5mDZFQ3DBxcPv/I4QG9EDYRRq1WNlCOXVfAHBTHMSVMwxlqG/AtnSfUr9+vgfN7MKiyDo0+Weg==", "license": "MIT", "optional": true, "workspaces": [ @@ -1011,23 +996,23 @@ "url": "https://github.com/sponsors/Brooooooklyn" }, "optionalDependencies": { - "@napi-rs/canvas-android-arm64": "0.1.86", - "@napi-rs/canvas-darwin-arm64": "0.1.86", - "@napi-rs/canvas-darwin-x64": "0.1.86", - "@napi-rs/canvas-linux-arm-gnueabihf": "0.1.86", - "@napi-rs/canvas-linux-arm64-gnu": "0.1.86", - "@napi-rs/canvas-linux-arm64-musl": "0.1.86", - "@napi-rs/canvas-linux-riscv64-gnu": "0.1.86", - "@napi-rs/canvas-linux-x64-gnu": "0.1.86", - "@napi-rs/canvas-linux-x64-musl": "0.1.86", - "@napi-rs/canvas-win32-arm64-msvc": "0.1.86", - "@napi-rs/canvas-win32-x64-msvc": "0.1.86" + "@napi-rs/canvas-android-arm64": "0.1.88", + "@napi-rs/canvas-darwin-arm64": "0.1.88", + "@napi-rs/canvas-darwin-x64": "0.1.88", + "@napi-rs/canvas-linux-arm-gnueabihf": "0.1.88", + "@napi-rs/canvas-linux-arm64-gnu": "0.1.88", + "@napi-rs/canvas-linux-arm64-musl": "0.1.88", + "@napi-rs/canvas-linux-riscv64-gnu": "0.1.88", + "@napi-rs/canvas-linux-x64-gnu": "0.1.88", + "@napi-rs/canvas-linux-x64-musl": "0.1.88", + "@napi-rs/canvas-win32-arm64-msvc": "0.1.88", + "@napi-rs/canvas-win32-x64-msvc": "0.1.88" } }, "node_modules/@napi-rs/canvas-android-arm64": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.86.tgz", - "integrity": "sha512-IjkZFKUr6GzMzzrawJaN3v+yY3Fvpa71e0DcbePfxWelFKnESIir+XUcdAbim29JOd0JE0/hQJdfUCb5t/Fjrw==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.88.tgz", + "integrity": "sha512-KEaClPnZuVxJ8smUWjV1wWFkByBO/D+vy4lN+Dm5DFH514oqwukxKGeck9xcKJhaWJGjfruGmYGiwRe//+/zQQ==", "cpu": [ "arm64" ], @@ -1045,9 +1030,9 @@ } }, "node_modules/@napi-rs/canvas-darwin-arm64": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.86.tgz", - "integrity": "sha512-PUCxDq0wSSJbtaOqoKj3+t5tyDbtxWumziOTykdn3T839hu6koMaBFpGk9lXpsGaPNgyFpPqjxhtsPljBGnDHg==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.88.tgz", + "integrity": "sha512-Xgywz0dDxOKSgx3eZnK85WgGMmGrQEW7ZLA/E7raZdlEE+xXCozobgqz2ZvYigpB6DJFYkqnwHjqCOTSDGlFdg==", "cpu": [ "arm64" ], @@ -1065,9 +1050,9 @@ } }, "node_modules/@napi-rs/canvas-darwin-x64": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.86.tgz", - "integrity": "sha512-rlCFLv4Rrg45qFZq7mysrKnsUbMhwdNg3YPuVfo9u4RkOqm7ooAJvdyDFxiqfSsJJTqupYqa9VQCUt8WKxKhNQ==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.88.tgz", + "integrity": "sha512-Yz4wSCIQOUgNucgk+8NFtQxQxZV5NO8VKRl9ePKE6XoNyNVC8JDqtvhh3b3TPqKK8W5p2EQpAr1rjjm0mfBxdg==", "cpu": [ "x64" ], @@ -1085,9 +1070,9 @@ } }, "node_modules/@napi-rs/canvas-linux-arm-gnueabihf": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.86.tgz", - "integrity": "sha512-6xWwyMc9BlDBt+9XHN/GzUo3MozHta/2fxQHMb80x0K2zpZuAdDKUYHmYzx9dFWDY3SbPYnx6iRlQl6wxnwS1w==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.88.tgz", + "integrity": "sha512-9gQM2SlTo76hYhxHi2XxWTAqpTOb+JtxMPEIr+H5nAhHhyEtNmTSDRtz93SP7mGd2G3Ojf2oF5tP9OdgtgXyKg==", "cpu": [ "arm" ], @@ -1105,9 +1090,9 @@ } }, "node_modules/@napi-rs/canvas-linux-arm64-gnu": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.86.tgz", - "integrity": "sha512-r2OX3w50xHxrToTovOSQWwkVfSq752CUzH9dzlVXyr8UDKFV8dMjfa9hePXvAJhN3NBp4TkHcGx15QCdaCIwnA==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.88.tgz", + "integrity": "sha512-7qgaOBMXuVRk9Fzztzr3BchQKXDxGbY+nwsovD3I/Sx81e+sX0ReEDYHTItNb0Je4NHbAl7D0MKyd4SvUc04sg==", "cpu": [ "arm64" ], @@ -1125,9 +1110,9 @@ } }, "node_modules/@napi-rs/canvas-linux-arm64-musl": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.86.tgz", - "integrity": "sha512-jbXuh8zVFUPw6a9SGpgc6EC+fRbGGyP1NFfeQiVqGLs6bN93ROtPLPL6MH9Bp6yt0CXUFallk2vgKdWDbmW+bw==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.88.tgz", + "integrity": "sha512-kYyNrUsHLkoGHBc77u4Unh067GrfiCUMbGHC2+OTxbeWfZkPt2o32UOQkhnSswKd9Fko/wSqqGkY956bIUzruA==", "cpu": [ "arm64" ], @@ -1145,9 +1130,9 @@ } }, "node_modules/@napi-rs/canvas-linux-riscv64-gnu": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.86.tgz", - "integrity": "sha512-9IwHR2qbq2HceM9fgwyL7x37Jy3ptt1uxvikQEuWR0FisIx9QEdt7F3huljCky76aoouF2vSd0R2fHo3ESRoPw==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.88.tgz", + "integrity": "sha512-HVuH7QgzB0yavYdNZDRyAsn/ejoXB0hn8twwFnOqUbCCdkV+REna7RXjSR7+PdfW0qMQ2YYWsLvVBT5iL/mGpw==", "cpu": [ "riscv64" ], @@ -1165,9 +1150,9 @@ } }, "node_modules/@napi-rs/canvas-linux-x64-gnu": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.86.tgz", - "integrity": "sha512-Jor+rhRN6ubix+D2QkNn9XlPPVAYl+2qFrkZ4oZN9UgtqIUZ+n+HljxhlkkDFRaX1mlxXOXPQjxaZg17zDSFcQ==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.88.tgz", + "integrity": "sha512-hvcvKIcPEQrvvJtJnwD35B3qk6umFJ8dFIr8bSymfrSMem0EQsfn1ztys8ETIFndTwdNWJKWluvxztA41ivsEw==", "cpu": [ "x64" ], @@ -1185,9 +1170,9 @@ } }, "node_modules/@napi-rs/canvas-linux-x64-musl": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.86.tgz", - "integrity": "sha512-A28VTy91DbclopSGZ2tIon3p8hcVI1JhnNpDpJ5N9rYlUnVz1WQo4waEMh+FICTZF07O3coxBNZc4Vu4doFw7A==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.88.tgz", + "integrity": "sha512-eSMpGYY2xnZSQ6UxYJ6plDboxq4KeJ4zT5HaVkUnbObNN6DlbJe0Mclh3wifAmquXfrlgTZt6zhHsUgz++AK6g==", "cpu": [ "x64" ], @@ -1205,9 +1190,9 @@ } }, "node_modules/@napi-rs/canvas-win32-arm64-msvc": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.86.tgz", - "integrity": "sha512-q6G1YXUt3gBCAS2bcDMCaBL4y20di8eVVBi1XhjUqZSVyZZxxwIuRQHy31NlPJUCMiyNiMuc6zeI0uqgkWwAmA==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.88.tgz", + "integrity": "sha512-qcIFfEgHrchyYqRrxsCeTQgpJZ/GqHiqPcU/Fvw/ARVlQeDX1VyFH+X+0gCR2tca6UJrq96vnW+5o7buCq+erA==", "cpu": [ "arm64" ], @@ -1225,9 +1210,9 @@ } }, "node_modules/@napi-rs/canvas-win32-x64-msvc": { - "version": "0.1.86", - "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.86.tgz", - "integrity": "sha512-X0g46uRVgnvCM1cOjRXAOSFSG63ktUFIf/TIfbKCUc7QpmYUcHmSP9iR6DGOYfk+SggLsXoJCIhPTotYeZEAmg==", + "version": "0.1.88", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.88.tgz", + "integrity": "sha512-ROVqbfS4QyZxYkqmaIBBpbz/BQvAR+05FXM5PAtTYVc0uyY8Y4BHJSMdGAaMf6TdIVRsQsiq+FG/dH9XhvWCFQ==", "cpu": [ "x64" ], @@ -4008,16 +3993,6 @@ "node": "*" } }, - "node_modules/hono": { - "version": "4.11.2", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.11.2.tgz", - "integrity": "sha512-o+avdUAD1v94oHkjGBhiMhBV4WBHxhbu0+CUVH78hhphKy/OKQLxtKjkmmNcrMlbYAhAbsM/9F+l3KnYxyD3Lg==", - "license": "MIT", - "peer": true, - "engines": { - "node": ">=16.9.0" - } - }, "node_modules/html-parse-string": { "version": "0.0.9", "resolved": "https://registry.npmjs.org/html-parse-string/-/html-parse-string-0.0.9.tgz", @@ -4583,7 +4558,6 @@ "resolved": "https://registry.npmjs.org/lit/-/lit-3.3.2.tgz", "integrity": "sha512-NF9zbsP79l4ao2SNrH3NkfmFgN/hBYSQo90saIVI1o5GpjAdCPVstVzO1MrLOakHoEhYkrtRjPK6Ob521aoYWQ==", "license": "BSD-3-Clause", - "peer": true, "dependencies": { "@lit/reactive-element": "^2.1.0", "lit-element": "^4.2.0", @@ -5704,7 +5678,6 @@ "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.4.0.tgz", "integrity": "sha512-uSaO4gnW+b3Y2aWoWfFpX62vn2sR3skfhbjsEnaBI81WD1wBLlHZe5sWf0AqjksNdYTbGBEd0UasQMT3SNV15g==", "license": "MIT", - "peer": true, "funding": { "type": "github", "url": "https://github.com/sponsors/dcastil" @@ -5733,8 +5706,7 @@ "version": "4.1.18", "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.18.tgz", "integrity": "sha512-4+Z+0yiYyEtUVCScyfHCxOYP06L5Ne+JiHhY2IjR2KWMIWhJOYZKLSGZaP5HkZ8+bY0cxfzwDE5uOmzFXyIwxw==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/tapable": { "version": "2.3.0", @@ -5999,7 +5971,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.0.tgz", "integrity": "sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==", "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -6442,9 +6413,9 @@ } }, "node_modules/zod-to-json-schema": { - "version": "3.25.0", - "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.0.tgz", - "integrity": "sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==", + "version": "3.25.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.1.tgz", + "integrity": "sha512-pM/SU9d3YAggzi6MtR4h7ruuQlqKtad8e9S0fmxcMi+ueAK5Korys/aWcV9LIIHTVbj01NdzxcnXSN+O74ZIVA==", "license": "ISC", "peerDependencies": { "zod": "^3.25 || ^4" @@ -6636,22 +6607,6 @@ "node": ">=20.0.0" } }, - "packages/proxy": { - "name": "@mariozechner/pi-proxy", - "version": "0.30.2", - "dependencies": { - "@hono/node-server": "^1.14.0", - "hono": "^4.6.16" - }, - "bin": { - "pi-proxy": "dist/cli.js" - }, - "devDependencies": { - "@types/node": "^22.10.5", - "tsx": "^4.19.2", - "typescript": "^5.7.3" - } - }, "packages/tui": { "name": "@mariozechner/pi-tui", "version": "0.30.2", diff --git a/package.json b/package.json index e608b9f2..fd20638e 100644 --- a/package.json +++ b/package.json @@ -8,8 +8,8 @@ ], "scripts": { "clean": "npm run clean --workspaces", - "build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi-proxy && npm run build -w @mariozechner/pi", - "dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui,proxy\" --prefix-colors \"cyan,yellow,red,white,green,magenta,blue\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\" \"npm run dev -w @mariozechner/pi-proxy\"", + "build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi", + "dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui\" --prefix-colors \"cyan,yellow,red,white,green,magenta\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\"", "dev:tsc": "concurrently --names \"ai,web-ui\" --prefix-colors \"cyan,green\" \"npm run dev:tsc -w @mariozechner/pi-ai\" \"npm run dev:tsc -w @mariozechner/pi-web-ui\"", "check": "biome check --write . && tsgo --noEmit && npm run check -w @mariozechner/pi-web-ui", "test": "npm run test --workspaces --if-present", @@ -20,6 +20,9 @@ "prepublishOnly": "npm run clean && npm run build && npm run check", "publish": "npm run prepublishOnly && npm publish -ws --access public", "publish:dry": "npm run prepublishOnly && npm publish -ws --access public --dry-run", + "release:patch": "node scripts/release.mjs patch", + "release:minor": "node scripts/release.mjs minor", + "release:major": "node scripts/release.mjs major", "prepare": "husky" }, "devDependencies": { @@ -36,6 +39,7 @@ }, "version": "0.0.3", "dependencies": { + "@mariozechner/pi-coding-agent": "^0.30.2", "get-east-asian-width": "^1.4.0" } } diff --git a/packages/agent/CHANGELOG.md b/packages/agent/CHANGELOG.md new file mode 100644 index 00000000..1e94f69c --- /dev/null +++ b/packages/agent/CHANGELOG.md @@ -0,0 +1,69 @@ +# Changelog + +## [Unreleased] + +### Breaking Changes + +- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. The `Agent` class now takes a `streamFn` option directly for custom streaming implementations. + +- **Agent options renamed**: + - `transport` → removed (use `streamFn` instead) + - `messageTransformer` → `convertToLlm` (converts `AgentMessage[]` to LLM-compatible `Message[]`) + - `preprocessor` → `transformContext` (transforms `AgentMessage[]` before `convertToLlm`) + +- **AppMessage renamed to AgentMessage**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency. + +- **Agent loop moved from pi-ai**: The `agentLoop`, `agentLoopContinue`, and related types (`AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `AgentLoopConfig`) have moved from `@mariozechner/pi-ai` to this package. + +### Added + +- **`streamFn` option**: Pass a custom stream function to the Agent for proxy backends or custom implementations. Default uses `streamSimple` from pi-ai. + +- **`streamProxy` utility**: New helper function for browser apps that need to proxy through a backend server. Replaces `AppTransport`. + +- **`getApiKey` option**: Dynamic API key resolution for expiring OAuth tokens (e.g., GitHub Copilot). + +- **`AgentLoopContext` and `AgentLoopConfig`**: Exported types for the low-level agent loop API. + +- **`agentLoop` and `agentLoopContinue`**: Low-level functions for running the agent loop directly without the `Agent` class wrapper. + +### Migration Guide + +**Before (0.30.x):** +```typescript +import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core'; + +const agent = new Agent({ + transport: new ProviderTransport({ apiKey: '...' }), + messageTransformer: (messages) => messages.filter(...), + preprocessor: async (messages) => compactMessages(messages) +}); +``` + +**After:** +```typescript +import { Agent } from '@mariozechner/pi-agent-core'; +import { streamSimple } from '@mariozechner/pi-ai'; + +const agent = new Agent({ + streamFn: streamSimple, // or omit for default + convertToLlm: (messages) => messages.filter(...), + transformContext: async (messages) => compactMessages(messages), + getApiKey: async (provider) => resolveApiKey(provider) +}); +``` + +**For proxy usage (replaces AppTransport):** +```typescript +import { Agent, streamProxy } from '@mariozechner/pi-agent-core'; + +const agent = new Agent({ + streamFn: (model, context, options) => streamProxy( + '/api/agent', + model, + context, + options, + { 'Authorization': 'Bearer ...' } + ) +}); +``` diff --git a/packages/agent/README.md b/packages/agent/README.md index 2e30bf74..3ac6783f 100644 --- a/packages/agent/README.md +++ b/packages/agent/README.md @@ -1,6 +1,6 @@ # @mariozechner/pi-agent-core -Stateful agent abstraction with transport layer for LLM interactions. Provides a reactive `Agent` class that manages conversation state, emits granular events, and supports pluggable transports for different deployment scenarios. +Stateful agent with tool execution, event streaming, and extensible message types. Built on `@mariozechner/pi-ai`. ## Installation @@ -11,12 +11,10 @@ npm install @mariozechner/pi-agent-core ## Quick Start ```typescript -import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core'; +import { Agent } from '@mariozechner/pi-agent-core'; import { getModel } from '@mariozechner/pi-ai'; -// Create agent with direct provider transport const agent = new Agent({ - transport: new ProviderTransport(), initialState: { systemPrompt: 'You are a helpful assistant.', model: getModel('anthropic', 'claude-sonnet-4-20250514'), @@ -28,38 +26,105 @@ const agent = new Agent({ // Subscribe to events for reactive UI updates agent.subscribe((event) => { switch (event.type) { + case 'message_start': + console.log(`${event.message.role} message started`); + break; case 'message_update': - // Stream text to UI - const content = event.message.content; - for (const block of content) { - if (block.type === 'text') console.log(block.text); + // Only emitted for assistant messages during streaming + // event.message is partial - may have incomplete content + for (const block of event.message.content) { + if (block.type === 'text') process.stdout.write(block.text); } break; + case 'message_end': + console.log(`${event.message.role} message complete`); + break; case 'tool_execution_start': console.log(`Calling ${event.toolName}...`); break; - case 'tool_execution_update': - // Stream tool output (e.g., bash stdout) - console.log('Progress:', event.partialResult.content); - break; case 'tool_execution_end': console.log(`Result:`, event.result.content); break; } }); -// Send a prompt await agent.prompt('Hello, world!'); - -// Access conversation state console.log(agent.state.messages); ``` -## Core Concepts +## AgentMessage vs LLM Message -### Agent State +The agent internally works with `AgentMessage`, a flexible type that can include: +- Standard LLM messages (`user`, `assistant`, `toolResult`) +- Custom app-specific message types (via declaration merging) -The `Agent` maintains reactive state: +LLMs only understand a subset: `user`, `assistant`, and `toolResult` messages with specific content formats. The `convertToLlm` function bridges this gap. + +### Why This Separation? + +1. **Rich UI state**: Store UI-specific data (attachments metadata, custom message types) alongside the conversation +2. **Session persistence**: Save the full conversation state including app-specific messages +3. **Context manipulation**: Transform messages before sending to LLM (compaction, injection, filtering) + +### The Conversion Flow + +``` +AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM + ↑ (optional) (required) + | + App state with custom types, + attachments, UI metadata +``` + +### Constraints + +**Messages passed to `prompt()` or queued via `queueMessage()` must convert to LLM messages with `role: "user"` or `role: "toolResult"`.** + +When calling `continue()`, the last message in the context must also convert to `user` or `toolResult`. The LLM expects to respond to a user or tool result, not to its own assistant message. + +```typescript +// OK: Standard user message +await agent.prompt('Hello'); + +// OK: Custom type that converts to user message +await agent.prompt({ role: 'hookMessage', content: 'System notification', timestamp: Date.now() }); +// But convertToLlm must handle this: +convertToLlm: (messages) => messages.map(m => { + if (m.role === 'hookMessage') { + return { role: 'user', content: m.content, timestamp: m.timestamp }; + } + return m; +}) + +// ERROR: Cannot prompt with assistant message +await agent.prompt({ role: 'assistant', content: [...], ... }); // Will fail at LLM +``` + +## Agent Options + +```typescript +interface AgentOptions { + initialState?: Partial; + + // Converts AgentMessage[] to LLM-compatible Message[] before each LLM call. + // Default: filters to user/assistant/toolResult and converts image attachments. + convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise; + + // Transform context before convertToLlm (for pruning, compaction, injecting context) + transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; + + // Queue mode: 'all' sends all queued messages, 'one-at-a-time' sends one per turn + queueMode?: 'all' | 'one-at-a-time'; + + // Custom stream function (for proxy backends). Default: streamSimple from pi-ai + streamFn?: StreamFn; + + // Dynamic API key resolution (useful for expiring OAuth tokens) + getApiKey?: (provider: string) => Promise | string | undefined; +} +``` + +## Agent State ```typescript interface AgentState { @@ -67,17 +132,19 @@ interface AgentState { model: Model; thinkingLevel: ThinkingLevel; // 'off' | 'minimal' | 'low' | 'medium' | 'high' | 'xhigh' tools: AgentTool[]; - messages: AppMessage[]; + messages: AgentMessage[]; // Full conversation including custom types isStreaming: boolean; - streamMessage: Message | null; + streamMessage: AgentMessage | null; // Current partial message during streaming pendingToolCalls: Set; error?: string; } ``` -### Events +## Events -Events provide fine-grained lifecycle information: +Events provide fine-grained lifecycle information for building reactive UIs. + +### Event Types | Event | Description | |-------|-------------| @@ -86,33 +153,118 @@ Events provide fine-grained lifecycle information: | `turn_start` | New turn begins (one LLM response + tool executions) | | `turn_end` | Turn completes with assistant message and tool results | | `message_start` | Message begins (user, assistant, or toolResult) | -| `message_update` | Assistant message streaming update | +| `message_update` | **Assistant messages only.** Partial message during streaming | | `message_end` | Message completes | | `tool_execution_start` | Tool begins execution | -| `tool_execution_update` | Tool streams progress (e.g., bash output) | +| `tool_execution_update` | Tool streams progress | | `tool_execution_end` | Tool completes with result | -### Transports +### Message Events for prompt() and queueMessage() -Transports abstract LLM communication: +When you call `prompt(message)`, the agent emits `message_start` and `message_end` events for that message before the assistant responds: -- **`ProviderTransport`**: Direct API calls using `@mariozechner/pi-ai` -- **`AppTransport`**: Proxy through a backend server (for browser apps) +``` +prompt(userMessage) + → agent_start + → turn_start + → message_start { message: userMessage } + → message_end { message: userMessage } + → message_start { message: assistantMessage } // LLM starts responding + → message_update { message: partialAssistant } // streaming... + → message_end { message: assistantMessage } + ... +``` + +Queued messages (via `queueMessage()`) emit the same events when injected: + +``` +// During tool execution, a message is queued +agent.queueMessage(interruptMessage) + +// After tool completes, before next LLM call: + → message_start { message: interruptMessage } + → message_end { message: interruptMessage } + → message_start { message: assistantMessage } // LLM responds to interrupt + ... +``` + +### Handling Partial Messages in Reactive UIs + +`message_update` events contain partial assistant messages during streaming. The `event.message` may have: +- Incomplete text (truncated mid-word) +- Partial tool call arguments +- Missing content blocks that haven't started streaming yet + +**Pattern for reactive UIs:** ```typescript -// Direct provider access (Node.js) -const agent = new Agent({ - transport: new ProviderTransport({ - apiKey: process.env.ANTHROPIC_API_KEY - }) -}); +agent.subscribe((event) => { + switch (event.type) { + case 'message_start': + if (event.message.role === 'assistant') { + // Create placeholder in UI + ui.addMessage({ id: tempId, role: 'assistant', content: [] }); + } + break; -// Via proxy (browser) + case 'message_update': + // Replace placeholder content with partial content + // This is only emitted for assistant messages + ui.updateMessage(tempId, event.message.content); + break; + + case 'message_end': + if (event.message.role === 'assistant') { + // Finalize with complete message + ui.finalizeMessage(tempId, event.message); + } + break; + } +}); +``` + +**Accessing the current partial message:** + +During streaming, `agent.state.streamMessage` contains the current partial message. This is useful for rendering outside the event handler: + +```typescript +// In a render loop or reactive binding +if (agent.state.isStreaming && agent.state.streamMessage) { + renderPartialMessage(agent.state.streamMessage); +} +``` + +## Custom Message Types + +Extend `AgentMessage` for app-specific messages via declaration merging: + +```typescript +declare module '@mariozechner/pi-agent-core' { + interface CustomAgentMessages { + artifact: { role: 'artifact'; code: string; language: string; timestamp: number }; + notification: { role: 'notification'; text: string; timestamp: number }; + } +} + +// AgentMessage now includes your custom types +const msg: AgentMessage = { role: 'artifact', code: '...', language: 'typescript', timestamp: Date.now() }; +``` + +Custom messages are stored in state but filtered out by the default `convertToLlm`. Provide your own converter to handle them: + +```typescript const agent = new Agent({ - transport: new AppTransport({ - endpoint: '/api/agent', - headers: { 'Authorization': 'Bearer ...' } - }) + convertToLlm: (messages) => { + return messages + .filter(m => m.role !== 'notification') // Filter out UI-only messages + .map(m => { + if (m.role === 'artifact') { + // Convert to user message so LLM sees the artifact + return { role: 'user', content: `[Artifact: ${m.language}]\n${m.code}`, timestamp: m.timestamp }; + } + return m; + }); + } }); ``` @@ -121,45 +273,76 @@ const agent = new Agent({ Queue messages to inject at the next turn: ```typescript -// Queue mode: 'all' or 'one-at-a-time' agent.setQueueMode('one-at-a-time'); -// Queue a message while agent is streaming -await agent.queueMessage({ +// Queue while agent is streaming +agent.queueMessage({ role: 'user', - content: 'Additional context...', + content: 'Stop what you are doing and focus on this instead.', timestamp: Date.now() }); ``` -## Attachments +When queued messages are detected after a tool call, remaining tool calls are skipped with error results ("Skipped due to queued user message"). The queued message is then injected before the next assistant response. -User messages can include attachments: +## Images + +User messages can include images: ```typescript -await agent.prompt('What is in this image?', [{ - id: 'img1', - type: 'image', - fileName: 'photo.jpg', - mimeType: 'image/jpeg', - size: 102400, - content: base64ImageData -}]); +await agent.prompt('What is in this image?', [ + { type: 'image', data: base64ImageData, mimeType: 'image/jpeg' } +]); ``` -## Custom Message Types +## Proxy Usage -Extend `AppMessage` for app-specific messages via declaration merging: +For browser apps that need to proxy through a backend, use `streamProxy`: ```typescript -declare module '@mariozechner/pi-agent-core' { - interface CustomMessages { - artifact: { role: 'artifact'; code: string; language: string }; - } +import { Agent, streamProxy } from '@mariozechner/pi-agent-core'; + +const agent = new Agent({ + streamFn: (model, context, options) => streamProxy( + '/api/agent', + model, + context, + options, + { 'Authorization': 'Bearer ...' } + ) +}); +``` + +## Low-Level API + +For more control, use `agentLoop` and `agentLoopContinue` directly: + +```typescript +import { agentLoop, agentLoopContinue, AgentLoopContext, AgentLoopConfig } from '@mariozechner/pi-agent-core'; +import { getModel, streamSimple } from '@mariozechner/pi-ai'; + +const context: AgentLoopContext = { + systemPrompt: 'You are helpful.', + messages: [], + tools: [myTool] +}; + +const config: AgentLoopConfig = { + model: getModel('openai', 'gpt-4o-mini'), + convertToLlm: (msgs) => msgs.filter(m => ['user', 'assistant', 'toolResult'].includes(m.role)) +}; + +const userMessage = { role: 'user', content: 'Hello', timestamp: Date.now() }; + +for await (const event of agentLoop(userMessage, context, config, undefined, streamSimple)) { + console.log(event.type); } -// Now AppMessage includes your custom type -const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript' }; +// Continue from existing context (e.g., after overflow recovery) +// Last message in context must convert to 'user' or 'toolResult' +for await (const event of agentLoopContinue(context, config, undefined, streamSimple)) { + console.log(event.type); +} ``` ## API Reference @@ -168,13 +351,14 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript' | Method | Description | |--------|-------------| -| `prompt(text, attachments?)` | Send a user prompt | -| `continue()` | Continue from current context (for retry after overflow) | +| `prompt(text, images?)` | Send a user prompt with optional images | +| `prompt(message)` | Send an AgentMessage directly (must convert to user/toolResult) | +| `continue()` | Continue from current context (last message must convert to user/toolResult) | | `abort()` | Abort current operation | -| `waitForIdle()` | Returns promise that resolves when agent is idle | +| `waitForIdle()` | Promise that resolves when agent is idle | | `reset()` | Clear all messages and state | | `subscribe(fn)` | Subscribe to events, returns unsubscribe function | -| `queueMessage(msg)` | Queue message for next turn | +| `queueMessage(msg)` | Queue message for next turn (must convert to user/toolResult) | | `clearMessageQueue()` | Clear queued messages | ### State Mutators @@ -184,7 +368,7 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript' | `setSystemPrompt(v)` | Update system prompt | | `setModel(m)` | Switch model | | `setThinkingLevel(l)` | Set reasoning level | -| `setQueueMode(m)` | Set queue mode ('all' or 'one-at-a-time') | +| `setQueueMode(m)` | Set queue mode | | `setTools(t)` | Update available tools | | `replaceMessages(ms)` | Replace all messages | | `appendMessage(m)` | Append a message | diff --git a/packages/ai/src/agent/agent-loop.ts b/packages/agent/src/agent-loop.ts similarity index 61% rename from packages/ai/src/agent/agent-loop.ts rename to packages/agent/src/agent-loop.ts index badb1bca..ee63b89f 100644 --- a/packages/ai/src/agent/agent-loop.ts +++ b/packages/agent/src/agent-loop.ts @@ -1,33 +1,52 @@ -import { streamSimple } from "../stream.js"; -import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js"; -import { EventStream } from "../utils/event-stream.js"; -import { validateToolArguments } from "../utils/validation.js"; -import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, QueuedMessage } from "./types.js"; +/** + * Agent loop that works with AgentMessage throughout. + * Transforms to Message[] only at the LLM call boundary. + */ + +import { + type AssistantMessage, + type Context, + EventStream, + streamSimple, + type ToolResultMessage, + validateToolArguments, +} from "@mariozechner/pi-ai"; +import type { + AgentContext, + AgentEvent, + AgentLoopConfig, + AgentMessage, + AgentTool, + AgentToolResult, + StreamFn, +} from "./types.js"; /** - * Start an agent loop with a new user message. + * Start an agent loop with a new prompt message. * The prompt is added to the context and events are emitted for it. */ export function agentLoop( - prompt: UserMessage, + prompts: AgentMessage[], context: AgentContext, config: AgentLoopConfig, signal?: AbortSignal, - streamFn?: typeof streamSimple, -): EventStream { + streamFn?: StreamFn, +): EventStream { const stream = createAgentStream(); (async () => { - const newMessages: AgentContext["messages"] = [prompt]; + const newMessages: AgentMessage[] = [...prompts]; const currentContext: AgentContext = { ...context, - messages: [...context.messages, prompt], + messages: [...context.messages, ...prompts], }; stream.push({ type: "agent_start" }); stream.push({ type: "turn_start" }); - stream.push({ type: "message_start", message: prompt }); - stream.push({ type: "message_end", message: prompt }); + for (const prompt of prompts) { + stream.push({ type: "message_start", message: prompt }); + stream.push({ type: "message_end", message: prompt }); + } await runLoop(currentContext, newMessages, config, signal, stream, streamFn); })(); @@ -37,33 +56,34 @@ export function agentLoop( /** * Continue an agent loop from the current context without adding a new message. - * Used for retry after overflow - context already has user message or tool results. - * Throws if the last message is not a user message or tool result. + * Used for retries - context already has user message or tool results. + * + * **Important:** The last message in context must convert to a `user` or `toolResult` message + * via `convertToLlm`. If it doesn't, the LLM provider will reject the request. + * This cannot be validated here since `convertToLlm` is only called once per turn. */ export function agentLoopContinue( context: AgentContext, config: AgentLoopConfig, signal?: AbortSignal, - streamFn?: typeof streamSimple, -): EventStream { - // Validate that we can continue from this context - const lastMessage = context.messages[context.messages.length - 1]; - if (!lastMessage) { + streamFn?: StreamFn, +): EventStream { + if (context.messages.length === 0) { throw new Error("Cannot continue: no messages in context"); } - if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") { - throw new Error(`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`); + + if (context.messages[context.messages.length - 1].role === "assistant") { + throw new Error("Cannot continue from message role: assistant"); } const stream = createAgentStream(); (async () => { - const newMessages: AgentContext["messages"] = []; + const newMessages: AgentMessage[] = []; const currentContext: AgentContext = { ...context }; stream.push({ type: "agent_start" }); stream.push({ type: "turn_start" }); - // No user message events - we're continuing from existing context await runLoop(currentContext, newMessages, config, signal, stream, streamFn); })(); @@ -71,28 +91,28 @@ export function agentLoopContinue( return stream; } -function createAgentStream(): EventStream { - return new EventStream( +function createAgentStream(): EventStream { + return new EventStream( (event: AgentEvent) => event.type === "agent_end", (event: AgentEvent) => (event.type === "agent_end" ? event.messages : []), ); } /** - * Shared loop logic for both agentLoop and agentLoopContinue. + * Main loop logic shared by agentLoop and agentLoopContinue. */ async function runLoop( currentContext: AgentContext, - newMessages: AgentContext["messages"], + newMessages: AgentMessage[], config: AgentLoopConfig, signal: AbortSignal | undefined, - stream: EventStream, - streamFn?: typeof streamSimple, + stream: EventStream, + streamFn?: StreamFn, ): Promise { let hasMoreToolCalls = true; let firstTurn = true; - let queuedMessages: QueuedMessage[] = (await config.getQueuedMessages?.()) || []; - let queuedAfterTools: QueuedMessage[] | null = null; + let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || []; + let queuedAfterTools: AgentMessage[] | null = null; while (hasMoreToolCalls || queuedMessages.length > 0) { if (!firstTurn) { @@ -101,15 +121,13 @@ async function runLoop( firstTurn = false; } - // Process queued messages first (inject before next assistant response) + // Process queued messages (inject before next assistant response) if (queuedMessages.length > 0) { - for (const { original, llm } of queuedMessages) { - stream.push({ type: "message_start", message: original }); - stream.push({ type: "message_end", message: original }); - if (llm) { - currentContext.messages.push(llm); - newMessages.push(llm); - } + for (const message of queuedMessages) { + stream.push({ type: "message_start", message }); + stream.push({ type: "message_end", message }); + currentContext.messages.push(message); + newMessages.push(message); } queuedMessages = []; } @@ -119,7 +137,6 @@ async function runLoop( newMessages.push(message); if (message.stopReason === "error" || message.stopReason === "aborted") { - // Stop the loop on error or abort stream.push({ type: "turn_end", message, toolResults: [] }); stream.push({ type: "agent_end", messages: newMessages }); stream.end(newMessages); @@ -132,7 +149,6 @@ async function runLoop( const toolResults: ToolResultMessage[] = []; if (hasMoreToolCalls) { - // Execute tool calls const toolExecution = await executeToolCalls( currentContext.tools, message, @@ -142,10 +158,14 @@ async function runLoop( ); toolResults.push(...toolExecution.toolResults); queuedAfterTools = toolExecution.queuedMessages ?? null; - currentContext.messages.push(...toolResults); - newMessages.push(...toolResults); + + for (const result of toolResults) { + currentContext.messages.push(result); + newMessages.push(result); + } } - stream.push({ type: "turn_end", message, toolResults: toolResults }); + + stream.push({ type: "turn_end", message, toolResults }); // Get queued messages after turn completes if (queuedAfterTools && queuedAfterTools.length > 0) { @@ -160,41 +180,44 @@ async function runLoop( stream.end(newMessages); } -// Helper functions +/** + * Stream an assistant response from the LLM. + * This is where AgentMessage[] gets transformed to Message[] for the LLM. + */ async function streamAssistantResponse( context: AgentContext, config: AgentLoopConfig, signal: AbortSignal | undefined, - stream: EventStream, - streamFn?: typeof streamSimple, + stream: EventStream, + streamFn?: StreamFn, ): Promise { - // Convert AgentContext to Context for streamSimple - // Use a copy of messages to avoid mutating the original context - const processedMessages = config.preprocessor - ? await config.preprocessor(context.messages, signal) - : [...context.messages]; - const processedContext: Context = { + // Apply context transform if configured (AgentMessage[] → AgentMessage[]) + let messages = context.messages; + if (config.transformContext) { + messages = await config.transformContext(messages, signal); + } + + // Convert to LLM-compatible messages (AgentMessage[] → Message[]) + const llmMessages = await config.convertToLlm(messages); + + // Build LLM context + const llmContext: Context = { systemPrompt: context.systemPrompt, - messages: [...processedMessages].map((m) => { - if (m.role === "toolResult") { - // biome-ignore lint/correctness/noUnusedVariables: fine here - const { details, ...rest } = m; - return rest; - } else { - return m; - } - }), - tools: context.tools, // AgentTool extends Tool, so this works + messages: llmMessages, + tools: context.tools, }; - // Use custom stream function if provided, otherwise use default streamSimple const streamFunction = streamFn || streamSimple; - // Resolve API key for every assistant response (important for expiring tokens) + // Resolve API key (important for expiring tokens) const resolvedApiKey = (config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey; - const response = await streamFunction(config.model, processedContext, { ...config, apiKey: resolvedApiKey, signal }); + const response = await streamFunction(config.model, llmContext, { + ...config, + apiKey: resolvedApiKey, + signal, + }); let partialMessage: AssistantMessage | null = null; let addedPartial = false; @@ -220,7 +243,11 @@ async function streamAssistantResponse( if (partialMessage) { partialMessage = event.partial; context.messages[context.messages.length - 1] = partialMessage; - stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } }); + stream.push({ + type: "message_update", + assistantMessageEvent: event, + message: { ...partialMessage }, + }); } break; @@ -244,16 +271,19 @@ async function streamAssistantResponse( return await response.result(); } -async function executeToolCalls( - tools: AgentTool[] | undefined, +/** + * Execute tool calls from an assistant message. + */ +async function executeToolCalls( + tools: AgentTool[] | undefined, assistantMessage: AssistantMessage, signal: AbortSignal | undefined, - stream: EventStream, + stream: EventStream, getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], -): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: QueuedMessage[] }> { +): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> { const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); - const results: ToolResultMessage[] = []; - let queuedMessages: QueuedMessage[] | undefined; + const results: ToolResultMessage[] = []; + let queuedMessages: AgentMessage[] | undefined; for (let index = 0; index < toolCalls.length; index++) { const toolCall = toolCalls[index]; @@ -266,16 +296,14 @@ async function executeToolCalls( args: toolCall.arguments, }); - let result: AgentToolResult; + let result: AgentToolResult; let isError = false; try { if (!tool) throw new Error(`Tool ${toolCall.name} not found`); - // Validate arguments using shared validation function const validatedArgs = validateToolArguments(tool, toolCall); - // Execute with validated, typed arguments, passing update callback result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => { stream.push({ type: "tool_execution_update", @@ -288,7 +316,7 @@ async function executeToolCalls( } catch (e) { result = { content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }], - details: {} as T, + details: {}, }; isError = true; } @@ -301,7 +329,7 @@ async function executeToolCalls( isError, }); - const toolResultMessage: ToolResultMessage = { + const toolResultMessage: ToolResultMessage = { role: "toolResult", toolCallId: toolCall.id, toolName: toolCall.name, @@ -315,6 +343,7 @@ async function executeToolCalls( stream.push({ type: "message_start", message: toolResultMessage }); stream.push({ type: "message_end", message: toolResultMessage }); + // Check for queued messages - skip remaining tools if user interrupted if (getQueuedMessages) { const queued = await getQueuedMessages(); if (queued.length > 0) { @@ -331,13 +360,13 @@ async function executeToolCalls( return { toolResults: results, queuedMessages }; } -function skipToolCall( +function skipToolCall( toolCall: Extract, - stream: EventStream, -): ToolResultMessage { - const result: AgentToolResult = { + stream: EventStream, +): ToolResultMessage { + const result: AgentToolResult = { content: [{ type: "text", text: "Skipped due to queued user message." }], - details: {} as T, + details: {}, }; stream.push({ @@ -354,12 +383,12 @@ function skipToolCall( isError: true, }); - const toolResultMessage: ToolResultMessage = { + const toolResultMessage: ToolResultMessage = { role: "toolResult", toolCallId: toolCall.id, toolName: toolCall.name, content: result.content, - details: result.details, + details: {}, isError: true, timestamp: Date.now(), }; diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 86febabb..078b707e 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -1,62 +1,66 @@ -import type { ImageContent, Message, QueuedMessage, ReasoningEffort, TextContent } from "@mariozechner/pi-ai"; -import { getModel } from "@mariozechner/pi-ai"; -import type { AgentTransport } from "./transports/types.js"; -import type { AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "./types.js"; +/** + * Agent class that uses the agent-loop directly. + * No transport abstraction - calls streamSimple via the loop. + */ + +import { + getModel, + type ImageContent, + type Message, + type Model, + type ReasoningEffort, + streamSimple, + type TextContent, +} from "@mariozechner/pi-ai"; +import { agentLoop, agentLoopContinue } from "./agent-loop.js"; +import type { + AgentContext, + AgentEvent, + AgentLoopConfig, + AgentMessage, + AgentState, + AgentTool, + StreamFn, + ThinkingLevel, +} from "./types.js"; /** - * Default message transformer: Keep only LLM-compatible messages, strip app-specific fields. - * Converts attachments to proper content blocks (images → ImageContent, documents → TextContent). + * Default convertToLlm: Keep only LLM-compatible messages, convert attachments. */ -function defaultMessageTransformer(messages: AppMessage[]): Message[] { - return messages - .filter((m) => { - // Only keep standard LLM message roles - return m.role === "user" || m.role === "assistant" || m.role === "toolResult"; - }) - .map((m) => { - if (m.role === "user") { - const { attachments, ...rest } = m as any; - - // If no attachments, return as-is - if (!attachments || attachments.length === 0) { - return rest as Message; - } - - // Convert attachments to content blocks - const content = Array.isArray(rest.content) ? [...rest.content] : [{ type: "text", text: rest.content }]; - - for (const attachment of attachments as Attachment[]) { - // Add image blocks for image attachments - if (attachment.type === "image") { - content.push({ - type: "image", - data: attachment.content, - mimeType: attachment.mimeType, - } as ImageContent); - } - // Add text blocks for documents with extracted text - else if (attachment.type === "document" && attachment.extractedText) { - content.push({ - type: "text", - text: `\n\n[Document: ${attachment.fileName}]\n${attachment.extractedText}`, - isDocument: true, - } as TextContent); - } - } - - return { ...rest, content } as Message; - } - return m as Message; - }); +function defaultConvertToLlm(messages: AgentMessage[]): Message[] { + return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult"); } export interface AgentOptions { initialState?: Partial; - transport: AgentTransport; - // Transform app messages to LLM-compatible messages before sending to transport - messageTransformer?: (messages: AppMessage[]) => Message[] | Promise; - // Queue mode: "all" = send all queued messages at once, "one-at-a-time" = send one queued message per turn + + /** + * Converts AgentMessage[] to LLM-compatible Message[] before each LLM call. + * Default filters to user/assistant/toolResult and converts attachments. + */ + convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise; + + /** + * Optional transform applied to context before convertToLlm. + * Use for context pruning, injecting external context, etc. + */ + transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; + + /** + * Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn + */ queueMode?: "all" | "one-at-a-time"; + + /** + * Custom stream function (for proxy backends, etc.). Default uses streamSimple. + */ + streamFn?: StreamFn; + + /** + * Resolves an API key dynamically for each LLM call. + * Useful for expiring tokens (e.g., GitHub Copilot OAuth). + */ + getApiKey?: (provider: string) => Promise | string | undefined; } export class Agent { @@ -71,20 +75,25 @@ export class Agent { pendingToolCalls: new Set(), error: undefined, }; + private listeners = new Set<(e: AgentEvent) => void>(); private abortController?: AbortController; - private transport: AgentTransport; - private messageTransformer: (messages: AppMessage[]) => Message[] | Promise; - private messageQueue: Array> = []; + private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise; + private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; + private messageQueue: AgentMessage[] = []; private queueMode: "all" | "one-at-a-time"; + public streamFn: StreamFn; + public getApiKey?: (provider: string) => Promise | string | undefined; private runningPrompt?: Promise; private resolveRunningPrompt?: () => void; - constructor(opts: AgentOptions) { + constructor(opts: AgentOptions = {}) { this._state = { ...this._state, ...opts.initialState }; - this.transport = opts.transport; - this.messageTransformer = opts.messageTransformer || defaultMessageTransformer; + this.convertToLlm = opts.convertToLlm || defaultConvertToLlm; + this.transformContext = opts.transformContext; this.queueMode = opts.queueMode || "one-at-a-time"; + this.streamFn = opts.streamFn || streamSimple; + this.getApiKey = opts.getApiKey; } get state(): AgentState { @@ -96,12 +105,12 @@ export class Agent { return () => this.listeners.delete(fn); } - // State mutators - update internal state without emitting events + // State mutators setSystemPrompt(v: string) { this._state.systemPrompt = v; } - setModel(m: typeof this._state.model) { + setModel(m: Model) { this._state.model = m; } @@ -117,25 +126,20 @@ export class Agent { return this.queueMode; } - setTools(t: typeof this._state.tools) { + setTools(t: AgentTool[]) { this._state.tools = t; } - replaceMessages(ms: AppMessage[]) { + replaceMessages(ms: AgentMessage[]) { this._state.messages = ms.slice(); } - appendMessage(m: AppMessage) { + appendMessage(m: AgentMessage) { this._state.messages = [...this._state.messages, m]; } - async queueMessage(m: AppMessage) { - // Transform message and queue it for injection at next turn - const transformed = await this.messageTransformer([m]); - this.messageQueue.push({ - original: m, - llm: transformed[0], // undefined if filtered out - }); + queueMessage(m: AgentMessage) { + this.messageQueue.push(m); } clearMessageQueue() { @@ -150,17 +154,10 @@ export class Agent { this.abortController?.abort(); } - /** - * Returns a promise that resolves when the current prompt completes. - * Returns immediately resolved promise if no prompt is running. - */ waitForIdle(): Promise { return this.runningPrompt ?? Promise.resolve(); } - /** - * Clear all messages and state. Call abort() first if a prompt is in flight. - */ reset() { this._state.messages = []; this._state.isStreaming = false; @@ -170,86 +167,57 @@ export class Agent { this.messageQueue = []; } - async prompt(input: string, attachments?: Attachment[]) { + /** Send a prompt with an AgentMessage */ + async prompt(message: AgentMessage | AgentMessage[]): Promise; + async prompt(input: string, images?: ImageContent[]): Promise; + async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) { const model = this._state.model; - if (!model) { - throw new Error("No model configured"); - } + if (!model) throw new Error("No model configured"); - // Build user message with attachments - const content: Array = [{ type: "text", text: input }]; - if (attachments?.length) { - for (const a of attachments) { - if (a.type === "image") { - content.push({ type: "image", data: a.content, mimeType: a.mimeType }); - } else if (a.type === "document" && a.extractedText) { - content.push({ - type: "text", - text: `\n\n[Document: ${a.fileName}]\n${a.extractedText}`, - isDocument: true, - } as TextContent); - } + let msgs: AgentMessage[]; + + if (Array.isArray(input)) { + msgs = input; + } else if (typeof input === "string") { + const content: Array = [{ type: "text", text: input }]; + if (images && images.length > 0) { + content.push(...images); } + msgs = [ + { + role: "user", + content, + timestamp: Date.now(), + }, + ]; + } else { + msgs = [input]; } - const userMessage: AppMessage = { - role: "user", - content, - attachments: attachments?.length ? attachments : undefined, - timestamp: Date.now(), - }; - - await this._runAgentLoop(userMessage); + await this._runLoop(msgs); } - /** - * Continue from the current context without adding a new user message. - * Used for retry after overflow recovery when context already has user message or tool results. - */ + /** Continue from current context (for retry after overflow) */ async continue() { const messages = this._state.messages; if (messages.length === 0) { throw new Error("No messages to continue from"); } - - const lastMessage = messages[messages.length - 1]; - if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") { - throw new Error(`Cannot continue from message role: ${lastMessage.role}`); + if (messages[messages.length - 1].role === "assistant") { + throw new Error("Cannot continue from message role: assistant"); } - await this._runAgentLoopContinue(); + await this._runLoop(undefined); } /** - * Internal: Run the agent loop with a new user message. + * Run the agent loop. + * If messages are provided, starts a new conversation turn with those messages. + * Otherwise, continues from existing context. */ - private async _runAgentLoop(userMessage: AppMessage) { - const { llmMessages, cfg } = await this._prepareRun(); - - const events = this.transport.run(llmMessages, userMessage as Message, cfg, this.abortController!.signal); - - await this._processEvents(events); - } - - /** - * Internal: Continue the agent loop from current context. - */ - private async _runAgentLoopContinue() { - const { llmMessages, cfg } = await this._prepareRun(); - - const events = this.transport.continue(llmMessages, cfg, this.abortController!.signal); - - await this._processEvents(events); - } - - /** - * Prepare for running the agent loop. - */ - private async _prepareRun() { + private async _runLoop(messages?: AgentMessage[]) { const model = this._state.model; - if (!model) { - throw new Error("No model configured"); - } + if (!model) throw new Error("No model configured"); this.runningPrompt = new Promise((resolve) => { this.resolveRunningPrompt = resolve; @@ -265,87 +233,90 @@ export class Agent { ? undefined : this._state.thinkingLevel === "minimal" ? "low" - : this._state.thinkingLevel; + : (this._state.thinkingLevel as ReasoningEffort); - const cfg = { + const context: AgentContext = { systemPrompt: this._state.systemPrompt, + messages: this._state.messages.slice(), tools: this._state.tools, + }; + + const config: AgentLoopConfig = { model, reasoning, - getQueuedMessages: async () => { + convertToLlm: this.convertToLlm, + transformContext: this.transformContext, + getApiKey: this.getApiKey, + getQueuedMessages: async () => { if (this.queueMode === "one-at-a-time") { if (this.messageQueue.length > 0) { const first = this.messageQueue[0]; this.messageQueue = this.messageQueue.slice(1); - return [first] as QueuedMessage[]; + return [first]; } return []; } else { const queued = this.messageQueue.slice(); this.messageQueue = []; - return queued as QueuedMessage[]; + return queued; } }, }; - const llmMessages = await this.messageTransformer(this._state.messages); - - return { llmMessages, cfg, model }; - } - - /** - * Process events from the transport. - */ - private async _processEvents(events: AsyncIterable) { - const model = this._state.model!; - const generatedMessages: AppMessage[] = []; - let partial: AppMessage | null = null; + let partial: AgentMessage | null = null; try { - for await (const ev of events) { - switch (ev.type) { - case "message_start": { - partial = ev.message as AppMessage; - this._state.streamMessage = ev.message as Message; + const stream = messages + ? agentLoop(messages, context, config, this.abortController.signal, this.streamFn) + : agentLoopContinue(context, config, this.abortController.signal, this.streamFn); + + for await (const event of stream) { + // Update internal state based on events + switch (event.type) { + case "message_start": + partial = event.message; + this._state.streamMessage = event.message; break; - } - case "message_update": { - partial = ev.message as AppMessage; - this._state.streamMessage = ev.message as Message; + + case "message_update": + partial = event.message; + this._state.streamMessage = event.message; break; - } - case "message_end": { + + case "message_end": partial = null; this._state.streamMessage = null; - this.appendMessage(ev.message as AppMessage); - generatedMessages.push(ev.message as AppMessage); + this.appendMessage(event.message); break; - } + case "tool_execution_start": { const s = new Set(this._state.pendingToolCalls); - s.add(ev.toolCallId); + s.add(event.toolCallId); this._state.pendingToolCalls = s; break; } + case "tool_execution_end": { const s = new Set(this._state.pendingToolCalls); - s.delete(ev.toolCallId); + s.delete(event.toolCallId); this._state.pendingToolCalls = s; break; } - case "turn_end": { - if (ev.message.role === "assistant" && ev.message.errorMessage) { - this._state.error = ev.message.errorMessage; + + case "turn_end": + if (event.message.role === "assistant" && (event.message as any).errorMessage) { + this._state.error = (event.message as any).errorMessage; } break; - } - case "agent_end": { + + case "agent_end": + this._state.isStreaming = false; this._state.streamMessage = null; break; - } } - this.emit(ev as AgentEvent); + // Emit to listeners + this.emit(event); } // Handle any remaining partial message @@ -357,8 +328,7 @@ export class Agent { (c.type === "toolCall" && c.name.trim().length > 0), ); if (!onlyEmpty) { - this.appendMessage(partial as AppMessage); - generatedMessages.push(partial as AppMessage); + this.appendMessage(partial); } else { if (this.abortController?.signal.aborted) { throw new Error("Request was aborted"); @@ -366,7 +336,7 @@ export class Agent { } } } catch (err: any) { - const msg: Message = { + const errorMsg: AgentMessage = { role: "assistant", content: [{ type: "text", text: "" }], api: model.api, @@ -383,10 +353,11 @@ export class Agent { stopReason: this.abortController?.signal.aborted ? "aborted" : "error", errorMessage: err?.message || String(err), timestamp: Date.now(), - }; - this.appendMessage(msg as AppMessage); - generatedMessages.push(msg as AppMessage); + } as AgentMessage; + + this.appendMessage(errorMsg); this._state.error = err?.message || String(err); + this.emit({ type: "agent_end", messages: [errorMsg] }); } finally { this._state.isStreaming = false; this._state.streamMessage = null; diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index 95a5094a..1f287856 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -1,22 +1,6 @@ // Core Agent -export { Agent, type AgentOptions } from "./agent.js"; -// Transports -export { - type AgentRunConfig, - type AgentTransport, - AppTransport, - type AppTransportOptions, - ProviderTransport, - type ProviderTransportOptions, - type ProxyAssistantMessageEvent, -} from "./transports/index.js"; +export * from "./agent.js"; +// Loop functions +export * from "./agent-loop.js"; // Types -export type { - AgentEvent, - AgentState, - AppMessage, - Attachment, - CustomMessages, - ThinkingLevel, - UserMessageWithAttachments, -} from "./types.js"; +export * from "./types.js"; diff --git a/packages/agent/src/proxy.ts b/packages/agent/src/proxy.ts new file mode 100644 index 00000000..ba098cd6 --- /dev/null +++ b/packages/agent/src/proxy.ts @@ -0,0 +1,340 @@ +/** + * Proxy stream function for apps that route LLM calls through a server. + * The server manages auth and proxies requests to LLM providers. + */ + +import { + type AssistantMessage, + type AssistantMessageEvent, + type Context, + EventStream, + type Model, + type SimpleStreamOptions, + type StopReason, + type ToolCall, +} from "@mariozechner/pi-ai"; +// Internal import for JSON parsing utility +import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js"; + +// Create stream class matching ProxyMessageEventStream +class ProxyMessageEventStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +/** + * Proxy event types - server sends these with partial field stripped to reduce bandwidth. + */ +export type ProxyAssistantMessageEvent = + | { type: "start" } + | { type: "text_start"; contentIndex: number } + | { type: "text_delta"; contentIndex: number; delta: string } + | { type: "text_end"; contentIndex: number; contentSignature?: string } + | { type: "thinking_start"; contentIndex: number } + | { type: "thinking_delta"; contentIndex: number; delta: string } + | { type: "thinking_end"; contentIndex: number; contentSignature?: string } + | { type: "toolcall_start"; contentIndex: number; id: string; toolName: string } + | { type: "toolcall_delta"; contentIndex: number; delta: string } + | { type: "toolcall_end"; contentIndex: number } + | { + type: "done"; + reason: Extract; + usage: AssistantMessage["usage"]; + } + | { + type: "error"; + reason: Extract; + errorMessage?: string; + usage: AssistantMessage["usage"]; + }; + +export interface ProxyStreamOptions extends SimpleStreamOptions { + /** Auth token for the proxy server */ + authToken: string; + /** Proxy server URL (e.g., "https://genai.example.com") */ + proxyUrl: string; +} + +/** + * Stream function that proxies through a server instead of calling LLM providers directly. + * The server strips the partial field from delta events to reduce bandwidth. + * We reconstruct the partial message client-side. + * + * Use this as the `streamFn` option when creating an Agent that needs to go through a proxy. + * + * @example + * ```typescript + * const agent = new Agent({ + * streamFn: (model, context, options) => + * streamProxy(model, context, { + * ...options, + * authToken: await getAuthToken(), + * proxyUrl: "https://genai.example.com", + * }), + * }); + * ``` + */ +export function streamProxy(model: Model, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream { + const stream = new ProxyMessageEventStream(); + + (async () => { + // Initialize the partial message that we'll build up from events + const partial: AssistantMessage = { + role: "assistant", + stopReason: "stop", + content: [], + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }; + + let reader: ReadableStreamDefaultReader | undefined; + + const abortHandler = () => { + if (reader) { + reader.cancel("Request aborted by user").catch(() => {}); + } + }; + + if (options.signal) { + options.signal.addEventListener("abort", abortHandler); + } + + try { + const response = await fetch(`${options.proxyUrl}/api/stream`, { + method: "POST", + headers: { + Authorization: `Bearer ${options.authToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model, + context, + options: { + temperature: options.temperature, + maxTokens: options.maxTokens, + reasoning: options.reasoning, + }, + }), + signal: options.signal, + }); + + if (!response.ok) { + let errorMessage = `Proxy error: ${response.status} ${response.statusText}`; + try { + const errorData = (await response.json()) as { error?: string }; + if (errorData.error) { + errorMessage = `Proxy error: ${errorData.error}`; + } + } catch { + // Couldn't parse error response + } + throw new Error(errorMessage); + } + + reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + if (options.signal?.aborted) { + throw new Error("Request aborted by user"); + } + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6).trim(); + if (data) { + const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent; + const event = processProxyEvent(proxyEvent, partial); + if (event) { + stream.push(event); + } + } + } + } + } + + if (options.signal?.aborted) { + throw new Error("Request aborted by user"); + } + + stream.end(); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + const reason = options.signal?.aborted ? "aborted" : "error"; + partial.stopReason = reason; + partial.errorMessage = errorMessage; + stream.push({ + type: "error", + reason, + error: partial, + }); + stream.end(); + } finally { + if (options.signal) { + options.signal.removeEventListener("abort", abortHandler); + } + } + })(); + + return stream; +} + +/** + * Process a proxy event and update the partial message. + */ +function processProxyEvent( + proxyEvent: ProxyAssistantMessageEvent, + partial: AssistantMessage, +): AssistantMessageEvent | undefined { + switch (proxyEvent.type) { + case "start": + return { type: "start", partial }; + + case "text_start": + partial.content[proxyEvent.contentIndex] = { type: "text", text: "" }; + return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial }; + + case "text_delta": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "text") { + content.text += proxyEvent.delta; + return { + type: "text_delta", + contentIndex: proxyEvent.contentIndex, + delta: proxyEvent.delta, + partial, + }; + } + throw new Error("Received text_delta for non-text content"); + } + + case "text_end": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "text") { + content.textSignature = proxyEvent.contentSignature; + return { + type: "text_end", + contentIndex: proxyEvent.contentIndex, + content: content.text, + partial, + }; + } + throw new Error("Received text_end for non-text content"); + } + + case "thinking_start": + partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" }; + return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial }; + + case "thinking_delta": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "thinking") { + content.thinking += proxyEvent.delta; + return { + type: "thinking_delta", + contentIndex: proxyEvent.contentIndex, + delta: proxyEvent.delta, + partial, + }; + } + throw new Error("Received thinking_delta for non-thinking content"); + } + + case "thinking_end": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "thinking") { + content.thinkingSignature = proxyEvent.contentSignature; + return { + type: "thinking_end", + contentIndex: proxyEvent.contentIndex, + content: content.thinking, + partial, + }; + } + throw new Error("Received thinking_end for non-thinking content"); + } + + case "toolcall_start": + partial.content[proxyEvent.contentIndex] = { + type: "toolCall", + id: proxyEvent.id, + name: proxyEvent.toolName, + arguments: {}, + partialJson: "", + } satisfies ToolCall & { partialJson: string } as ToolCall; + return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial }; + + case "toolcall_delta": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "toolCall") { + (content as any).partialJson += proxyEvent.delta; + content.arguments = parseStreamingJson((content as any).partialJson) || {}; + partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity + return { + type: "toolcall_delta", + contentIndex: proxyEvent.contentIndex, + delta: proxyEvent.delta, + partial, + }; + } + throw new Error("Received toolcall_delta for non-toolCall content"); + } + + case "toolcall_end": { + const content = partial.content[proxyEvent.contentIndex]; + if (content?.type === "toolCall") { + delete (content as any).partialJson; + return { + type: "toolcall_end", + contentIndex: proxyEvent.contentIndex, + toolCall: content, + partial, + }; + } + return undefined; + } + + case "done": + partial.stopReason = proxyEvent.reason; + partial.usage = proxyEvent.usage; + return { type: "done", reason: proxyEvent.reason, message: partial }; + + case "error": + partial.stopReason = proxyEvent.reason; + partial.errorMessage = proxyEvent.errorMessage; + partial.usage = proxyEvent.usage; + return { type: "error", reason: proxyEvent.reason, error: partial }; + + default: { + const _exhaustiveCheck: never = proxyEvent; + console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`); + return undefined; + } + } +} diff --git a/packages/agent/src/transports/AppTransport.ts b/packages/agent/src/transports/AppTransport.ts deleted file mode 100644 index 69b9af46..00000000 --- a/packages/agent/src/transports/AppTransport.ts +++ /dev/null @@ -1,397 +0,0 @@ -import type { - AgentContext, - AgentLoopConfig, - Api, - AssistantMessage, - AssistantMessageEvent, - Context, - Message, - Model, - SimpleStreamOptions, - ToolCall, - UserMessage, -} from "@mariozechner/pi-ai"; -import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai"; -import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js"; -import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js"; -import type { ProxyAssistantMessageEvent } from "./proxy-types.js"; -import type { AgentRunConfig, AgentTransport } from "./types.js"; - -/** - * Stream function that proxies through a server instead of calling providers directly. - * The server strips the partial field from delta events to reduce bandwidth. - * We reconstruct the partial message client-side. - */ -function streamSimpleProxy( - model: Model, - context: Context, - options: SimpleStreamOptions & { authToken: string }, - proxyUrl: string, -): AssistantMessageEventStream { - const stream = new AssistantMessageEventStream(); - - (async () => { - // Initialize the partial message that we'll build up from events - const partial: AssistantMessage = { - role: "assistant", - stopReason: "stop", - content: [], - api: model.api, - provider: model.provider, - model: model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - timestamp: Date.now(), - }; - - let reader: ReadableStreamDefaultReader | undefined; - - // Set up abort handler to cancel the reader - const abortHandler = () => { - if (reader) { - reader.cancel("Request aborted by user").catch(() => {}); - } - }; - - if (options.signal) { - options.signal.addEventListener("abort", abortHandler); - } - - try { - const response = await fetch(`${proxyUrl}/api/stream`, { - method: "POST", - headers: { - Authorization: `Bearer ${options.authToken}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model, - context, - options: { - temperature: options.temperature, - maxTokens: options.maxTokens, - reasoning: options.reasoning, - // Don't send apiKey or signal - those are added server-side - }, - }), - signal: options.signal, - }); - - if (!response.ok) { - let errorMessage = `Proxy error: ${response.status} ${response.statusText}`; - try { - const errorData = (await response.json()) as { error?: string }; - if (errorData.error) { - errorMessage = `Proxy error: ${errorData.error}`; - } - } catch { - // Couldn't parse error response, use default message - } - throw new Error(errorMessage); - } - - // Parse SSE stream - reader = response.body!.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - // Check if aborted after reading - if (options.signal?.aborted) { - throw new Error("Request aborted by user"); - } - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - if (line.startsWith("data: ")) { - const data = line.slice(6).trim(); - if (data) { - const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent; - let event: AssistantMessageEvent | undefined; - - // Handle different event types - // Server sends events with partial for non-delta events, - // and without partial for delta events - switch (proxyEvent.type) { - case "start": - event = { type: "start", partial }; - break; - - case "text_start": - partial.content[proxyEvent.contentIndex] = { - type: "text", - text: "", - }; - event = { type: "text_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "text_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "text") { - content.text += proxyEvent.delta; - event = { - type: "text_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - } else { - throw new Error("Received text_delta for non-text content"); - } - break; - } - case "text_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "text") { - content.textSignature = proxyEvent.contentSignature; - event = { - type: "text_end", - contentIndex: proxyEvent.contentIndex, - content: content.text, - partial, - }; - } else { - throw new Error("Received text_end for non-text content"); - } - break; - } - - case "thinking_start": - partial.content[proxyEvent.contentIndex] = { - type: "thinking", - thinking: "", - }; - event = { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "thinking_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "thinking") { - content.thinking += proxyEvent.delta; - event = { - type: "thinking_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - } else { - throw new Error("Received thinking_delta for non-thinking content"); - } - break; - } - - case "thinking_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "thinking") { - content.thinkingSignature = proxyEvent.contentSignature; - event = { - type: "thinking_end", - contentIndex: proxyEvent.contentIndex, - content: content.thinking, - partial, - }; - } else { - throw new Error("Received thinking_end for non-thinking content"); - } - break; - } - - case "toolcall_start": - partial.content[proxyEvent.contentIndex] = { - type: "toolCall", - id: proxyEvent.id, - name: proxyEvent.toolName, - arguments: {}, - partialJson: "", - } satisfies ToolCall & { partialJson: string } as ToolCall; - event = { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "toolcall_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "toolCall") { - (content as any).partialJson += proxyEvent.delta; - content.arguments = parseStreamingJson((content as any).partialJson) || {}; - event = { - type: "toolcall_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity - } else { - throw new Error("Received toolcall_delta for non-toolCall content"); - } - break; - } - - case "toolcall_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "toolCall") { - delete (content as any).partialJson; - event = { - type: "toolcall_end", - contentIndex: proxyEvent.contentIndex, - toolCall: content, - partial, - }; - } - break; - } - - case "done": - partial.stopReason = proxyEvent.reason; - partial.usage = proxyEvent.usage; - event = { type: "done", reason: proxyEvent.reason, message: partial }; - break; - - case "error": - partial.stopReason = proxyEvent.reason; - partial.errorMessage = proxyEvent.errorMessage; - partial.usage = proxyEvent.usage; - event = { type: "error", reason: proxyEvent.reason, error: partial }; - break; - - default: { - // Exhaustive check - const _exhaustiveCheck: never = proxyEvent; - console.warn(`Unhandled event type: ${(proxyEvent as any).type}`); - break; - } - } - - // Push the event to stream - if (event) { - stream.push(event); - } else { - throw new Error("Failed to create event from proxy event"); - } - } - } - } - } - - // Check if aborted after reading - if (options.signal?.aborted) { - throw new Error("Request aborted by user"); - } - - stream.end(); - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error); - partial.stopReason = options.signal?.aborted ? "aborted" : "error"; - partial.errorMessage = errorMessage; - stream.push({ - type: "error", - reason: partial.stopReason, - error: partial, - } satisfies AssistantMessageEvent); - stream.end(); - } finally { - // Clean up abort handler - if (options.signal) { - options.signal.removeEventListener("abort", abortHandler); - } - } - })(); - - return stream; -} - -export interface AppTransportOptions { - /** - * Proxy server URL. The server manages user accounts and proxies requests to LLM providers. - * Example: "https://genai.mariozechner.at" - */ - proxyUrl: string; - - /** - * Function to retrieve auth token for the proxy server. - * The token is used for user authentication and authorization. - */ - getAuthToken: () => Promise | string; -} - -/** - * Transport that uses an app server with user authentication tokens. - * The server manages user accounts and proxies requests to LLM providers. - */ -export class AppTransport implements AgentTransport { - private options: AppTransportOptions; - - constructor(options: AppTransportOptions) { - this.options = options; - } - - private async getStreamFn(authToken: string) { - return (model: Model, context: Context, options?: SimpleStreamOptions) => { - return streamSimpleProxy( - model, - context, - { - ...options, - authToken, - }, - this.options.proxyUrl, - ); - }; - } - - private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext { - return { - systemPrompt: cfg.systemPrompt, - messages, - tools: cfg.tools, - }; - } - - private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig { - return { - model: cfg.model, - reasoning: cfg.reasoning, - getQueuedMessages: cfg.getQueuedMessages, - }; - } - - async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) { - const authToken = await this.options.getAuthToken(); - if (!authToken) { - throw new Error("Auth token is required for AppTransport"); - } - - const streamFn = await this.getStreamFn(authToken); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(cfg); - - for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) { - yield ev; - } - } - - async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) { - const authToken = await this.options.getAuthToken(); - if (!authToken) { - throw new Error("Auth token is required for AppTransport"); - } - - const streamFn = await this.getStreamFn(authToken); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(cfg); - - for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) { - yield ev; - } - } -} diff --git a/packages/agent/src/transports/ProviderTransport.ts b/packages/agent/src/transports/ProviderTransport.ts deleted file mode 100644 index 024db0e4..00000000 --- a/packages/agent/src/transports/ProviderTransport.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { - type AgentContext, - type AgentLoopConfig, - agentLoop, - agentLoopContinue, - type Message, - type UserMessage, -} from "@mariozechner/pi-ai"; -import type { AgentRunConfig, AgentTransport } from "./types.js"; - -export interface ProviderTransportOptions { - /** - * Function to retrieve API key for a given provider. - * If not provided, transport will try to use environment variables. - */ - getApiKey?: (provider: string) => Promise | string | undefined; - - /** - * Optional CORS proxy URL for browser environments. - * If provided, all requests will be routed through this proxy. - * Format: "https://proxy.example.com" - */ - corsProxyUrl?: string; -} - -/** - * Transport that calls LLM providers directly. - * Optionally routes calls through a CORS proxy if configured. - */ -export class ProviderTransport implements AgentTransport { - private options: ProviderTransportOptions; - - constructor(options: ProviderTransportOptions = {}) { - this.options = options; - } - - private getModel(cfg: AgentRunConfig) { - let model = cfg.model; - if (this.options.corsProxyUrl && cfg.model.baseUrl) { - model = { - ...cfg.model, - baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`, - }; - } - return model; - } - - private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext { - return { - systemPrompt: cfg.systemPrompt, - messages, - tools: cfg.tools, - }; - } - - private buildLoopConfig(model: AgentRunConfig["model"], cfg: AgentRunConfig): AgentLoopConfig { - return { - model, - reasoning: cfg.reasoning, - // Resolve API key per assistant response (important for expiring OAuth tokens) - getApiKey: this.options.getApiKey, - getQueuedMessages: cfg.getQueuedMessages, - }; - } - - async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) { - const model = this.getModel(cfg); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(model, cfg); - - for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) { - yield ev; - } - } - - async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) { - const model = this.getModel(cfg); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(model, cfg); - - for await (const ev of agentLoopContinue(context, pc, signal)) { - yield ev; - } - } -} diff --git a/packages/agent/src/transports/index.ts b/packages/agent/src/transports/index.ts deleted file mode 100644 index 7dfc9702..00000000 --- a/packages/agent/src/transports/index.ts +++ /dev/null @@ -1,4 +0,0 @@ -export { AppTransport, type AppTransportOptions } from "./AppTransport.js"; -export { ProviderTransport, type ProviderTransportOptions } from "./ProviderTransport.js"; -export type { ProxyAssistantMessageEvent } from "./proxy-types.js"; -export type { AgentRunConfig, AgentTransport } from "./types.js"; diff --git a/packages/agent/src/transports/proxy-types.ts b/packages/agent/src/transports/proxy-types.ts deleted file mode 100644 index fbbe7848..00000000 --- a/packages/agent/src/transports/proxy-types.ts +++ /dev/null @@ -1,20 +0,0 @@ -import type { StopReason, Usage } from "@mariozechner/pi-ai"; - -/** - * Event types emitted by the proxy server. - * The server strips the `partial` field from delta events to reduce bandwidth. - * Clients reconstruct the partial message from these events. - */ -export type ProxyAssistantMessageEvent = - | { type: "start" } - | { type: "text_start"; contentIndex: number } - | { type: "text_delta"; contentIndex: number; delta: string } - | { type: "text_end"; contentIndex: number; contentSignature?: string } - | { type: "thinking_start"; contentIndex: number } - | { type: "thinking_delta"; contentIndex: number; delta: string } - | { type: "thinking_end"; contentIndex: number; contentSignature?: string } - | { type: "toolcall_start"; contentIndex: number; id: string; toolName: string } - | { type: "toolcall_delta"; contentIndex: number; delta: string } - | { type: "toolcall_end"; contentIndex: number } - | { type: "done"; reason: Extract; usage: Usage } - | { type: "error"; reason: Extract; errorMessage: string; usage: Usage }; diff --git a/packages/agent/src/transports/types.ts b/packages/agent/src/transports/types.ts deleted file mode 100644 index 736ba0c3..00000000 --- a/packages/agent/src/transports/types.ts +++ /dev/null @@ -1,32 +0,0 @@ -import type { AgentEvent, AgentTool, Message, Model, QueuedMessage, ReasoningEffort } from "@mariozechner/pi-ai"; - -/** - * The minimal configuration needed to run an agent turn. - */ -export interface AgentRunConfig { - systemPrompt: string; - tools: AgentTool[]; - model: Model; - reasoning?: ReasoningEffort; - getQueuedMessages?: () => Promise[]>; -} - -/** - * Transport interface for executing agent turns. - * Transports handle the communication with LLM providers, - * abstracting away the details of API calls, proxies, etc. - * - * Events yielded must match the @mariozechner/pi-ai AgentEvent types. - */ -export interface AgentTransport { - /** Run with a new user message */ - run( - messages: Message[], - userMessage: Message, - config: AgentRunConfig, - signal?: AbortSignal, - ): AsyncIterable; - - /** Continue from current context (no new user message) */ - continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable; -} diff --git a/packages/agent/src/types.ts b/packages/agent/src/types.ts index 46da1492..e8af618e 100644 --- a/packages/agent/src/types.ts +++ b/packages/agent/src/types.ts @@ -1,26 +1,86 @@ import type { - AgentTool, - AssistantMessage, AssistantMessageEvent, + ImageContent, Message, Model, + SimpleStreamOptions, + streamSimple, + TextContent, + Tool, ToolResultMessage, - UserMessage, } from "@mariozechner/pi-ai"; +import type { Static, TSchema } from "@sinclair/typebox"; + +/** Stream function - can return sync or Promise for async config lookup */ +export type StreamFn = ( + ...args: Parameters +) => ReturnType | Promise>; /** - * Attachment type definition. - * Processing is done by consumers (e.g., document extraction in web-ui). + * Configuration for the agent loop. */ -export interface Attachment { - id: string; - type: "image" | "document"; - fileName: string; - mimeType: string; - size: number; - content: string; // base64 encoded (without data URL prefix) - extractedText?: string; // For documents - preview?: string; // base64 image preview +export interface AgentLoopConfig extends SimpleStreamOptions { + model: Model; + + /** + * Converts AgentMessage[] to LLM-compatible Message[] before each LLM call. + * + * Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage + * that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications, + * status messages) should be filtered out. + * + * @example + * ```typescript + * convertToLlm: (messages) => messages.flatMap(m => { + * if (m.role === "hookMessage") { + * // Convert custom message to user message + * return [{ role: "user", content: m.content, timestamp: m.timestamp }]; + * } + * if (m.role === "notification") { + * // Filter out UI-only messages + * return []; + * } + * // Pass through standard LLM messages + * return [m]; + * }) + * ``` + */ + convertToLlm: (messages: AgentMessage[]) => Message[] | Promise; + + /** + * Optional transform applied to the context before `convertToLlm`. + * + * Use this for operations that work at the AgentMessage level: + * - Context window management (pruning old messages) + * - Injecting context from external sources + * + * @example + * ```typescript + * transformContext: async (messages) => { + * if (estimateTokens(messages) > MAX_TOKENS) { + * return pruneOldMessages(messages); + * } + * return messages; + * } + * ``` + */ + transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; + + /** + * Resolves an API key dynamically for each LLM call. + * + * Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire + * during long-running tool execution phases. + */ + getApiKey?: (provider: string) => Promise | string | undefined; + + /** + * Returns queued messages to inject into the conversation. + * + * Called after each turn to check for user interruptions or injected messages. + * If messages are returned, they're added to the context before the next LLM call. + */ + getQueuedMessages?: () => Promise; } /** @@ -29,11 +89,6 @@ export interface Attachment { */ export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh"; -/** - * User message with optional attachments. - */ -export type UserMessageWithAttachments = UserMessage & { attachments?: Attachment[] }; - /** * Extensible interface for custom app messages. * Apps can extend via declaration merging: @@ -41,27 +96,23 @@ export type UserMessageWithAttachments = UserMessage & { attachments?: Attachmen * @example * ```typescript * declare module "@mariozechner/agent" { - * interface CustomMessages { + * interface CustomAgentMessages { * artifact: ArtifactMessage; * notification: NotificationMessage; * } * } * ``` */ -export interface CustomMessages { +export interface CustomAgentMessages { // Empty by default - apps extend via declaration merging } /** - * AppMessage: Union of LLM messages + attachments + custom messages. + * AgentMessage: Union of LLM messages + custom messages. * This abstraction allows apps to add custom message types while maintaining * type safety and compatibility with the base LLM messages. */ -export type AppMessage = - | AssistantMessage - | UserMessageWithAttachments - | Message // Includes ToolResultMessage - | CustomMessages[keyof CustomMessages]; +export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages]; /** * Agent state containing all configuration and conversation data. @@ -71,13 +122,42 @@ export interface AgentState { model: Model; thinkingLevel: ThinkingLevel; tools: AgentTool[]; - messages: AppMessage[]; // Can include attachments + custom message types + messages: AgentMessage[]; // Can include attachments + custom message types isStreaming: boolean; - streamMessage: Message | null; + streamMessage: AgentMessage | null; pendingToolCalls: Set; error?: string; } +export interface AgentToolResult { + // Content blocks supporting text and images + content: (TextContent | ImageContent)[]; + // Details to be displayed in a UI or logged + details: T; +} + +// Callback for streaming tool execution updates +export type AgentToolUpdateCallback = (partialResult: AgentToolResult) => void; + +// AgentTool extends Tool but adds the execute function +export interface AgentTool extends Tool { + // A human-readable label for the tool to be displayed in UI + label: string; + execute: ( + toolCallId: string, + params: Static, + signal?: AbortSignal, + onUpdate?: AgentToolUpdateCallback, + ) => Promise>; +} + +// AgentContext is like Context but uses AgentTool +export interface AgentContext { + systemPrompt: string; + messages: AgentMessage[]; + tools?: AgentTool[]; +} + /** * Events emitted by the Agent for UI updates. * These events provide fine-grained lifecycle information for messages, turns, and tool executions. @@ -85,15 +165,15 @@ export interface AgentState { export type AgentEvent = // Agent lifecycle | { type: "agent_start" } - | { type: "agent_end"; messages: AppMessage[] } + | { type: "agent_end"; messages: AgentMessage[] } // Turn lifecycle - a turn is one assistant response + any tool calls/results | { type: "turn_start" } - | { type: "turn_end"; message: AppMessage; toolResults: ToolResultMessage[] } + | { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] } // Message lifecycle - emitted for user, assistant, and toolResult messages - | { type: "message_start"; message: AppMessage } + | { type: "message_start"; message: AgentMessage } // Only emitted for assistant messages during streaming - | { type: "message_update"; message: AppMessage; assistantMessageEvent: AssistantMessageEvent } - | { type: "message_end"; message: AppMessage } + | { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent } + | { type: "message_end"; message: AgentMessage } // Tool execution lifecycle | { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any } | { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any } diff --git a/packages/agent/test/agent-loop.test.ts b/packages/agent/test/agent-loop.test.ts new file mode 100644 index 00000000..b8295038 --- /dev/null +++ b/packages/agent/test/agent-loop.test.ts @@ -0,0 +1,535 @@ +import { + type AssistantMessage, + type AssistantMessageEvent, + EventStream, + type Message, + type Model, + type UserMessage, +} from "@mariozechner/pi-ai"; +import { Type } from "@sinclair/typebox"; +import { describe, expect, it } from "vitest"; +import { agentLoop, agentLoopContinue } from "../src/agent-loop.js"; +import type { AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AgentTool } from "../src/types.js"; + +// Mock stream for testing - mimics MockAssistantStream +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createUsage() { + return { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; +} + +function createModel(): Model<"openai-responses"> { + return { + id: "mock", + name: "mock", + api: "openai-responses", + provider: "openai", + baseUrl: "https://example.invalid", + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 8192, + maxTokens: 2048, + }; +} + +function createAssistantMessage( + content: AssistantMessage["content"], + stopReason: AssistantMessage["stopReason"] = "stop", +): AssistantMessage { + return { + role: "assistant", + content, + api: "openai-responses", + provider: "openai", + model: "mock", + usage: createUsage(), + stopReason, + timestamp: Date.now(), + }; +} + +function createUserMessage(text: string): UserMessage { + return { + role: "user", + content: text, + timestamp: Date.now(), + }; +} + +// Simple identity converter for tests - just passes through standard messages +function identityConverter(messages: AgentMessage[]): Message[] { + return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[]; +} + +describe("agentLoop with AgentMessage", () => { + it("should emit events with AgentMessage types", async () => { + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [], + tools: [], + }; + + const userPrompt: AgentMessage = createUserMessage("Hello"); + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: identityConverter, + }; + + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + const message = createAssistantMessage([{ type: "text", text: "Hi there!" }]); + stream.push({ type: "done", reason: "stop", message }); + }); + return stream; + }; + + const events: AgentEvent[] = []; + const stream = agentLoop([userPrompt], context, config, undefined, streamFn); + + for await (const event of stream) { + events.push(event); + } + + const messages = await stream.result(); + + // Should have user message and assistant message + expect(messages.length).toBe(2); + expect(messages[0].role).toBe("user"); + expect(messages[1].role).toBe("assistant"); + + // Verify event sequence + const eventTypes = events.map((e) => e.type); + expect(eventTypes).toContain("agent_start"); + expect(eventTypes).toContain("turn_start"); + expect(eventTypes).toContain("message_start"); + expect(eventTypes).toContain("message_end"); + expect(eventTypes).toContain("turn_end"); + expect(eventTypes).toContain("agent_end"); + }); + + it("should handle custom message types via convertToLlm", async () => { + // Create a custom message type + interface CustomNotification { + role: "notification"; + text: string; + timestamp: number; + } + + const notification: CustomNotification = { + role: "notification", + text: "This is a notification", + timestamp: Date.now(), + }; + + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [notification as unknown as AgentMessage], // Custom message in context + tools: [], + }; + + const userPrompt: AgentMessage = createUserMessage("Hello"); + + let convertedMessages: Message[] = []; + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: (messages) => { + // Filter out notifications, convert rest + convertedMessages = messages + .filter((m) => (m as { role: string }).role !== "notification") + .filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[]; + return convertedMessages; + }, + }; + + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + const message = createAssistantMessage([{ type: "text", text: "Response" }]); + stream.push({ type: "done", reason: "stop", message }); + }); + return stream; + }; + + const events: AgentEvent[] = []; + const stream = agentLoop([userPrompt], context, config, undefined, streamFn); + + for await (const event of stream) { + events.push(event); + } + + // The notification should have been filtered out in convertToLlm + expect(convertedMessages.length).toBe(1); // Only user message + expect(convertedMessages[0].role).toBe("user"); + }); + + it("should apply transformContext before convertToLlm", async () => { + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [ + createUserMessage("old message 1"), + createAssistantMessage([{ type: "text", text: "old response 1" }]), + createUserMessage("old message 2"), + createAssistantMessage([{ type: "text", text: "old response 2" }]), + ], + tools: [], + }; + + const userPrompt: AgentMessage = createUserMessage("new message"); + + let transformedMessages: AgentMessage[] = []; + let convertedMessages: Message[] = []; + + const config: AgentLoopConfig = { + model: createModel(), + transformContext: async (messages) => { + // Keep only last 2 messages (prune old ones) + transformedMessages = messages.slice(-2); + return transformedMessages; + }, + convertToLlm: (messages) => { + convertedMessages = messages.filter( + (m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult", + ) as Message[]; + return convertedMessages; + }, + }; + + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + const message = createAssistantMessage([{ type: "text", text: "Response" }]); + stream.push({ type: "done", reason: "stop", message }); + }); + return stream; + }; + + const stream = agentLoop([userPrompt], context, config, undefined, streamFn); + + for await (const _ of stream) { + // consume + } + + // transformContext should have been called first, keeping only last 2 + expect(transformedMessages.length).toBe(2); + // Then convertToLlm receives the pruned messages + expect(convertedMessages.length).toBe(2); + }); + + it("should handle tool calls and results", async () => { + const toolSchema = Type.Object({ value: Type.String() }); + const executed: string[] = []; + const tool: AgentTool = { + name: "echo", + label: "Echo", + description: "Echo tool", + parameters: toolSchema, + async execute(_toolCallId, params) { + executed.push(params.value); + return { + content: [{ type: "text", text: `echoed: ${params.value}` }], + details: { value: params.value }, + }; + }, + }; + + const context: AgentContext = { + systemPrompt: "", + messages: [], + tools: [tool], + }; + + const userPrompt: AgentMessage = createUserMessage("echo something"); + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: identityConverter, + }; + + let callIndex = 0; + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + if (callIndex === 0) { + // First call: return tool call + const message = createAssistantMessage( + [{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "hello" } }], + "toolUse", + ); + stream.push({ type: "done", reason: "toolUse", message }); + } else { + // Second call: return final response + const message = createAssistantMessage([{ type: "text", text: "done" }]); + stream.push({ type: "done", reason: "stop", message }); + } + callIndex++; + }); + return stream; + }; + + const events: AgentEvent[] = []; + const stream = agentLoop([userPrompt], context, config, undefined, streamFn); + + for await (const event of stream) { + events.push(event); + } + + // Tool should have been executed + expect(executed).toEqual(["hello"]); + + // Should have tool execution events + const toolStart = events.find((e) => e.type === "tool_execution_start"); + const toolEnd = events.find((e) => e.type === "tool_execution_end"); + expect(toolStart).toBeDefined(); + expect(toolEnd).toBeDefined(); + if (toolEnd?.type === "tool_execution_end") { + expect(toolEnd.isError).toBe(false); + } + }); + + it("should inject queued messages and skip remaining tool calls", async () => { + const toolSchema = Type.Object({ value: Type.String() }); + const executed: string[] = []; + const tool: AgentTool = { + name: "echo", + label: "Echo", + description: "Echo tool", + parameters: toolSchema, + async execute(_toolCallId, params) { + executed.push(params.value); + return { + content: [{ type: "text", text: `ok:${params.value}` }], + details: { value: params.value }, + }; + }, + }; + + const context: AgentContext = { + systemPrompt: "", + messages: [], + tools: [tool], + }; + + const userPrompt: AgentMessage = createUserMessage("start"); + const queuedUserMessage: AgentMessage = createUserMessage("interrupt"); + + let queuedDelivered = false; + let callIndex = 0; + let sawInterruptInContext = false; + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: identityConverter, + getQueuedMessages: async () => { + // Return queued message after first tool executes + if (executed.length === 1 && !queuedDelivered) { + queuedDelivered = true; + return [queuedUserMessage]; + } + return []; + }, + }; + + const events: AgentEvent[] = []; + const stream = agentLoop([userPrompt], context, config, undefined, (_model, ctx, _options) => { + // Check if interrupt message is in context on second call + if (callIndex === 1) { + sawInterruptInContext = ctx.messages.some( + (m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt", + ); + } + + const mockStream = new MockAssistantStream(); + queueMicrotask(() => { + if (callIndex === 0) { + // First call: return two tool calls + const message = createAssistantMessage( + [ + { type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } }, + { type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } }, + ], + "toolUse", + ); + mockStream.push({ type: "done", reason: "toolUse", message }); + } else { + // Second call: return final response + const message = createAssistantMessage([{ type: "text", text: "done" }]); + mockStream.push({ type: "done", reason: "stop", message }); + } + callIndex++; + }); + return mockStream; + }); + + for await (const event of stream) { + events.push(event); + } + + // Only first tool should have executed + expect(executed).toEqual(["first"]); + + // Second tool should be skipped + const toolEnds = events.filter( + (e): e is Extract => e.type === "tool_execution_end", + ); + expect(toolEnds.length).toBe(2); + expect(toolEnds[0].isError).toBe(false); + expect(toolEnds[1].isError).toBe(true); + if (toolEnds[1].result.content[0]?.type === "text") { + expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message"); + } + + // Queued message should appear in events + const queuedMessageEvent = events.find( + (e) => + e.type === "message_start" && + e.message.role === "user" && + typeof e.message.content === "string" && + e.message.content === "interrupt", + ); + expect(queuedMessageEvent).toBeDefined(); + + // Interrupt message should be in context when second LLM call is made + expect(sawInterruptInContext).toBe(true); + }); +}); + +describe("agentLoopContinue with AgentMessage", () => { + it("should throw when context has no messages", () => { + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [], + tools: [], + }; + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: identityConverter, + }; + + expect(() => agentLoopContinue(context, config)).toThrow("Cannot continue: no messages in context"); + }); + + it("should continue from existing context without emitting user message events", async () => { + const userMessage: AgentMessage = createUserMessage("Hello"); + + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [userMessage], + tools: [], + }; + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: identityConverter, + }; + + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + const message = createAssistantMessage([{ type: "text", text: "Response" }]); + stream.push({ type: "done", reason: "stop", message }); + }); + return stream; + }; + + const events: AgentEvent[] = []; + const stream = agentLoopContinue(context, config, undefined, streamFn); + + for await (const event of stream) { + events.push(event); + } + + const messages = await stream.result(); + + // Should only return the new assistant message (not the existing user message) + expect(messages.length).toBe(1); + expect(messages[0].role).toBe("assistant"); + + // Should NOT have user message events (that's the key difference from agentLoop) + const messageEndEvents = events.filter((e) => e.type === "message_end"); + expect(messageEndEvents.length).toBe(1); + expect((messageEndEvents[0] as any).message.role).toBe("assistant"); + }); + + it("should allow custom message types as last message (caller responsibility)", async () => { + // Custom message that will be converted to user message by convertToLlm + interface HookMessage { + role: "hookMessage"; + text: string; + timestamp: number; + } + + const hookMessage: HookMessage = { + role: "hookMessage", + text: "Hook content", + timestamp: Date.now(), + }; + + const context: AgentContext = { + systemPrompt: "You are helpful.", + messages: [hookMessage as unknown as AgentMessage], + tools: [], + }; + + const config: AgentLoopConfig = { + model: createModel(), + convertToLlm: (messages) => { + // Convert hookMessage to user message + return messages + .map((m) => { + if ((m as any).role === "hookMessage") { + return { + role: "user" as const, + content: (m as any).text, + timestamp: m.timestamp, + }; + } + return m; + }) + .filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[]; + }, + }; + + const streamFn = () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + const message = createAssistantMessage([{ type: "text", text: "Response to hook" }]); + stream.push({ type: "done", reason: "stop", message }); + }); + return stream; + }; + + // Should not throw - the hookMessage will be converted to user message + const stream = agentLoopContinue(context, config, undefined, streamFn); + + const events: AgentEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + const messages = await stream.result(); + expect(messages.length).toBe(1); + expect(messages[0].role).toBe("assistant"); + }); +}); diff --git a/packages/agent/test/agent.test.ts b/packages/agent/test/agent.test.ts index 235136cf..8fee033f 100644 --- a/packages/agent/test/agent.test.ts +++ b/packages/agent/test/agent.test.ts @@ -1,12 +1,10 @@ import { getModel } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; -import { Agent, ProviderTransport } from "../src/index.js"; +import { Agent } from "../src/index.js"; describe("Agent", () => { it("should create an agent instance with default state", () => { - const agent = new Agent({ - transport: new ProviderTransport(), - }); + const agent = new Agent(); expect(agent.state).toBeDefined(); expect(agent.state.systemPrompt).toBe(""); @@ -23,7 +21,6 @@ describe("Agent", () => { it("should create an agent instance with custom initial state", () => { const customModel = getModel("openai", "gpt-4o-mini"); const agent = new Agent({ - transport: new ProviderTransport(), initialState: { systemPrompt: "You are a helpful assistant.", model: customModel, @@ -37,9 +34,7 @@ describe("Agent", () => { }); it("should subscribe to events", () => { - const agent = new Agent({ - transport: new ProviderTransport(), - }); + const agent = new Agent(); let eventCount = 0; const unsubscribe = agent.subscribe((_event) => { @@ -61,9 +56,7 @@ describe("Agent", () => { }); it("should update state with mutators", () => { - const agent = new Agent({ - transport: new ProviderTransport(), - }); + const agent = new Agent(); // Test setSystemPrompt agent.setSystemPrompt("Custom prompt"); @@ -101,38 +94,19 @@ describe("Agent", () => { }); it("should support message queueing", async () => { - const agent = new Agent({ - transport: new ProviderTransport(), - }); + const agent = new Agent(); const message = { role: "user" as const, content: "Queued message", timestamp: Date.now() }; - await agent.queueMessage(message); + agent.queueMessage(message); // The message is queued but not yet in state.messages expect(agent.state.messages).not.toContainEqual(message); }); it("should handle abort controller", () => { - const agent = new Agent({ - transport: new ProviderTransport(), - }); + const agent = new Agent(); // Should not throw even if nothing is running expect(() => agent.abort()).not.toThrow(); }); }); - -describe("ProviderTransport", () => { - it("should create a provider transport instance", () => { - const transport = new ProviderTransport(); - expect(transport).toBeDefined(); - }); - - it("should create a provider transport with options", () => { - const transport = new ProviderTransport({ - getApiKey: async (provider) => `test-key-${provider}`, - corsProxyUrl: "https://proxy.example.com", - }); - expect(transport).toBeDefined(); - }); -}); diff --git a/packages/agent/test/e2e.test.ts b/packages/agent/test/e2e.test.ts index f18030b3..7996543a 100644 --- a/packages/agent/test/e2e.test.ts +++ b/packages/agent/test/e2e.test.ts @@ -1,25 +1,8 @@ import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai"; -import { calculateTool, getModel } from "@mariozechner/pi-ai"; +import { getModel } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; -import { Agent, ProviderTransport } from "../src/index.js"; - -function createTransport() { - return new ProviderTransport({ - getApiKey: async (provider) => { - const envVarMap: Record = { - google: "GEMINI_API_KEY", - openai: "OPENAI_API_KEY", - anthropic: "ANTHROPIC_API_KEY", - xai: "XAI_API_KEY", - groq: "GROQ_API_KEY", - cerebras: "CEREBRAS_API_KEY", - zai: "ZAI_API_KEY", - }; - const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`; - return process.env[envVar]; - }, - }); -} +import { Agent } from "../src/index.js"; +import { calculateTool } from "./utils/calculate.js"; async function basicPrompt(model: Model) { const agent = new Agent({ @@ -29,7 +12,6 @@ async function basicPrompt(model: Model) { thinkingLevel: "off", tools: [], }, - transport: createTransport(), }); await agent.prompt("What is 2+2? Answer with just the number."); @@ -57,7 +39,6 @@ async function toolExecution(model: Model) { thinkingLevel: "off", tools: [calculateTool], }, - transport: createTransport(), }); await agent.prompt("Calculate 123 * 456 using the calculator tool."); @@ -99,7 +80,6 @@ async function abortExecution(model: Model) { thinkingLevel: "off", tools: [calculateTool], }, - transport: createTransport(), }); const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results."); @@ -129,7 +109,6 @@ async function stateUpdates(model: Model) { thinkingLevel: "off", tools: [], }, - transport: createTransport(), }); const events: Array = []; @@ -162,7 +141,6 @@ async function multiTurnConversation(model: Model) { thinkingLevel: "off", tools: [], }, - transport: createTransport(), }); await agent.prompt("My name is Alice."); @@ -356,7 +334,6 @@ describe("Agent.continue()", () => { systemPrompt: "Test", model: getModel("anthropic", "claude-haiku-4-5"), }, - transport: createTransport(), }); await expect(agent.continue()).rejects.toThrow("No messages to continue from"); @@ -368,7 +345,6 @@ describe("Agent.continue()", () => { systemPrompt: "Test", model: getModel("anthropic", "claude-haiku-4-5"), }, - transport: createTransport(), }); const assistantMessage: AssistantMessage = { @@ -405,7 +381,6 @@ describe("Agent.continue()", () => { thinkingLevel: "off", tools: [], }, - transport: createTransport(), }); // Manually add a user message without calling prompt() @@ -445,7 +420,6 @@ describe("Agent.continue()", () => { thinkingLevel: "off", tools: [calculateTool], }, - transport: createTransport(), }); // Set up a conversation state as if tool was just executed diff --git a/packages/ai/src/agent/tools/calculate.ts b/packages/agent/test/utils/calculate.ts similarity index 93% rename from packages/ai/src/agent/tools/calculate.ts rename to packages/agent/test/utils/calculate.ts index afc75889..af5b17c4 100644 --- a/packages/ai/src/agent/tools/calculate.ts +++ b/packages/agent/test/utils/calculate.ts @@ -1,5 +1,5 @@ import { type Static, Type } from "@sinclair/typebox"; -import type { AgentTool, AgentToolResult } from "../../agent/types.js"; +import type { AgentTool, AgentToolResult } from "../../src/types.js"; export interface CalculateResult extends AgentToolResult { content: Array<{ type: "text"; text: string }>; diff --git a/packages/ai/src/agent/tools/get-current-time.ts b/packages/agent/test/utils/get-current-time.ts similarity index 93% rename from packages/ai/src/agent/tools/get-current-time.ts rename to packages/agent/test/utils/get-current-time.ts index 2d989243..814233b1 100644 --- a/packages/ai/src/agent/tools/get-current-time.ts +++ b/packages/agent/test/utils/get-current-time.ts @@ -1,6 +1,5 @@ import { type Static, Type } from "@sinclair/typebox"; -import type { AgentTool } from "../../agent/index.js"; -import type { AgentToolResult } from "../types.js"; +import type { AgentTool, AgentToolResult } from "../../src/types.js"; export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {} diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index 69c4bee9..9e4af665 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Breaking Changes + +- **Agent API moved**: All agent functionality (`agentLoop`, `agentLoopContinue`, `AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, etc.) has moved to `@mariozechner/pi-agent-core`. See the [agent-core README](../agent/README.md) for documentation. + ## [0.28.0] - 2025-12-25 ### Breaking Changes diff --git a/packages/ai/README.md b/packages/ai/README.md index 9b03f402..89cdcfc4 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -782,276 +782,6 @@ const continuation = await complete(newModel, restored); > **Note**: If the context contains images (encoded as base64 as shown in the Image Input section), those will also be serialized. -## Agent API - -The Agent API provides a higher-level interface for building agents with tools. It handles tool execution, validation, and provides detailed event streaming for interactive applications. - -### Event System - -The Agent API streams events during execution, allowing you to build reactive UIs and track agent progress. The agent processes prompts in **turns**, where each turn consists of: -1. An assistant message (the LLM's response) -2. Optional tool executions if the assistant calls tools -3. Tool result messages that are fed back to the LLM - -This continues until the assistant produces a response without tool calls. - -**Queued messages**: If you provide `getQueuedMessages` in the loop config, the agent checks for queued user messages after each tool call. When queued messages are found, any remaining tool calls from the current assistant message are skipped and returned as error tool results (`isError: true`) with the message "Skipped due to queued user message." The queued user messages are injected before the next assistant response. - -### Event Flow Example - -Given a prompt asking to calculate two expressions and sum them: - -```typescript -import { agentLoop, AgentContext, calculateTool } from '@mariozechner/pi-ai'; - -const context: AgentContext = { - systemPrompt: 'You are a helpful math assistant.', - messages: [], - tools: [calculateTool] -}; - -const stream = agentLoop( - { role: 'user', content: 'Calculate 15 * 20 and 30 * 40, then sum the results', timestamp: Date.now() }, - context, - { model: getModel('openai', 'gpt-4o-mini') } -); - -// Expected event sequence: -// 1. agent_start - Agent begins processing -// 2. turn_start - First turn begins -// 3. message_start - User message starts -// 4. message_end - User message ends -// 5. message_start - Assistant message starts -// 6. message_update - Assistant streams response with tool calls -// 7. message_end - Assistant message ends -// 8. tool_execution_start - First calculation (15 * 20) -// 9. tool_execution_update - Streaming progress (for long-running tools) -// 10. tool_execution_end - Result: 300 -// 11. tool_execution_start - Second calculation (30 * 40) -// 12. tool_execution_update - Streaming progress -// 13. tool_execution_end - Result: 1200 -// 12. message_start - Tool result message for first calculation -// 13. message_end - Tool result message ends -// 14. message_start - Tool result message for second calculation -// 15. message_end - Tool result message ends -// 16. turn_end - First turn ends with 2 tool results -// 17. turn_start - Second turn begins -// 18. message_start - Assistant message starts -// 19. message_update - Assistant streams response with sum calculation -// 20. message_end - Assistant message ends -// 21. tool_execution_start - Sum calculation (300 + 1200) -// 22. tool_execution_end - Result: 1500 -// 23. message_start - Tool result message for sum -// 24. message_end - Tool result message ends -// 25. turn_end - Second turn ends with 1 tool result -// 26. turn_start - Third turn begins -// 27. message_start - Final assistant message starts -// 28. message_update - Assistant streams final answer -// 29. message_end - Final assistant message ends -// 30. turn_end - Third turn ends with 0 tool results -// 31. agent_end - Agent completes with all messages -``` - -### Handling Events - -```typescript -for await (const event of stream) { - switch (event.type) { - case 'agent_start': - console.log('Agent started'); - break; - - case 'turn_start': - console.log('New turn started'); - break; - - case 'message_start': - console.log(`${event.message.role} message started`); - break; - - case 'message_update': - // Only for assistant messages during streaming - if (event.message.content.some(c => c.type === 'text')) { - console.log('Assistant:', event.message.content); - } - break; - - case 'tool_execution_start': - console.log(`Calling ${event.toolName} with:`, event.args); - break; - - case 'tool_execution_update': - // Streaming progress for long-running tools (e.g., bash output) - console.log(`Progress:`, event.partialResult.content); - break; - - case 'tool_execution_end': - if (event.isError) { - console.error(`Tool failed:`, event.result); - } else { - console.log(`Tool result:`, event.result.content); - } - break; - - case 'turn_end': - console.log(`Turn ended with ${event.toolResults.length} tool calls`); - break; - - case 'agent_end': - console.log(`Agent completed with ${event.messages.length} new messages`); - break; - } -} - -// Get all messages generated during this agent execution -// These include the user message and can be directly appended to context.messages -const messages = await stream.result(); -context.messages.push(...messages); -``` - -### Continuing from Existing Context - -Use `agentLoopContinue` to resume an agent loop without adding a new user message. This is useful for: -- Retrying after context overflow (after compaction reduces context size) -- Resuming from tool results that were added manually to the context - -```typescript -import { agentLoopContinue, AgentContext } from '@mariozechner/pi-ai'; - -// Context already has messages - last must be 'user' or 'toolResult' -const context: AgentContext = { - systemPrompt: 'You are helpful.', - messages: [userMessage, assistantMessage, toolResult], - tools: [myTool] -}; - -// Continue processing from the tool result -const stream = agentLoopContinue(context, { model }); - -for await (const event of stream) { - // Same events as agentLoop, but no user message events emitted -} - -const newMessages = await stream.result(); -``` - -**Validation**: Throws if context has no messages or if the last message is an assistant message. - -### Defining Tools with TypeBox - -Tools use TypeBox schemas for runtime validation and type inference: - -```typescript -import { Type, Static, AgentTool, AgentToolResult, StringEnum } from '@mariozechner/pi-ai'; - -const weatherSchema = Type.Object({ - city: Type.String({ minLength: 1 }), - units: StringEnum(['celsius', 'fahrenheit'], { default: 'celsius' }) -}); - -type WeatherParams = Static; - -const weatherTool: AgentTool = { - label: 'Get Weather', - name: 'get_weather', - description: 'Get current weather for a city', - parameters: weatherSchema, - execute: async (toolCallId, args, signal, onUpdate) => { - // args is fully typed: { city: string, units: 'celsius' | 'fahrenheit' } - // signal: AbortSignal for cancellation - // onUpdate: Optional callback for streaming progress (emits tool_execution_update events) - const temp = Math.round(Math.random() * 30); - return { - content: [{ type: 'text', text: `Temperature in ${args.city}: ${temp}°${args.units[0].toUpperCase()}` }], - details: { temp } - }; - } -}; - -// Tools can also return images alongside text -const chartTool: AgentTool = { - label: 'Generate Chart', - name: 'generate_chart', - description: 'Generate a chart from data', - parameters: Type.Object({ data: Type.Array(Type.Number()) }), - execute: async (toolCallId, args) => { - const chartImage = await generateChartImage(args.data); - return { - content: [ - { type: 'text', text: `Generated chart with ${args.data.length} data points` }, - { type: 'image', data: chartImage.toString('base64'), mimeType: 'image/png' } - ] - }; - } -}; - -// Tools can stream progress via the onUpdate callback (emits tool_execution_update events) -const bashTool: AgentTool = { - label: 'Run Bash', - name: 'bash', - description: 'Execute a bash command', - parameters: Type.Object({ command: Type.String() }), - execute: async (toolCallId, args, signal, onUpdate) => { - let output = ''; - const child = spawn('bash', ['-c', args.command]); - - child.stdout.on('data', (data) => { - output += data.toString(); - // Stream partial output to UI via tool_execution_update events - onUpdate?.({ - content: [{ type: 'text', text: output }], - details: { exitCode: -1 } // Not finished yet - }); - }); - - const exitCode = await new Promise((resolve) => { - child.on('close', resolve); - }); - - return { - content: [{ type: 'text', text: output }], - details: { exitCode } - }; - } -}; -``` - -### Validation and Error Handling - -Tool arguments are automatically validated using AJV with the TypeBox schema. Invalid arguments result in detailed error messages: - -```typescript -// If the LLM calls with invalid arguments: -// get_weather({ city: '', units: 'kelvin' }) - -// The tool execution will fail with: -/* -Validation failed for tool "get_weather": - - city: must NOT have fewer than 1 characters - - units: must be equal to one of the allowed values - -Received arguments: -{ - "city": "", - "units": "kelvin" -} -*/ -``` - -### Built-in Example Tools - -The library includes example tools for common operations: - -```typescript -import { calculateTool, getCurrentTimeTool } from '@mariozechner/pi-ai'; - -const context: AgentContext = { - systemPrompt: 'You are a helpful assistant.', - messages: [], - tools: [calculateTool, getCurrentTimeTool] -}; -``` - ## Browser Usage The library supports browser environments. You must pass the API key explicitly since environment variables are not available in browsers: diff --git a/packages/ai/src/agent/index.ts b/packages/ai/src/agent/index.ts deleted file mode 100644 index 16e323d4..00000000 --- a/packages/ai/src/agent/index.ts +++ /dev/null @@ -1,11 +0,0 @@ -export { agentLoop, agentLoopContinue } from "./agent-loop.js"; -export * from "./tools/index.js"; -export type { - AgentContext, - AgentEvent, - AgentLoopConfig, - AgentTool, - AgentToolResult, - AgentToolUpdateCallback, - QueuedMessage, -} from "./types.js"; diff --git a/packages/ai/src/agent/tools/index.ts b/packages/ai/src/agent/tools/index.ts deleted file mode 100644 index 0a9cc7ad..00000000 --- a/packages/ai/src/agent/tools/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { calculate, calculateTool } from "./calculate.js"; -export { getCurrentTime, getCurrentTimeTool } from "./get-current-time.js"; diff --git a/packages/ai/src/agent/types.ts b/packages/ai/src/agent/types.ts deleted file mode 100644 index c0cb2df6..00000000 --- a/packages/ai/src/agent/types.ts +++ /dev/null @@ -1,105 +0,0 @@ -import type { Static, TSchema } from "@sinclair/typebox"; -import type { - AssistantMessage, - AssistantMessageEvent, - ImageContent, - Message, - Model, - SimpleStreamOptions, - TextContent, - Tool, - ToolResultMessage, -} from "../types.js"; - -export interface AgentToolResult { - // Content blocks supporting text and images - content: (TextContent | ImageContent)[]; - // Details to be displayed in a UI or logged - details: T; -} - -// Callback for streaming tool execution updates -export type AgentToolUpdateCallback = (partialResult: AgentToolResult) => void; - -// AgentTool extends Tool but adds the execute function -export interface AgentTool extends Tool { - // A human-readable label for the tool to be displayed in UI - label: string; - execute: ( - toolCallId: string, - params: Static, - signal?: AbortSignal, - onUpdate?: AgentToolUpdateCallback, - ) => Promise>; -} - -// AgentContext is like Context but uses AgentTool -export interface AgentContext { - systemPrompt: string; - messages: Message[]; - tools?: AgentTool[]; -} - -// Event types -export type AgentEvent = - // Emitted when the agent starts. An agent can emit multiple turns - | { type: "agent_start" } - // Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages - | { type: "turn_start" } - // Emitted when a user, assistant or tool result message starts - | { type: "message_start"; message: Message } - // Emitted when an asssitant messages is updated due to streaming - | { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage } - // Emitted when a user, assistant or tool result message is complete - | { type: "message_end"; message: Message } - // Emitted when a tool execution starts - | { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any } - // Emitted when a tool execution produces output (streaming) - | { - type: "tool_execution_update"; - toolCallId: string; - toolName: string; - args: any; - partialResult: AgentToolResult; - } - // Emitted when a tool execution completes - | { - type: "tool_execution_end"; - toolCallId: string; - toolName: string; - result: AgentToolResult; - isError: boolean; - } - // Emitted when a full turn completes - | { type: "turn_end"; message: AssistantMessage; toolResults: ToolResultMessage[] } - // Emitted when the agent has completed all its turns. All messages from every turn are - // contained in messages, which can be appended to the context - | { type: "agent_end"; messages: AgentContext["messages"] }; - -// Queued message with optional LLM representation -export interface QueuedMessage { - original: TApp; // Original message for UI events - llm?: Message; // Optional transformed message for loop context (undefined if filtered) -} - -// Configuration for agent loop execution -export interface AgentLoopConfig extends SimpleStreamOptions { - model: Model; - - /** - * Optional hook to resolve an API key dynamically for each LLM call. - * - * This is useful for short-lived OAuth tokens (e.g. GitHub Copilot) that may - * expire during long-running tool execution phases. - * - * The agent loop will call this before each assistant response and pass the - * returned value as `apiKey` to `streamSimple()` (or a custom `streamFn`). - * - * If it returns `undefined`, the loop falls back to `config.apiKey`, and then - * to `streamSimple()`'s own provider key lookup (setApiKey/env vars). - */ - getApiKey?: (provider: string) => Promise | string | undefined; - - preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise; - getQueuedMessages?: () => Promise[]>; -} diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index e24431d8..5ff971de 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,4 +1,3 @@ -export * from "./agent/index.js"; export * from "./models.js"; export * from "./providers/anthropic.js"; export * from "./providers/google.js"; @@ -7,6 +6,7 @@ export * from "./providers/openai-completions.js"; export * from "./providers/openai-responses.js"; export * from "./stream.js"; export * from "./types.js"; +export * from "./utils/event-stream.js"; export * from "./utils/oauth/index.js"; export * from "./utils/overflow.js"; export * from "./utils/typebox-helpers.js"; diff --git a/packages/ai/src/models.generated.ts b/packages/ai/src/models.generated.ts index 6cde070c..f7fb6aff 100644 --- a/packages/ai/src/models.generated.ts +++ b/packages/ai/src/models.generated.ts @@ -3325,13 +3325,13 @@ export const MODELS = { reasoning: true, input: ["text"], cost: { - input: 0.224, - output: 0.32, + input: 0.25, + output: 0.38, cacheRead: 0, cacheWrite: 0, }, contextWindow: 163840, - maxTokens: 4096, + maxTokens: 65536, } satisfies Model<"openai-completions">, "deepseek/deepseek-v3.2-exp": { id: "deepseek/deepseek-v3.2-exp", @@ -3892,7 +3892,7 @@ export const MODELS = { cacheWrite: 0, }, contextWindow: 196608, - maxTokens: 131072, + maxTokens: 65536, } satisfies Model<"openai-completions">, "minimax/minimax-m2.1": { id: "minimax/minimax-m2.1", @@ -5371,7 +5371,7 @@ export const MODELS = { cacheWrite: 0, }, contextWindow: 131072, - maxTokens: 128000, + maxTokens: 131072, } satisfies Model<"openai-completions">, "openai/gpt-oss-safeguard-20b": { id: "openai/gpt-oss-safeguard-20b", @@ -6249,8 +6249,8 @@ export const MODELS = { reasoning: true, input: ["text"], cost: { - input: 0.3, - output: 1.2, + input: 0.25, + output: 0.85, cacheRead: 0, cacheWrite: 0, }, @@ -6266,8 +6266,8 @@ export const MODELS = { reasoning: true, input: ["text"], cost: { - input: 0.3, - output: 1.2, + input: 0.25, + output: 0.85, cacheRead: 0, cacheWrite: 0, }, @@ -6538,13 +6538,13 @@ export const MODELS = { reasoning: true, input: ["text"], cost: { - input: 0.39, - output: 1.9, + input: 0.35, + output: 1.5, cacheRead: 0, cacheWrite: 0, }, - contextWindow: 204800, - maxTokens: 204800, + contextWindow: 202752, + maxTokens: 65536, } satisfies Model<"openai-completions">, "z-ai/glm-4.6:exacto": { id: "z-ai/glm-4.6:exacto", diff --git a/packages/ai/src/providers/google-gemini-cli.ts b/packages/ai/src/providers/google-gemini-cli.ts index cbeb60a2..85d644dd 100644 --- a/packages/ai/src/providers/google-gemini-cli.ts +++ b/packages/ai/src/providers/google-gemini-cli.ts @@ -4,7 +4,7 @@ * Uses the Cloud Code Assist API endpoint to access Gemini and Claude models. */ -import type { Content, ThinkingConfig, ThinkingLevel } from "@google/genai"; +import type { Content, ThinkingConfig } from "@google/genai"; import { calculateCost } from "../models.js"; import type { Api, @@ -21,6 +21,12 @@ import { AssistantMessageEventStream } from "../utils/event-stream.js"; import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; import { convertMessages, convertTools, mapStopReasonString, mapToolChoice } from "./google-shared.js"; +/** + * Thinking level for Gemini 3 models. + * Mirrors Google's ThinkingLevel enum values. + */ +export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH"; + export interface GoogleGeminiCliOptions extends StreamOptions { toolChoice?: "auto" | "none" | "any"; /** @@ -35,7 +41,7 @@ export interface GoogleGeminiCliOptions extends StreamOptions { /** Thinking budget in tokens. Use for Gemini 2.x models. */ budgetTokens?: number; /** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */ - level?: ThinkingLevel; + level?: GoogleThinkingLevel; }; projectId?: string; } @@ -436,7 +442,8 @@ function buildRequest( }; // Gemini 3 models use thinkingLevel, older models use thinkingBudget if (options.thinking.level !== undefined) { - generationConfig.thinkingConfig.thinkingLevel = options.thinking.level; + // Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values + generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any; } else if (options.thinking.budgetTokens !== undefined) { generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens; } diff --git a/packages/ai/src/providers/google.ts b/packages/ai/src/providers/google.ts index 29370a68..67893eef 100644 --- a/packages/ai/src/providers/google.ts +++ b/packages/ai/src/providers/google.ts @@ -3,7 +3,6 @@ import { type GenerateContentParameters, GoogleGenAI, type ThinkingConfig, - type ThinkingLevel, } from "@google/genai"; import { calculateCost } from "../models.js"; import { getEnvApiKey } from "../stream.js"; @@ -20,6 +19,7 @@ import type { } from "../types.js"; import { AssistantMessageEventStream } from "../utils/event-stream.js"; import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; +import type { GoogleThinkingLevel } from "./google-gemini-cli.js"; import { convertMessages, convertTools, mapStopReason, mapToolChoice } from "./google-shared.js"; export interface GoogleOptions extends StreamOptions { @@ -27,7 +27,7 @@ export interface GoogleOptions extends StreamOptions { thinking?: { enabled: boolean; budgetTokens?: number; // -1 for dynamic, 0 to disable - level?: ThinkingLevel; + level?: GoogleThinkingLevel; }; } @@ -299,7 +299,8 @@ function buildParams( if (options.thinking?.enabled && model.reasoning) { const thinkingConfig: ThinkingConfig = { includeThoughts: true }; if (options.thinking.level !== undefined) { - thinkingConfig.thinkingLevel = options.thinking.level; + // Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values + thinkingConfig.thinkingLevel = options.thinking.level as any; } else if (options.thinking.budgetTokens !== undefined) { thinkingConfig.thinkingBudget = options.thinking.budgetTokens; } diff --git a/packages/ai/src/stream.ts b/packages/ai/src/stream.ts index a71656e9..f68d5f60 100644 --- a/packages/ai/src/stream.ts +++ b/packages/ai/src/stream.ts @@ -1,8 +1,11 @@ -import { ThinkingLevel } from "@google/genai"; import { supportsXhigh } from "./models.js"; import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js"; import { type GoogleOptions, streamGoogle } from "./providers/google.js"; -import { type GoogleGeminiCliOptions, streamGoogleGeminiCli } from "./providers/google-gemini-cli.js"; +import { + type GoogleGeminiCliOptions, + type GoogleThinkingLevel, + streamGoogleGeminiCli, +} from "./providers/google-gemini-cli.js"; import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js"; import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js"; import type { @@ -30,9 +33,13 @@ export function getEnvApiKey(provider: any): string | undefined { return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN; } + // ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY + if (provider === "anthropic") { + return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; + } + const envMap: Record = { openai: "OPENAI_API_KEY", - anthropic: "ANTHROPIC_API_KEY", google: "GEMINI_API_KEY", groq: "GROQ_API_KEY", cerebras: "CEREBRAS_API_KEY", @@ -252,53 +259,56 @@ function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean { return model.id.includes("3-flash"); } -function getGemini3ThinkingLevel(effort: ClampedReasoningEffort, model: Model<"google-generative-ai">): ThinkingLevel { +function getGemini3ThinkingLevel( + effort: ClampedReasoningEffort, + model: Model<"google-generative-ai">, +): GoogleThinkingLevel { if (isGemini3ProModel(model)) { // Gemini 3 Pro only supports LOW/HIGH (for now) switch (effort) { case "minimal": case "low": - return ThinkingLevel.LOW; + return "LOW"; case "medium": case "high": - return ThinkingLevel.HIGH; + return "HIGH"; } } // Gemini 3 Flash supports all four levels switch (effort) { case "minimal": - return ThinkingLevel.MINIMAL; + return "MINIMAL"; case "low": - return ThinkingLevel.LOW; + return "LOW"; case "medium": - return ThinkingLevel.MEDIUM; + return "MEDIUM"; case "high": - return ThinkingLevel.HIGH; + return "HIGH"; } } -function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): ThinkingLevel { +function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): GoogleThinkingLevel { if (modelId.includes("3-pro")) { // Gemini 3 Pro only supports LOW/HIGH (for now) switch (effort) { case "minimal": case "low": - return ThinkingLevel.LOW; + return "LOW"; case "medium": case "high": - return ThinkingLevel.HIGH; + return "HIGH"; } } // Gemini 3 Flash supports all four levels switch (effort) { case "minimal": - return ThinkingLevel.MINIMAL; + return "MINIMAL"; case "low": - return ThinkingLevel.LOW; + return "LOW"; case "medium": - return ThinkingLevel.MEDIUM; + return "MEDIUM"; case "high": - return ThinkingLevel.HIGH; + return "HIGH"; } } diff --git a/packages/ai/src/utils/oauth/anthropic.ts b/packages/ai/src/utils/oauth/anthropic.ts index ea772e94..74a2228c 100644 --- a/packages/ai/src/utils/oauth/anthropic.ts +++ b/packages/ai/src/utils/oauth/anthropic.ts @@ -2,25 +2,16 @@ * Anthropic OAuth flow (Claude Pro/Max) */ -import { createHash, randomBytes } from "crypto"; +import { generatePKCE } from "./pkce.js"; import type { OAuthCredentials } from "./types.js"; -const decode = (s: string) => Buffer.from(s, "base64").toString(); +const decode = (s: string) => atob(s); const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl"); const AUTHORIZE_URL = "https://claude.ai/oauth/authorize"; const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"; const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback"; const SCOPES = "org:create_api_key user:profile user:inference"; -/** - * Generate PKCE code verifier and challenge - */ -function generatePKCE(): { verifier: string; challenge: string } { - const verifier = randomBytes(32).toString("base64url"); - const challenge = createHash("sha256").update(verifier).digest("base64url"); - return { verifier, challenge }; -} - /** * Login with Anthropic OAuth (device code flow) * @@ -31,7 +22,7 @@ export async function loginAnthropic( onAuthUrl: (url: string) => void, onPromptCode: () => Promise, ): Promise { - const { verifier, challenge } = generatePKCE(); + const { verifier, challenge } = await generatePKCE(); // Build authorization URL const authParams = new URLSearchParams({ diff --git a/packages/ai/src/utils/oauth/github-copilot.ts b/packages/ai/src/utils/oauth/github-copilot.ts index e74901d0..744d9f18 100644 --- a/packages/ai/src/utils/oauth/github-copilot.ts +++ b/packages/ai/src/utils/oauth/github-copilot.ts @@ -5,7 +5,7 @@ import { getModels } from "../../models.js"; import type { OAuthCredentials } from "./types.js"; -const decode = (s: string) => Buffer.from(s, "base64").toString(); +const decode = (s: string) => atob(s); const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg="); const COPILOT_HEADERS = { diff --git a/packages/ai/src/utils/oauth/google-antigravity.ts b/packages/ai/src/utils/oauth/google-antigravity.ts index 80deb402..16f45c51 100644 --- a/packages/ai/src/utils/oauth/google-antigravity.ts +++ b/packages/ai/src/utils/oauth/google-antigravity.ts @@ -1,14 +1,17 @@ /** * Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud) * Uses different OAuth credentials than google-gemini-cli for access to additional models. + * + * NOTE: This module uses Node.js http.createServer for the OAuth callback. + * It is only intended for CLI use, not browser environments. */ -import { createHash, randomBytes } from "crypto"; -import { createServer, type Server } from "http"; +import type { Server } from "http"; +import { generatePKCE } from "./pkce.js"; import type { OAuthCredentials } from "./types.js"; // Antigravity OAuth credentials (different from Gemini CLI) -const decode = (s: string) => Buffer.from(s, "base64").toString(); +const decode = (s: string) => atob(s); const CLIENT_ID = decode( "MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==", ); @@ -30,19 +33,15 @@ const TOKEN_URL = "https://oauth2.googleapis.com/token"; // Fallback project ID when discovery fails const DEFAULT_PROJECT_ID = "rising-fact-p41fc"; -/** - * Generate PKCE code verifier and challenge - */ -function generatePKCE(): { verifier: string; challenge: string } { - const verifier = randomBytes(32).toString("base64url"); - const challenge = createHash("sha256").update(verifier).digest("base64url"); - return { verifier, challenge }; -} - /** * Start a local HTTP server to receive the OAuth callback */ -function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> { +async function startCallbackServer(): Promise<{ + server: Server; + getCode: () => Promise<{ code: string; state: string }>; +}> { + const { createServer } = await import("http"); + return new Promise((resolve, reject) => { let codeResolve: (value: { code: string; state: string }) => void; let codeReject: (error: Error) => void; @@ -232,7 +231,7 @@ export async function loginAntigravity( onAuth: (info: { url: string; instructions?: string }) => void, onProgress?: (message: string) => void, ): Promise { - const { verifier, challenge } = generatePKCE(); + const { verifier, challenge } = await generatePKCE(); // Start local server for callback onProgress?.("Starting local server for OAuth callback..."); diff --git a/packages/ai/src/utils/oauth/google-gemini-cli.ts b/packages/ai/src/utils/oauth/google-gemini-cli.ts index 6c06d375..d3585d15 100644 --- a/packages/ai/src/utils/oauth/google-gemini-cli.ts +++ b/packages/ai/src/utils/oauth/google-gemini-cli.ts @@ -1,13 +1,16 @@ /** * Gemini CLI OAuth flow (Google Cloud Code Assist) * Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*) + * + * NOTE: This module uses Node.js http.createServer for the OAuth callback. + * It is only intended for CLI use, not browser environments. */ -import { createHash, randomBytes } from "crypto"; -import { createServer, type Server } from "http"; +import type { Server } from "http"; +import { generatePKCE } from "./pkce.js"; import type { OAuthCredentials } from "./types.js"; -const decode = (s: string) => Buffer.from(s, "base64").toString(); +const decode = (s: string) => atob(s); const CLIENT_ID = decode( "NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t", ); @@ -22,19 +25,15 @@ const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"; const TOKEN_URL = "https://oauth2.googleapis.com/token"; const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com"; -/** - * Generate PKCE code verifier and challenge - */ -function generatePKCE(): { verifier: string; challenge: string } { - const verifier = randomBytes(32).toString("base64url"); - const challenge = createHash("sha256").update(verifier).digest("base64url"); - return { verifier, challenge }; -} - /** * Start a local HTTP server to receive the OAuth callback */ -function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> { +async function startCallbackServer(): Promise<{ + server: Server; + getCode: () => Promise<{ code: string; state: string }>; +}> { + const { createServer } = await import("http"); + return new Promise((resolve, reject) => { let codeResolve: (value: { code: string; state: string }) => void; let codeReject: (error: Error) => void; @@ -263,7 +262,7 @@ export async function loginGeminiCli( onAuth: (info: { url: string; instructions?: string }) => void, onProgress?: (message: string) => void, ): Promise { - const { verifier, challenge } = generatePKCE(); + const { verifier, challenge } = await generatePKCE(); // Start local server for callback onProgress?.("Starting local server for OAuth callback..."); diff --git a/packages/ai/src/utils/oauth/pkce.ts b/packages/ai/src/utils/oauth/pkce.ts new file mode 100644 index 00000000..bf7ac7d5 --- /dev/null +++ b/packages/ai/src/utils/oauth/pkce.ts @@ -0,0 +1,34 @@ +/** + * PKCE utilities using Web Crypto API. + * Works in both Node.js 20+ and browsers. + */ + +/** + * Encode bytes as base64url string. + */ +function base64urlEncode(bytes: Uint8Array): string { + let binary = ""; + for (const byte of bytes) { + binary += String.fromCharCode(byte); + } + return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, ""); +} + +/** + * Generate PKCE code verifier and challenge. + * Uses Web Crypto API for cross-platform compatibility. + */ +export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> { + // Generate random verifier + const verifierBytes = new Uint8Array(32); + crypto.getRandomValues(verifierBytes); + const verifier = base64urlEncode(verifierBytes); + + // Compute SHA-256 challenge + const encoder = new TextEncoder(); + const data = encoder.encode(verifier); + const hashBuffer = await crypto.subtle.digest("SHA-256", data); + const challenge = base64urlEncode(new Uint8Array(hashBuffer)); + + return { verifier, challenge }; +} diff --git a/packages/ai/test/agent-queue-interrupt.test.ts b/packages/ai/test/agent-queue-interrupt.test.ts deleted file mode 100644 index 42a5db45..00000000 --- a/packages/ai/test/agent-queue-interrupt.test.ts +++ /dev/null @@ -1,166 +0,0 @@ -import { Type } from "@sinclair/typebox"; -import { describe, expect, it } from "vitest"; -import { agentLoop } from "../src/agent/agent-loop.js"; -import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js"; -import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js"; -import { AssistantMessageEventStream } from "../src/utils/event-stream.js"; - -function createUsage() { - return { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }; -} - -function createModel(): Model<"openai-responses"> { - return { - id: "mock", - name: "mock", - api: "openai-responses", - provider: "openai", - baseUrl: "https://example.invalid", - reasoning: false, - input: ["text"], - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, - contextWindow: 8192, - maxTokens: 2048, - }; -} - -describe("agentLoop queued message interrupt", () => { - it("injects queued messages after a tool call and skips remaining tool calls", async () => { - const toolSchema = Type.Object({ value: Type.String() }); - const executed: string[] = []; - const tool: AgentTool = { - name: "echo", - label: "Echo", - description: "Echo tool", - parameters: toolSchema, - async execute(_toolCallId, params) { - executed.push(params.value); - return { - content: [{ type: "text", text: `ok:${params.value}` }], - details: { value: params.value }, - }; - }, - }; - - const context: AgentContext = { - systemPrompt: "", - messages: [], - tools: [tool], - }; - - const userPrompt: UserMessage = { - role: "user", - content: "start", - timestamp: Date.now(), - }; - - const queuedUserMessage: Message = { - role: "user", - content: "interrupt", - timestamp: Date.now(), - }; - const queuedMessages: QueuedMessage[] = [{ original: queuedUserMessage, llm: queuedUserMessage }]; - - let queuedDelivered = false; - let sawInterruptInContext = false; - let callIndex = 0; - - const streamFn = () => { - const stream = new AssistantMessageEventStream(); - queueMicrotask(() => { - if (callIndex === 0) { - const message: AssistantMessage = { - role: "assistant", - content: [ - { type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } }, - { type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } }, - ], - api: "openai-responses", - provider: "openai", - model: "mock", - usage: createUsage(), - stopReason: "toolUse", - timestamp: Date.now(), - }; - stream.push({ type: "done", reason: "toolUse", message }); - } else { - const message: AssistantMessage = { - role: "assistant", - content: [{ type: "text", text: "done" }], - api: "openai-responses", - provider: "openai", - model: "mock", - usage: createUsage(), - stopReason: "stop", - timestamp: Date.now(), - }; - stream.push({ type: "done", reason: "stop", message }); - } - callIndex += 1; - }); - return stream; - }; - - const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async () => { - if (executed.length === 1 && !queuedDelivered) { - queuedDelivered = true; - return queuedMessages as QueuedMessage[]; - } - return []; - }; - - const config: AgentLoopConfig = { - model: createModel(), - getQueuedMessages, - }; - - const events: AgentEvent[] = []; - const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => { - if (callIndex === 1) { - sawInterruptInContext = ctx.messages.some( - (m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt", - ); - } - return streamFn(); - }); - - for await (const event of stream) { - events.push(event); - } - - expect(executed).toEqual(["first"]); - const toolEnds = events.filter( - (event): event is Extract => event.type === "tool_execution_end", - ); - expect(toolEnds.length).toBe(2); - expect(toolEnds[1].isError).toBe(true); - expect(toolEnds[1].result.content[0]?.type).toBe("text"); - if (toolEnds[1].result.content[0]?.type === "text") { - expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message"); - } - - const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end"); - const queuedMessageIndex = events.findIndex( - (event) => - event.type === "message_start" && - event.message.role === "user" && - typeof event.message.content === "string" && - event.message.content === "interrupt", - ); - const nextAssistantIndex = events.findIndex( - (event, index) => - index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant", - ); - - expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex); - expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex); - expect(sawInterruptInContext).toBe(true); - }); -}); diff --git a/packages/ai/test/agent.test.ts b/packages/ai/test/agent.test.ts deleted file mode 100644 index 8b7fdd46..00000000 --- a/packages/ai/test/agent.test.ts +++ /dev/null @@ -1,701 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js"; -import { calculateTool } from "../src/agent/tools/calculate.js"; -import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js"; -import { getModel } from "../src/models.js"; -import type { - Api, - AssistantMessage, - Message, - Model, - OptionsForApi, - ToolResultMessage, - UserMessage, -} from "../src/types.js"; -import { resolveApiKey } from "./oauth.js"; - -// Resolve OAuth tokens at module level (async, runs before tests) -const oauthTokens = await Promise.all([ - resolveApiKey("anthropic"), - resolveApiKey("github-copilot"), - resolveApiKey("google-gemini-cli"), - resolveApiKey("google-antigravity"), -]); -const [anthropicOAuthToken, githubCopilotToken, geminiCliToken, antigravityToken] = oauthTokens; - -async function calculateTest(model: Model, options: OptionsForApi = {}) { - // Create the agent context with the calculator tool - const context: AgentContext = { - systemPrompt: - "You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.", - messages: [], - tools: [calculateTool], - }; - - // Create the prompt config - const config: AgentLoopConfig = { - model, - ...options, - }; - - // Create the user prompt asking for multiple calculations - const userPrompt: UserMessage = { - role: "user", - content: `Use the calculator tool to complete the following mulit-step task. -1. Calculate 3485 * 4234 and 88823 * 3482 in parallel -2. Calculate the sum of the two results using the calculator tool -3. Output ONLY the final sum as a single integer number, nothing else.`, - timestamp: Date.now(), - }; - - // Calculate expected results (using integers) - const expectedFirst = 3485 * 4234; // = 14755490 - const expectedSecond = 88823 * 3482; // = 309281786 - const expectedSum = expectedFirst + expectedSecond; // = 324037276 - - // Track events for verification - const events: AgentEvent[] = []; - let turns = 0; - let toolCallCount = 0; - const toolResults: number[] = []; - let finalAnswer: number | undefined; - - // Execute the prompt - const stream = agentLoop(userPrompt, context, config); - - for await (const event of stream) { - events.push(event); - - switch (event.type) { - case "turn_start": - turns++; - console.log(`\n=== Turn ${turns} started ===`); - break; - - case "turn_end": - console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`); - console.log(event.message); - break; - - case "tool_execution_end": - if (!event.isError && typeof event.result === "object" && event.result.content) { - const textOutput = event.result.content - .filter((c: any) => c.type === "text") - .map((c: any) => c.text) - .join("\n"); - toolCallCount++; - // Extract number from output like "expression = result" - const match = textOutput.match(/=\s*([\d.]+)/); - if (match) { - const value = parseFloat(match[1]); - toolResults.push(value); - console.log(`Tool ${toolCallCount}: ${textOutput}`); - } - } - break; - - case "message_end": - // Just track the message end event, don't extract answer here - break; - } - } - - // Get the final messages - const finalMessages = await stream.result(); - - // Verify the results - expect(finalMessages).toBeDefined(); - expect(finalMessages.length).toBeGreaterThan(0); - - const finalMessage = finalMessages[finalMessages.length - 1]; - expect(finalMessage).toBeDefined(); - expect(finalMessage.role).toBe("assistant"); - if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant"); - - // Extract the final answer from the last assistant message - const content = finalMessage.content - .filter((c) => c.type === "text") - .map((c) => (c.type === "text" ? c.text : "")) - .join(" "); - - // Look for integers in the response that might be the final answer - const numbers = content.match(/\b\d+\b/g); - if (numbers) { - // Check if any of the numbers matches our expected sum - for (const num of numbers) { - const value = parseInt(num, 10); - if (Math.abs(value - expectedSum) < 10) { - finalAnswer = value; - break; - } - } - // If no exact match, take the last large number as likely the answer - if (finalAnswer === undefined) { - const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000); - if (largeNumbers.length > 0) { - finalAnswer = largeNumbers[largeNumbers.length - 1]; - } - } - } - - // Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum - // (or possibly 2 if the model calculates the sum itself without a tool) - expect(toolCallCount).toBeGreaterThanOrEqual(2); - - // Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer - // Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions - expect(turns).toBeGreaterThanOrEqual(3); - expect(turns).toBeLessThanOrEqual(4); - - // Verify the individual calculations are in the results - const hasFirstCalc = toolResults.some((r) => r === expectedFirst); - const hasSecondCalc = toolResults.some((r) => r === expectedSecond); - expect(hasFirstCalc).toBe(true); - expect(hasSecondCalc).toBe(true); - - // Verify the final sum - if (finalAnswer !== undefined) { - expect(finalAnswer).toBe(expectedSum); - console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`); - } else { - // If we couldn't extract the final answer from text, check if it's in the tool results - const hasSum = toolResults.some((r) => r === expectedSum); - expect(hasSum).toBe(true); - } - - // Log summary - console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`); - if (turns === 3) { - console.log("Model used parallel tool calls for initial calculations"); - } else { - console.log("Model used sequential tool calls"); - } - - return { - turns, - toolCallCount, - toolResults, - finalAnswer, - events, - }; -} - -async function abortTest(model: Model, options: OptionsForApi = {}) { - // Create the agent context with the calculator tool - const context: AgentContext = { - systemPrompt: - "You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.", - messages: [], - tools: [calculateTool], - }; - - // Create the prompt config - const config: AgentLoopConfig = { - model, - ...options, - }; - - // Create a prompt that will require multiple calculations - const userPrompt: UserMessage = { - role: "user", - content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.", - timestamp: Date.now(), - }; - - // Create abort controller - const abortController = new AbortController(); - - // Track events for verification - const events: AgentEvent[] = []; - let toolCallCount = 0; - const errorReceived = false; - let finalMessages: Message[] | undefined; - - // Execute the prompt - const stream = agentLoop(userPrompt, context, config, abortController.signal); - - // Abort after first tool execution - (async () => { - for await (const event of stream) { - events.push(event); - - if (event.type === "tool_execution_end" && !event.isError) { - toolCallCount++; - // Abort after first successful tool execution - if (toolCallCount === 1) { - console.log("Aborting after first tool execution"); - abortController.abort(); - } - } - - if (event.type === "agent_end") { - finalMessages = event.messages; - } - } - })(); - - finalMessages = await stream.result(); - - // Verify abort behavior - console.log(`\nAbort test completed with ${toolCallCount} tool calls`); - const assistantMessage = finalMessages[finalMessages.length - 1]; - if (!assistantMessage) throw new Error("No final message received"); - expect(assistantMessage).toBeDefined(); - expect(assistantMessage.role).toBe("assistant"); - if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant"); - - // Should have executed 1 tool call before abort - expect(toolCallCount).toBeGreaterThanOrEqual(1); - expect(assistantMessage.stopReason).toBe("aborted"); - - return { - toolCallCount, - events, - errorReceived, - finalMessages, - }; -} - -describe("Agent Calculator Tests", () => { - describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => { - const model = getModel("google", "gemini-2.5-flash"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => { - const model = getModel("openai", "gpt-4o-mini"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => { - const model = getModel("openai", "gpt-5-mini"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => { - const model = getModel("anthropic", "claude-haiku-4-5"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => { - const model = getModel("xai", "grok-3"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => { - const model = getModel("groq", "openai/gpt-oss-20b"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => { - const model = getModel("cerebras", "gpt-oss-120b"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => { - const model = getModel("zai", "glm-4.5-air"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider Agent", () => { - const model = getModel("mistral", "devstral-medium-latest"); - - it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => { - const result = await calculateTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }); - - it("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - // ========================================================================= - // OAuth-based providers (credentials from ~/.pi/agent/oauth.json) - // ========================================================================= - - describe("Anthropic OAuth Provider Agent", () => { - const model = getModel("anthropic", "claude-haiku-4-5"); - - it.skipIf(!anthropicOAuthToken)( - "should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const result = await calculateTest(model, { apiKey: anthropicOAuthToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!anthropicOAuthToken)("should handle abort during tool execution", { retry: 3 }, async () => { - const result = await abortTest(model, { apiKey: anthropicOAuthToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - }); - - describe("GitHub Copilot Provider Agent", () => { - it.skipIf(!githubCopilotToken)( - "gpt-4o - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("github-copilot", "gpt-4o"); - const result = await calculateTest(model, { apiKey: githubCopilotToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!githubCopilotToken)("gpt-4o - should handle abort during tool execution", { retry: 3 }, async () => { - const model = getModel("github-copilot", "gpt-4o"); - const result = await abortTest(model, { apiKey: githubCopilotToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }); - - it.skipIf(!githubCopilotToken)( - "claude-sonnet-4 - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("github-copilot", "claude-sonnet-4"); - const result = await calculateTest(model, { apiKey: githubCopilotToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!githubCopilotToken)( - "claude-sonnet-4 - should handle abort during tool execution", - { retry: 3 }, - async () => { - const model = getModel("github-copilot", "claude-sonnet-4"); - const result = await abortTest(model, { apiKey: githubCopilotToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }, - ); - }); - - describe("Google Gemini CLI Provider Agent", () => { - it.skipIf(!geminiCliToken)( - "gemini-2.5-flash - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("google-gemini-cli", "gemini-2.5-flash"); - const result = await calculateTest(model, { apiKey: geminiCliToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!geminiCliToken)( - "gemini-2.5-flash - should handle abort during tool execution", - { retry: 3 }, - async () => { - const model = getModel("google-gemini-cli", "gemini-2.5-flash"); - const result = await abortTest(model, { apiKey: geminiCliToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }, - ); - }); - - describe("Google Antigravity Provider Agent", () => { - it.skipIf(!antigravityToken)( - "gemini-3-flash - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "gemini-3-flash"); - const result = await calculateTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!antigravityToken)( - "gemini-3-flash - should handle abort during tool execution", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "gemini-3-flash"); - const result = await abortTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }, - ); - - it.skipIf(!antigravityToken)( - "claude-sonnet-4-5 - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "claude-sonnet-4-5"); - const result = await calculateTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!antigravityToken)( - "claude-sonnet-4-5 - should handle abort during tool execution", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "claude-sonnet-4-5"); - const result = await abortTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }, - ); - - it.skipIf(!antigravityToken)( - "gpt-oss-120b-medium - should calculate multiple expressions and sum the results", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "gpt-oss-120b-medium"); - const result = await calculateTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(2); - }, - ); - - it.skipIf(!antigravityToken)( - "gpt-oss-120b-medium - should handle abort during tool execution", - { retry: 3 }, - async () => { - const model = getModel("google-antigravity", "gpt-oss-120b-medium"); - const result = await abortTest(model, { apiKey: antigravityToken }); - expect(result.toolCallCount).toBeGreaterThanOrEqual(1); - }, - ); - }); -}); - -describe("agentLoopContinue", () => { - describe("validation", () => { - const model = getModel("anthropic", "claude-haiku-4-5"); - const baseContext: AgentContext = { - systemPrompt: "You are a helpful assistant.", - messages: [], - tools: [], - }; - const config: AgentLoopConfig = { model }; - - it("should throw when context has no messages", () => { - expect(() => agentLoopContinue(baseContext, config)).toThrow("Cannot continue: no messages in context"); - }); - - it("should throw when last message is an assistant message", () => { - const assistantMessage: AssistantMessage = { - role: "assistant", - content: [{ type: "text", text: "Hello" }], - api: "anthropic-messages", - provider: "anthropic", - model: "claude-haiku-4-5", - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "stop", - timestamp: Date.now(), - }; - const context: AgentContext = { - ...baseContext, - messages: [assistantMessage], - }; - expect(() => agentLoopContinue(context, config)).toThrow( - "Cannot continue from message role: assistant. Expected 'user' or 'toolResult'.", - ); - }); - - // Note: "should not throw" tests for valid inputs are covered by the E2E tests below - // which actually consume the stream and verify the output - }); - - describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => { - const model = getModel("anthropic", "claude-haiku-4-5"); - - it("should continue and get assistant response when last message is user", { retry: 3 }, async () => { - const userMessage: UserMessage = { - role: "user", - content: [{ type: "text", text: "Say exactly: HELLO WORLD" }], - timestamp: Date.now(), - }; - - const context: AgentContext = { - systemPrompt: "You are a helpful assistant. Follow instructions exactly.", - messages: [userMessage], - tools: [], - }; - - const config: AgentLoopConfig = { model }; - - const events: AgentEvent[] = []; - const stream = agentLoopContinue(context, config); - - for await (const event of stream) { - events.push(event); - } - - const messages = await stream.result(); - - // Should have gotten an assistant response - expect(messages.length).toBe(1); - expect(messages[0].role).toBe("assistant"); - - // Verify event sequence - no user message events since we're continuing - const eventTypes = events.map((e) => e.type); - expect(eventTypes).toContain("agent_start"); - expect(eventTypes).toContain("turn_start"); - expect(eventTypes).toContain("message_start"); - expect(eventTypes).toContain("message_end"); - expect(eventTypes).toContain("turn_end"); - expect(eventTypes).toContain("agent_end"); - - // Should NOT have user message events (that's the difference from agentLoop) - const messageEndEvents = events.filter((e) => e.type === "message_end"); - expect(messageEndEvents.length).toBe(1); // Only assistant message - expect((messageEndEvents[0] as any).message.role).toBe("assistant"); - }); - }); - - describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => { - const model = getModel("anthropic", "claude-haiku-4-5"); - - it("should continue processing after tool results", { retry: 3 }, async () => { - // Simulate a conversation where: - // 1. User asked to calculate something - // 2. Assistant made a tool call - // 3. Tool result is ready - // 4. We continue from here - - const userMessage: UserMessage = { - role: "user", - content: [{ type: "text", text: "What is 5 + 3? Use the calculator." }], - timestamp: Date.now(), - }; - - const assistantMessage: AssistantMessage = { - role: "assistant", - content: [ - { type: "text", text: "Let me calculate that for you." }, - { type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } }, - ], - api: "anthropic-messages", - provider: "anthropic", - model: "claude-haiku-4-5", - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: "toolUse", - timestamp: Date.now(), - }; - - const toolResult: ToolResultMessage = { - role: "toolResult", - toolCallId: "calc-1", - toolName: "calculate", - content: [{ type: "text", text: "5 + 3 = 8" }], - isError: false, - timestamp: Date.now(), - }; - - const context: AgentContext = { - systemPrompt: "You are a helpful assistant. After getting a calculation result, state the answer clearly.", - messages: [userMessage, assistantMessage, toolResult], - tools: [calculateTool], - }; - - const config: AgentLoopConfig = { model }; - - const events: AgentEvent[] = []; - const stream = agentLoopContinue(context, config); - - for await (const event of stream) { - events.push(event); - } - - const messages = await stream.result(); - - // Should have gotten an assistant response - expect(messages.length).toBeGreaterThanOrEqual(1); - const lastMessage = messages[messages.length - 1]; - expect(lastMessage.role).toBe("assistant"); - - // The assistant should mention the result (8) - if (lastMessage.role === "assistant") { - const textContent = lastMessage.content - .filter((c) => c.type === "text") - .map((c) => (c as any).text) - .join(" "); - expect(textContent).toMatch(/8/); - } - }); - }); -}); diff --git a/packages/ai/test/stream.test.ts b/packages/ai/test/stream.test.ts index 29d9111f..02229d66 100644 --- a/packages/ai/test/stream.test.ts +++ b/packages/ai/test/stream.test.ts @@ -752,13 +752,11 @@ describe("Generate E2E Tests", () => { const llm = getModel("google-gemini-cli", "gemini-3-flash-preview"); it.skipIf(!geminiCliToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => { - const { ThinkingLevel } = await import("@google/genai"); - await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.LOW } }); + await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "LOW" } }); }); it.skipIf(!geminiCliToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => { - const { ThinkingLevel } = await import("@google/genai"); - await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } }); + await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "MEDIUM" } }); }); }); @@ -778,17 +776,15 @@ describe("Generate E2E Tests", () => { }); it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => { - const { ThinkingLevel } = await import("@google/genai"); // gemini-3-flash supports all four levels: MINIMAL, LOW, MEDIUM, HIGH await handleThinking(llm, { apiKey: antigravityToken, - thinking: { enabled: true, level: ThinkingLevel.LOW }, + thinking: { enabled: true, level: "LOW" }, }); }); it.skipIf(!antigravityToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => { - const { ThinkingLevel } = await import("@google/genai"); - await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } }); + await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: "MEDIUM" } }); }); it.skipIf(!antigravityToken)("should handle image input", { retry: 3 }, async () => { @@ -800,11 +796,10 @@ describe("Generate E2E Tests", () => { const llm = getModel("google-antigravity", "gemini-3-pro-high"); it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel HIGH", { retry: 3 }, async () => { - const { ThinkingLevel } = await import("@google/genai"); // gemini-3-pro only supports LOW/HIGH await handleThinking(llm, { apiKey: antigravityToken, - thinking: { enabled: true, level: ThinkingLevel.HIGH }, + thinking: { enabled: true, level: "HIGH" }, }); }); }); diff --git a/packages/ai/test/tool-validation.test.ts b/packages/ai/test/tool-validation.test.ts deleted file mode 100644 index 709d0946..00000000 --- a/packages/ai/test/tool-validation.test.ts +++ /dev/null @@ -1,140 +0,0 @@ -import { type Static, Type } from "@sinclair/typebox"; -import AjvModule from "ajv"; -import addFormatsModule from "ajv-formats"; - -// Handle both default and named exports -const Ajv = (AjvModule as any).default || AjvModule; -const addFormats = (addFormatsModule as any).default || addFormatsModule; - -import { describe, expect, it } from "vitest"; -import type { AgentTool } from "../src/agent/types.js"; - -describe("Tool Validation with TypeBox and AJV", () => { - // Define a test tool with TypeBox schema - const testSchema = Type.Object({ - name: Type.String({ minLength: 1 }), - age: Type.Integer({ minimum: 0, maximum: 150 }), - email: Type.String({ format: "email" }), - tags: Type.Optional(Type.Array(Type.String())), - }); - - type TestParams = Static; - - const testTool: AgentTool = { - label: "Test Tool", - name: "test_tool", - description: "A test tool for validation", - parameters: testSchema, - execute: async (_toolCallId, args) => { - return { - content: [{ type: "text", text: `Processed: ${args.name}, ${args.age}, ${args.email}` }], - details: undefined, - }; - }, - }; - - // Create AJV instance for validation - const ajv = new Ajv({ allErrors: true }); - addFormats(ajv); - - it("should validate correct input", () => { - const validInput = { - name: "John Doe", - age: 30, - email: "john@example.com", - tags: ["developer", "typescript"], - }; - - // Validate with AJV - const validate = ajv.compile(testTool.parameters); - const isValid = validate(validInput); - expect(isValid).toBe(true); - }); - - it("should reject invalid email", () => { - const invalidInput = { - name: "John Doe", - age: 30, - email: "not-an-email", - }; - - const validate = ajv.compile(testTool.parameters); - const isValid = validate(invalidInput); - expect(isValid).toBe(false); - expect(validate.errors).toBeDefined(); - }); - - it("should reject missing required fields", () => { - const invalidInput = { - age: 30, - email: "john@example.com", - }; - - const validate = ajv.compile(testTool.parameters); - const isValid = validate(invalidInput); - expect(isValid).toBe(false); - expect(validate.errors).toBeDefined(); - }); - - it("should reject invalid age", () => { - const invalidInput = { - name: "John Doe", - age: -5, - email: "john@example.com", - }; - - const validate = ajv.compile(testTool.parameters); - const isValid = validate(invalidInput); - expect(isValid).toBe(false); - expect(validate.errors).toBeDefined(); - }); - - it("should format validation errors nicely", () => { - const invalidInput = { - name: "", - age: 200, - email: "invalid", - }; - - const validate = ajv.compile(testTool.parameters); - const isValid = validate(invalidInput); - expect(isValid).toBe(false); - expect(validate.errors).toBeDefined(); - - if (validate.errors) { - const errors = validate.errors - .map((err: any) => { - const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root"; - return ` - ${path}: ${err.message}`; - }) - .join("\n"); - - // AJV error messages are different from Zod - expect(errors).toContain("name: must NOT have fewer than 1 characters"); - expect(errors).toContain("age: must be <= 150"); - expect(errors).toContain('email: must match format "email"'); - } - }); - - it("should have type-safe execute function", async () => { - const validInput = { - name: "John Doe", - age: 30, - email: "john@example.com", - }; - - // Validate and execute - const validate = ajv.compile(testTool.parameters); - const isValid = validate(validInput); - expect(isValid).toBe(true); - - const result = await testTool.execute("test-id", validInput as TestParams); - - const textOutput = result.content - .filter((c: any) => c.type === "text") - .map((c: any) => c.text) - .join("\n"); - expect(textOutput).toBe("Processed: John Doe, 30, john@example.com"); - expect(result.details).toBeUndefined(); - }); -}); diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index 43246cf7..b2c064c8 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -2,13 +2,68 @@ ## [Unreleased] +### Breaking Changes + +- **Session tree structure (v2)**: Sessions now store entries as a tree with `id`/`parentId` fields, enabling in-place branching without creating new files. Existing v1 sessions are auto-migrated on load. +- **SessionManager API**: + - `saveXXX()` renamed to `appendXXX()` (e.g., `appendMessage`, `appendCompaction`) + - `branchInPlace()` renamed to `branch()` + - `reset()` renamed to `newSession()` + - `createBranchedSessionFromEntries(entries, index)` replaced with `createBranchedSession(leafId)` + - `saveCompaction(entry)` replaced with `appendCompaction(summary, firstKeptEntryId, tokensBefore)` + - `getEntries()` now excludes the session header (use `getHeader()` separately) + - New methods: `getTree()`, `getPath()`, `getLeafUuid()`, `getLeafEntry()`, `getEntry()`, `branchWithSummary()` + - New `appendCustomEntry(customType, data)` for hooks to store custom data (not in LLM context) + - New `appendCustomMessageEntry(customType, content, display, details?)` for hooks to inject messages into LLM context +- **Compaction API**: + - `CompactionEntry` and `CompactionResult` are now generic with optional `details?: T` for hook-specific data + - `compact()` now returns `CompactionResult` (`{ summary, firstKeptEntryId, tokensBefore, details? }`) instead of `CompactionEntry` + - `appendCompaction()` now accepts optional `details` parameter + - `CompactionEntry.firstKeptEntryIndex` replaced with `firstKeptEntryId` + - `prepareCompaction()` now returns `firstKeptEntryId` in its result +- **Hook types**: + - `SessionEventBase` no longer has `sessionManager`/`modelRegistry` - access them via `HookEventContext` instead + - `HookEventContext` now has `sessionManager` and `modelRegistry` (moved from events) + - `HookEventContext` no longer has `exec()` - use `pi.exec()` instead + - `HookCommandContext` no longer has `exec()` - use `pi.exec()` instead + - `before_compact` event passes `preparation: CompactionPreparation` and `previousCompactions: CompactionEntry[]` (newest first) + - `before_switch` event now has `targetSessionFile`, `switch` event has `previousSessionFile` + - Removed `resolveApiKey` (use `modelRegistry.getApiKey(model)`) + - Hooks can return `compaction.details` to store custom data (e.g., ArtifactIndex for structured compaction) +- **Hook API**: + - `pi.send(text, attachments?)` replaced with `pi.sendMessage(message, triggerTurn?)` which creates `CustomMessageEntry` instead of user messages + - New `pi.appendEntry(customType, data?)` to persist hook state (does NOT participate in LLM context) + - New `pi.registerCommand(name, options)` to register custom slash commands + - New `pi.registerMessageRenderer(customType, renderer)` to register custom renderers for hook messages + - New `pi.exec(command, args, options?)` to execute shell commands (moved from `HookEventContext`/`HookCommandContext`) + - `HookMessageRenderer` type: `(message: HookMessage, options, theme) => Component | null` + - Renderers return inner content; the TUI wraps it in a styled Box + - New types: `HookMessage`, `RegisteredCommand`, `HookCommandContext` + - Handler types renamed: `SendHandler` → `SendMessageHandler`, new `AppendEntryHandler` +- **SessionManager**: + - `getSessionFile()` now returns `string | undefined` (undefined for in-memory sessions) +- **Themes**: Custom themes must add `customMessageBg`, `customMessageText`, `customMessageLabel` color tokens + ### Added - **`enabledModels` setting**: Configure whitelisted models in `settings.json` (same format as `--models` CLI flag). CLI `--models` takes precedence over the setting. +### Changed + +- **Entry IDs**: Session entries now use short 8-character hex IDs instead of full UUIDs +- **API key priority**: `ANTHROPIC_OAUTH_TOKEN` now takes precedence over `ANTHROPIC_API_KEY` +- **New entry types**: `BranchSummaryEntry` for branch context, `CustomEntry` for hook state persistence, `CustomMessageEntry` for hook-injected context messages, `LabelEntry` for user-defined bookmarks +- **Entry labels**: New `getLabel(id)` and `appendLabelChange(targetId, label)` methods for labeling entries. Labels are included in `SessionTreeNode` for UI/export. +- **TUI**: `CustomMessageEntry` renders with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors). Entries with `display: false` are hidden. +- **AgentSession**: New `sendHookMessage(message, triggerTurn?)` method for hooks to inject messages. Handles queuing during streaming, direct append when idle, and optional turn triggering. +- **HookMessage**: New message type with `role: "hookMessage"` for hook-injected messages in agent events. Use `isHookMessage(msg)` type guard to identify them. These are converted to user messages for LLM context via `messageTransformer`. +- **Agent.prompt()**: Now accepts `AppMessage` directly (in addition to `string, attachments?`) for custom message types like `HookMessage`. + ### Fixed - **Edit tool fails on Windows due to CRLF line endings**: Files with CRLF line endings now match correctly when LLMs send LF-only text. Line endings are normalized before matching and restored to original style on write. ([#355](https://github.com/badlogic/pi-mono/issues/355)) +- **Session file validation**: `findMostRecentSession()` now validates session headers before returning, preventing non-session JSONL files from being loaded +- **Compaction error handling**: `generateSummary()` and `generateTurnPrefixSummary()` now throw on LLM errors instead of returning empty strings ## [0.30.2] - 2025-12-26 diff --git a/packages/coding-agent/docs/session-tree-plan.md b/packages/coding-agent/docs/session-tree-plan.md new file mode 100644 index 00000000..0a3a50ae --- /dev/null +++ b/packages/coding-agent/docs/session-tree-plan.md @@ -0,0 +1,441 @@ +# Session Tree Implementation Plan + +Reference: [session-tree.md](./session-tree.md) + +## Phase 1: SessionManager Core ✅ + +- [x] Update entry types with `id`, `parentId` fields (using SessionEntryBase) +- [x] Add `version` field to `SessionHeader` +- [x] Change `CompactionEntry.firstKeptEntryIndex` → `firstKeptEntryId` +- [x] Add `BranchSummaryEntry` type +- [x] Add `CustomEntry` type for hooks +- [x] Add `byId: Map` index +- [x] Add `leafId: string` tracking +- [x] Implement `getPath(fromId?)` tree traversal +- [x] Implement `getTree()` returning `SessionTreeNode[]` +- [x] Implement `getEntry(id)` lookup +- [x] Implement `getLeafUuid()` and `getLeafEntry()` helpers +- [x] Update `_buildIndex()` to populate `byId` map +- [x] Rename `saveXXX()` to `appendXXX()` (returns id, advances leaf) +- [x] Add `appendCustomEntry(customType, data)` for hooks +- [x] Update `buildSessionContext()` to use `getPath()` traversal + +## Phase 2: Migration ✅ + +- [x] Add `CURRENT_SESSION_VERSION = 2` constant +- [x] Implement `migrateV1ToV2()` with extensible migration chain +- [x] Update `setSessionFile()` to detect version and migrate +- [x] Implement `_rewriteFile()` for post-migration persistence +- [x] Handle `firstKeptEntryIndex` → `firstKeptEntryId` conversion in migration + +## Phase 3: Branching ✅ + +- [x] Implement `branch(id)` - switch leaf pointer +- [x] Implement `branchWithSummary(id, summary)` - create summary entry +- [x] Implement `createBranchedSession(leafId)` - extract path to new file +- [x] Update `AgentSession.branch()` to use new API + +## Phase 4: Compaction Integration ✅ + +- [x] Update `compaction.ts` to work with IDs +- [x] Update `prepareCompaction()` to return `firstKeptEntryId` +- [x] Update `compact()` to return `CompactionResult` with `firstKeptEntryId` +- [x] Update `AgentSession` compaction methods +- [x] Add `firstKeptEntryId` to `before_compact` hook event + +## Phase 5: Testing ✅ + +- [x] `migration.test.ts` - v1 to v2 migration, idempotency +- [x] `build-context.test.ts` - context building with tree structure, compaction, branches +- [x] `tree-traversal.test.ts` - append operations, getPath, getTree, branching +- [x] `file-operations.test.ts` - loadEntriesFromFile, findMostRecentSession +- [x] `save-entry.test.ts` - custom entry integration +- [x] Update existing compaction tests for new types + +--- + +## Remaining Work + +### Compaction Refactor + +- [x] Use `CompactionResult` type for hook return value +- [x] Make `CompactionEntry` generic with optional `details?: T` field for hook-specific data +- [x] Make `CompactionResult` generic to match +- [x] Update `SessionEventBase` to pass `sessionManager` and `modelRegistry` instead of derived fields +- [x] Update `before_compact` event: + - Pass `preparation: CompactionPreparation` instead of individual fields + - Pass `previousCompactions: CompactionEntry[]` (newest first) instead of `previousSummary?: string` + - Keep: `customInstructions`, `model`, `signal` + - Drop: `resolveApiKey` (use `modelRegistry.getApiKey()`), `cutPoint`, `entries` +- [x] Update hook example `custom-compaction.ts` to use new API +- [x] Update `getSessionFile()` to return `string | undefined` for in-memory sessions +- [x] Update `before_switch` to have `targetSessionFile`, `switch` to have `previousSessionFile` + +Reference: [#314](https://github.com/badlogic/pi-mono/pull/314) - Structured compaction with anchored iterative summarization needs `details` field to store `ArtifactIndex` and version markers. + +### Branch Summary Design ✅ + +Current type: +```typescript +export interface BranchSummaryEntry extends SessionEntryBase { + type: "branch_summary"; + summary: string; + fromId: string; // References the abandoned leaf + fromHook?: boolean; // Whether summary was generated by a hook + details?: unknown; // File tracking: { readFiles, modifiedFiles } +} +``` + +- [x] `fromId` field references the abandoned leaf +- [x] `fromHook` field distinguishes pi-generated vs hook-generated summaries +- [x] `details` field for file tracking +- [x] Branch summarizer implemented with structured output format +- [x] Uses serialization approach (same as compaction) to prevent model confusion +- [x] Tests for `branchWithSummary()` flow + +### Entry Labels ✅ + +- [x] Add `LabelEntry` type with `targetId` and `label` fields +- [x] Add `labelsById: Map` private field +- [x] Build labels map in `_buildIndex()` via linear scan +- [x] Add `getLabel(id)` method +- [x] Add `appendLabelChange(targetId, label)` method (undefined clears) +- [x] Update `createBranchedSession()` to filter out LabelEntry and recreate from resolved map +- [x] `buildSessionContext()` already ignores LabelEntry (only handles message types) +- [x] Add `label?: string` to `SessionTreeNode`, populated by `getTree()` +- [x] Display labels in UI (tree-selector shows labels) +- [x] `/label` command (implemented in tree-selector) + +### CustomMessageEntry + +Hook-injected messages that participate in LLM context. Unlike `CustomEntry` (for hook state only), these are sent to the model. + +```typescript +export interface CustomMessageEntry extends SessionEntryBase { + type: "custom_message"; + customType: string; // Hook identifier + content: string | (TextContent | ImageContent)[]; // Message content (same as UserMessage) + details?: T; // Hook-specific data for state reconstruction on reload + display: boolean; // Whether to display in TUI +} +``` + +Behavior: +- [x] Type definition matching plan +- [x] `appendCustomMessageEntry(customType, content, display, details?)` in SessionManager +- [x] `buildSessionContext()` includes custom_message entries as user messages +- [x] Exported from main index +- [x] TUI rendering: + - `display: false` - hidden entirely + - `display: true` - rendered with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors) + - [x] `registerCustomMessageRenderer(customType, renderer)` in HookAPI for custom renderers + - [x] Renderer returns inner Component, TUI wraps in styled Box + +### Hook API Changes ✅ + +**Renamed:** +- `renderCustomMessage()` → `registerCustomMessageRenderer()` + +**New: `sendMessage()` ✅** + +Replaces `send()`. Always creates CustomMessageEntry, never user messages. + +```typescript +type HookMessage = Pick, 'customType' | 'content' | 'display' | 'details'>; + +sendMessage(message: HookMessage, triggerTurn?: boolean): void; +``` + +Implementation: +- Uses agent's queue mechanism with `_hookData` marker on AppMessage +- `message_end` handler routes based on marker presence +- `AgentSession.sendHookMessage()` handles three cases: + - Streaming: queues via `agent.queueMessage()`, loop processes and emits `message_end` + - Not streaming + triggerTurn: direct append + `agent.continue()` + - Not streaming + no trigger: direct append only +- TUI updates via event (streaming) or explicit rebuild (non-streaming) + +**New: `appendEntry()` ✅** + +For hook state persistence (NOT in LLM context): + +```typescript +appendEntry(customType: string, data?: unknown): void; +``` + +Calls `sessionManager.appendCustomEntry()` directly. + +**New: `registerCommand()` (types ✅, wiring TODO)** + +```typescript +// HookAPI (the `pi` object) - utilities available to all hooks: +interface HookAPI { + sendMessage(message: HookMessage, triggerTurn?: boolean): void; + appendEntry(customType: string, data?: unknown): void; + registerCommand(name: string, options: RegisteredCommand): void; + registerCustomMessageRenderer(customType: string, renderer: CustomMessageRenderer): void; + exec(command: string, args: string[], options?: ExecOptions): Promise; +} + +// HookEventContext - passed to event handlers, has stable context: +interface HookEventContext { + ui: HookUIContext; + hasUI: boolean; + cwd: string; + sessionManager: SessionManager; + modelRegistry: ModelRegistry; +} +// Note: exec moved to HookAPI, sessionManager/modelRegistry moved from SessionEventBase + +// HookCommandContext - passed to command handlers: +interface HookCommandContext { + args: string; // Everything after /commandname + ui: HookUIContext; + hasUI: boolean; + cwd: string; + sessionManager: SessionManager; + modelRegistry: ModelRegistry; +} +// Note: exec and sendMessage accessed via `pi` closure + +registerCommand(name: string, options: { + description?: string; + handler: (ctx: HookCommandContext) => Promise; +}): void; +``` + +Handler return: +- `void` - command completed (use `sendMessage()` with `triggerTurn: true` to prompt LLM) + +Wiring (all in AgentSession.prompt()): +- [x] Add hook commands to autocomplete in interactive-mode +- [x] `_tryExecuteHookCommand()` in AgentSession handles command execution +- [x] Build HookCommandContext with ui (from hookRunner), exec, sessionManager, etc. +- [x] If handler returns string, use as prompt text +- [x] If handler returns undefined, return early (no LLM call) +- [x] Works for all modes (interactive, RPC, print) via shared AgentSession + +**New: `ui.custom()` ✅** + +For arbitrary hook UI with keyboard focus: + +```typescript +interface HookUIContext { + // ... existing: select, confirm, input, notify + + /** Show custom component with keyboard focus. Call done() when finished. */ + custom(component: Component, done: () => void): void; +} +``` + +See also: `CustomEntry` for storing hook state that does NOT participate in context. + +**New: `context` event ✅** + +Fires before messages are sent to the LLM, allowing hooks to modify context non-destructively. + +```typescript +interface ContextEvent { + type: "context"; + /** Messages that will be sent to the LLM */ + messages: Message[]; +} + +interface ContextEventResult { + /** Modified messages to send instead */ + messages?: Message[]; +} + +// In HookAPI: +on(event: "context", handler: HookHandler): void; +``` + +Example use case: **Dynamic Context Pruning** ([discussion #330](https://github.com/badlogic/pi-mono/discussions/330)) + +Non-destructive pruning of tool results to reduce context size: + +```typescript +export default function(pi: HookAPI) { + // Register /prune command + pi.registerCommand("prune", { + description: "Mark tool results for pruning", + handler: async (ctx) => { + // Show UI to select which tool results to prune + // Append custom entry recording pruning decisions: + // { toolResultId, strategy: "summary" | "truncate" | "remove" } + pi.appendEntry("tool-result-pruning", { ... }); + } + }); + + // Intercept context before LLM call + pi.on("context", async (event, ctx) => { + // Find all pruning entries in session + const entries = ctx.sessionManager.getEntries(); + const pruningRules = entries + .filter(e => e.type === "custom" && e.customType === "tool-result-pruning") + .map(e => e.data); + + // Apply pruning rules to messages + const prunedMessages = applyPruning(event.messages, pruningRules); + return { messages: prunedMessages }; + }); +} +``` + +Benefits: +- Original tool results stay intact in session +- Pruning is stored as custom entries, survives session reload +- Works with branching (pruning entries are part of the tree) +- Trade-off: cache busting on first submission after pruning + +### Investigate: `context` event vs `before_agent_start` ✅ + +References: +- [#324](https://github.com/badlogic/pi-mono/issues/324) - `before_agent_start` proposal +- [#330](https://github.com/badlogic/pi-mono/discussions/330) - Dynamic Context Pruning (why `context` was added) + +**Current `context` event:** +- Fires before each LLM call within the agent loop +- Receives `AgentMessage[]` (deep copy, safe to modify) +- Returns `Message[]` (inconsistent with input type) +- Modifications are transient (not persisted to session) +- No TUI visibility of what was changed +- Use case: non-destructive pruning, dynamic context manipulation + +**Type inconsistency:** Event receives `AgentMessage[]` but result returns `Message[]`: +```typescript +interface ContextEvent { + messages: AgentMessage[]; // Input +} +interface ContextEventResult { + messages?: Message[]; // Output - different type! +} +``` + +Questions: +- [ ] Should input/output both be `Message[]` (LLM format)? +- [ ] Or both be `AgentMessage[]` with conversion happening after? +- [ ] Where does `AgentMessage[]` → `Message[]` conversion currently happen? + +**Proposed `before_agent_start` event:** +- Fires once when user submits a prompt, before `agent_start` +- Allows hooks to inject additional content that gets **persisted** to session +- Injected content is visible in TUI (observability) +- Does not bust prompt cache (appended after user message, not modifying system prompt) + +**Key difference:** +| Aspect | `context` | `before_agent_start` | +|--------|-----------|---------------------| +| When | Before each LLM call | Once per user prompt | +| Persisted | No | Yes (as SystemMessage) | +| TUI visible | No | Yes (collapsible) | +| Cache impact | Can bust cache | Append-only, cache-safe | +| Use case | Transient manipulation | Persistent context injection | + +**Implementation (completed):** +- Reuses `HookMessage` type (no new message type needed) +- Handler returns `{ message: Pick }` +- Message is appended to agent state AND persisted to session before `agent.prompt()` is called +- Renders using existing `HookMessageComponent` (or custom renderer if registered) +- [ ] How does it interact with compaction? (treated like user messages?) +- [ ] Can hook return multiple messages or just one? + +**Implementation sketch:** +```typescript +interface BeforeAgentStartEvent { + type: "before_agent_start"; + userMessage: UserMessage; // The prompt user just submitted +} + +interface BeforeAgentStartResult { + /** Additional context to inject (persisted as SystemMessage) */ + inject?: { + label: string; // Shown in collapsed TUI state + content: string | (TextContent | ImageContent)[]; + }; +} +``` + +### HTML Export + +- [ ] Add collapsible sidebar showing full tree structure +- [ ] Allow selecting any node in tree to view that path +- [ ] Add "reset to session leaf" button +- [ ] Render full path (no compaction resolution needed) +- [ ] Responsive: collapse sidebar on mobile + +### UI Commands ✅ + +- [x] `/branch` - Creates new session file from current path (uses `createBranchedSession()`) +- [x] `/tree` - In-session tree navigation via tree-selector component + - Shows full tree structure with labels + - Navigate between branches (moves leaf pointer) + - Shows current position + - Generates branch summaries when switching branches + +### Tree Selector Improvements ✅ + +- [x] Active line highlight using `selectedBg` theme color +- [x] Filter modes via `^O` (forward) / `Shift+^O` (backward): + - `default`: hides label/custom entries + - `no-tools`: default minus tool results + - `user-only`: just user messages + - `labeled-only`: just labeled entries + - `all`: everything + +### Documentation + +Review and update all docs: + +- [ ] `docs/hooks.md` - Major update for hook API: + - `pi.send()` → `pi.sendMessage()` with new signature + - New `pi.appendEntry()` for state persistence + - New `pi.registerCommand()` for custom slash commands + - New `pi.registerCustomMessageRenderer()` for custom TUI rendering + - `HookCommandContext` interface and handler patterns + - `HookMessage` type + - Updated event signatures (`SessionEventBase`, `before_compact`, etc.) +- [ ] `docs/hooks-v2.md` - Review/merge or remove if obsolete +- [ ] `docs/sdk.md` - Update for: + - `HookMessage` and `isHookMessage()` + - `Agent.prompt(AppMessage)` overload + - Session v2 tree structure + - SessionManager API changes +- [ ] `docs/session.md` - Update for v2 tree structure, new entry types +- [ ] `docs/custom-tools.md` - Check if hook changes affect custom tools +- [ ] `docs/rpc.md` - Check if hook commands work in RPC mode +- [ ] `docs/skills.md` - Review for any hook-related updates +- [ ] `docs/extension-loading.md` - Review +- [x] `docs/theme.md` - Added selectedBg, customMessageBg/Text/Label color tokens (50 total) +- [ ] `README.md` - Update hook examples if any + +### Examples + +Review and update examples: + +- [ ] `examples/hooks/` - Update existing, add new examples: + - [ ] Review `custom-compaction.ts` for new API + - [ ] Add `registerCommand()` example + - [ ] Add `sendMessage()` example + - [ ] Add `registerCustomMessageRenderer()` example +- [ ] `examples/sdk/` - Update for new session/hook APIs +- [ ] `examples/custom-tools/` - Review for compatibility + +--- + +## Before Release + +- [ ] Run full automated test suite: `npm test` +- [ ] Manual testing of tree navigation and branch summarization +- [ ] Verify compaction with file tracking works correctly + +--- + +## Notes + +- All append methods return the new entry's ID +- Migration rewrites file on first load if version < CURRENT_VERSION +- Existing sessions become linear chains after migration (parentId = previous entry) +- Tree features available immediately after migration +- SessionHeader does NOT have id/parentId (it's metadata, not part of tree) +- Session is append-only: entries cannot be modified or deleted, only branching changes the leaf pointer diff --git a/packages/coding-agent/docs/theme.md b/packages/coding-agent/docs/theme.md index 06f674cb..ac3fa237 100644 --- a/packages/coding-agent/docs/theme.md +++ b/packages/coding-agent/docs/theme.md @@ -21,12 +21,16 @@ Every theme must define all color tokens. There are no optional colors. | `dim` | Very dimmed text | Less important info, placeholders | | `text` | Default text color | Main content (usually `""`) | -### Backgrounds & Content Text (7 colors) +### Backgrounds & Content Text (11 colors) | Token | Purpose | |-------|---------| +| `selectedBg` | Selected/active line background (e.g., tree selector) | | `userMessageBg` | User message background | | `userMessageText` | User message text color | +| `customMessageBg` | Hook custom message background | +| `customMessageText` | Hook custom message text color | +| `customMessageLabel` | Hook custom message label/type text | | `toolPendingBg` | Tool execution box (pending state) | | `toolSuccessBg` | Tool execution box (success state) | | `toolErrorBg` | Tool execution box (error state) | @@ -95,7 +99,7 @@ These create a visual hierarchy: off → minimal → low → medium → high → |-------|---------| | `bashMode` | Editor border color when in bash mode (! prefix) | -**Total: 46 color tokens** (all required) +**Total: 50 color tokens** (all required) ## Theme Format diff --git a/packages/coding-agent/docs/tree.md b/packages/coding-agent/docs/tree.md new file mode 100644 index 00000000..768a1cc0 --- /dev/null +++ b/packages/coding-agent/docs/tree.md @@ -0,0 +1,197 @@ +# Session Tree Navigation + +The `/tree` command provides tree-based navigation of the session history. + +## Overview + +Sessions are stored as trees where each entry has an `id` and `parentId`. The "leaf" pointer tracks the current position. `/tree` lets you navigate to any point and optionally summarize the branch you're leaving. + +### Comparison with `/branch` + +| Feature | `/branch` | `/tree` | +|---------|-----------|---------| +| View | Flat list of user messages | Full tree structure | +| Action | Extracts path to **new session file** | Changes leaf in **same session** | +| Summary | Never | Optional (user prompted) | +| Events | `session_before_branch` / `session_branch` | `session_before_tree` / `session_tree` | + +## Tree UI + +``` +├─ user: "Hello, can you help..." +│ └─ assistant: "Of course! I can..." +│ ├─ user: "Let's try approach A..." +│ │ └─ assistant: "For approach A..." +│ │ └─ [compaction: 12k tokens] +│ │ └─ user: "That worked..." ← active +│ └─ user: "Actually, approach B..." +│ └─ assistant: "For approach B..." +``` + +### Controls + +| Key | Action | +|-----|--------| +| ↑/↓ | Navigate (depth-first order) | +| Enter | Select node | +| Escape/Ctrl+C | Cancel | +| Ctrl+U | Toggle: user messages only | +| Ctrl+O | Toggle: show all (including custom/label entries) | + +### Display + +- Height: half terminal height +- Current leaf marked with `← active` +- Labels shown inline: `[label-name]` +- Default filter hides `label` and `custom` entries (shown in Ctrl+O mode) +- Children sorted by timestamp (oldest first) + +## Selection Behavior + +### User Message or Custom Message +1. Leaf set to **parent** of selected node (or `null` if root) +2. Message text placed in **editor** for re-submission +3. User edits and submits, creating a new branch + +### Non-User Message (assistant, compaction, etc.) +1. Leaf set to **selected node** +2. Editor stays empty +3. User continues from that point + +### Selecting Root User Message +If user selects the very first message (has no parent): +1. Leaf reset to `null` (empty conversation) +2. Message text placed in editor +3. User effectively restarts from scratch + +## Branch Summarization + +When switching, user is prompted: "Summarize the branch you're leaving?" + +### What Gets Summarized + +Path from old leaf back to common ancestor with target: + +``` +A → B → C → D → E → F ← old leaf + ↘ G → H ← target +``` + +Abandoned path: D → E → F (summarized) + +Summarization stops at: +1. Common ancestor (always) +2. Compaction node (if encountered first) + +### Summary Storage + +Stored as `BranchSummaryEntry`: + +```typescript +interface BranchSummaryEntry { + type: "branch_summary"; + id: string; + parentId: string; // New leaf position + timestamp: string; + fromId: string; // Old leaf we abandoned + summary: string; // LLM-generated summary + details?: unknown; // Optional hook data +} +``` + +## Implementation + +### AgentSession.navigateTree() + +```typescript +async navigateTree( + targetId: string, + options?: { summarize?: boolean; customInstructions?: string } +): Promise<{ editorText?: string; cancelled: boolean }> +``` + +Flow: +1. Validate target, check no-op (target === current leaf) +2. Find common ancestor between old leaf and target +3. Collect entries to summarize (if requested) +4. Fire `session_before_tree` event (hook can cancel or provide summary) +5. Run default summarizer if needed +6. Switch leaf via `branch()` or `branchWithSummary()` +7. Update agent: `agent.replaceMessages(sessionManager.buildSessionContext().messages)` +8. Fire `session_tree` event +9. Notify custom tools via session event +10. Return result with `editorText` if user message was selected + +### SessionManager + +- `getLeafUuid(): string | null` - Current leaf (null if empty) +- `resetLeaf(): void` - Set leaf to null (for root user message navigation) +- `getTree(): SessionTreeNode[]` - Full tree with children sorted by timestamp +- `branch(id)` - Change leaf pointer +- `branchWithSummary(id, summary)` - Change leaf and create summary entry + +### InteractiveMode + +`/tree` command shows `TreeSelectorComponent`, then: +1. Prompt for summarization +2. Call `session.navigateTree()` +3. Clear and re-render chat +4. Set editor text if applicable + +## Hook Events + +### `session_before_tree` + +```typescript +interface TreePreparation { + targetId: string; + oldLeafId: string | null; + commonAncestorId: string | null; + entriesToSummarize: SessionEntry[]; + userWantsSummary: boolean; +} + +interface SessionBeforeTreeEvent { + type: "session_before_tree"; + preparation: TreePreparation; + model: Model; + signal: AbortSignal; +} + +interface SessionBeforeTreeResult { + cancel?: boolean; + summary?: { summary: string; details?: unknown }; +} +``` + +### `session_tree` + +```typescript +interface SessionTreeEvent { + type: "session_tree"; + newLeafId: string | null; + oldLeafId: string | null; + summaryEntry?: BranchSummaryEntry; + fromHook?: boolean; +} +``` + +### Example: Custom Summarizer + +```typescript +export default function(pi: HookAPI) { + pi.on("session_before_tree", async (event, ctx) => { + if (!event.preparation.userWantsSummary) return; + if (event.preparation.entriesToSummarize.length === 0) return; + + const summary = await myCustomSummarizer(event.preparation.entriesToSummarize); + return { summary: { summary, details: { custom: true } } }; + }); +} +``` + +## Error Handling + +- Summarization failure: cancels navigation, shows error +- User abort (Escape): cancels navigation +- Hook returns `cancel: true`: cancels navigation silently diff --git a/packages/coding-agent/examples/custom-tools/question/index.ts b/packages/coding-agent/examples/custom-tools/question/index.ts index 23de95b7..76c068ca 100644 --- a/packages/coding-agent/examples/custom-tools/question/index.ts +++ b/packages/coding-agent/examples/custom-tools/question/index.ts @@ -41,7 +41,7 @@ const factory: CustomToolFactory = (pi) => { const answer = await pi.ui.select(params.question, params.options); - if (answer === null) { + if (answer === undefined) { return { content: [{ type: "text", text: "User cancelled the selection" }], details: { question: params.question, options: params.options, answer: null }, diff --git a/packages/coding-agent/examples/custom-tools/subagent/index.ts b/packages/coding-agent/examples/custom-tools/subagent/index.ts index 406ee4ae..67a2d526 100644 --- a/packages/coding-agent/examples/custom-tools/subagent/index.ts +++ b/packages/coding-agent/examples/custom-tools/subagent/index.ts @@ -16,7 +16,8 @@ import { spawn } from "node:child_process"; import * as fs from "node:fs"; import * as os from "node:os"; import * as path from "node:path"; -import type { AgentToolResult, Message } from "@mariozechner/pi-ai"; +import type { AgentToolResult } from "@mariozechner/pi-agent-core"; +import type { Message } from "@mariozechner/pi-ai"; import { StringEnum } from "@mariozechner/pi-ai"; import { type CustomAgentTool, diff --git a/packages/coding-agent/examples/hooks/auto-commit-on-exit.ts b/packages/coding-agent/examples/hooks/auto-commit-on-exit.ts index 8d167882..bfcc37f7 100644 --- a/packages/coding-agent/examples/hooks/auto-commit-on-exit.ts +++ b/packages/coding-agent/examples/hooks/auto-commit-on-exit.ts @@ -8,11 +8,9 @@ import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; export default function (pi: HookAPI) { - pi.on("session", async (event, ctx) => { - if (event.reason !== "shutdown") return; - + pi.on("session_shutdown", async (_event, ctx) => { // Check for uncommitted changes - const { stdout: status, code } = await ctx.exec("git", ["status", "--porcelain"]); + const { stdout: status, code } = await pi.exec("git", ["status", "--porcelain"]); if (code !== 0 || status.trim().length === 0) { // Not a git repo or no changes @@ -20,9 +18,10 @@ export default function (pi: HookAPI) { } // Find the last assistant message for commit context + const entries = ctx.sessionManager.getEntries(); let lastAssistantText = ""; - for (let i = event.entries.length - 1; i >= 0; i--) { - const entry = event.entries[i]; + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; if (entry.type === "message" && entry.message.role === "assistant") { const content = entry.message.content; if (Array.isArray(content)) { @@ -40,8 +39,8 @@ export default function (pi: HookAPI) { const commitMessage = `[pi] ${firstLine.slice(0, 50)}${firstLine.length > 50 ? "..." : ""}`; // Stage and commit - await ctx.exec("git", ["add", "-A"]); - const { code: commitCode } = await ctx.exec("git", ["commit", "-m", commitMessage]); + await pi.exec("git", ["add", "-A"]); + const { code: commitCode } = await pi.exec("git", ["commit", "-m", commitMessage]); if (commitCode === 0 && ctx.hasUI) { ctx.ui.notify(`Auto-committed: ${commitMessage}`, "info"); diff --git a/packages/coding-agent/examples/hooks/confirm-destructive.ts b/packages/coding-agent/examples/hooks/confirm-destructive.ts index 61293868..75c5ee0c 100644 --- a/packages/coding-agent/examples/hooks/confirm-destructive.ts +++ b/packages/coding-agent/examples/hooks/confirm-destructive.ts @@ -2,59 +2,57 @@ * Confirm Destructive Actions Hook * * Prompts for confirmation before destructive session actions (clear, switch, branch). - * Demonstrates how to cancel session events using the before_* variants. + * Demonstrates how to cancel session events using the before_* events. */ +import type { SessionMessageEntry } from "@mariozechner/pi-coding-agent"; import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; export default function (pi: HookAPI) { - pi.on("session", async (event, ctx) => { - // Only handle before_* events (the ones that can be cancelled) - if (event.reason === "before_new") { - if (!ctx.hasUI) return; + pi.on("session_before_new", async (_event, ctx) => { + if (!ctx.hasUI) return; + const confirmed = await ctx.ui.confirm("Clear session?", "This will delete all messages in the current session."); + + if (!confirmed) { + ctx.ui.notify("Clear cancelled", "info"); + return { cancel: true }; + } + }); + + pi.on("session_before_switch", async (_event, ctx) => { + if (!ctx.hasUI) return; + + // Check if there are unsaved changes (messages since last assistant response) + const entries = ctx.sessionManager.getEntries(); + const hasUnsavedWork = entries.some( + (e): e is SessionMessageEntry => e.type === "message" && e.message.role === "user", + ); + + if (hasUnsavedWork) { const confirmed = await ctx.ui.confirm( - "Clear session?", - "This will delete all messages in the current session.", + "Switch session?", + "You have messages in the current session. Switch anyway?", ); if (!confirmed) { - ctx.ui.notify("Clear cancelled", "info"); - return { cancel: true }; - } - } - - if (event.reason === "before_switch") { - if (!ctx.hasUI) return; - - // Check if there are unsaved changes (messages since last assistant response) - const hasUnsavedWork = event.entries.some((e) => e.type === "message" && e.message.role === "user"); - - if (hasUnsavedWork) { - const confirmed = await ctx.ui.confirm( - "Switch session?", - "You have messages in the current session. Switch anyway?", - ); - - if (!confirmed) { - ctx.ui.notify("Switch cancelled", "info"); - return { cancel: true }; - } - } - } - - if (event.reason === "before_branch") { - if (!ctx.hasUI) return; - - const choice = await ctx.ui.select(`Branch from turn ${event.targetTurnIndex}?`, [ - "Yes, create branch", - "No, stay in current session", - ]); - - if (choice !== "Yes, create branch") { - ctx.ui.notify("Branch cancelled", "info"); + ctx.ui.notify("Switch cancelled", "info"); return { cancel: true }; } } }); + + pi.on("session_before_branch", async (event, ctx) => { + if (!ctx.hasUI) return; + + const choice = await ctx.ui.select(`Branch from turn ${event.entryIndex}?`, [ + "Yes, create branch", + "No, stay in current session", + ]); + + if (choice !== "Yes, create branch") { + ctx.ui.notify("Branch cancelled", "info"); + return { cancel: true }; + } + }); } diff --git a/packages/coding-agent/examples/hooks/custom-compaction.ts b/packages/coding-agent/examples/hooks/custom-compaction.ts index 3967b1ce..f912a3d6 100644 --- a/packages/coding-agent/examples/hooks/custom-compaction.ts +++ b/packages/coding-agent/examples/hooks/custom-compaction.ts @@ -14,17 +14,18 @@ */ import { complete, getModel } from "@mariozechner/pi-ai"; -import { messageTransformer } from "@mariozechner/pi-coding-agent"; +import { convertToLlm } from "@mariozechner/pi-coding-agent"; import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; export default function (pi: HookAPI) { - pi.on("session", async (event, ctx) => { - if (event.reason !== "before_compact") return; - + pi.on("session_before_compact", async (event, ctx) => { ctx.ui.notify("Custom compaction hook triggered", "info"); - const { messagesToSummarize, messagesToKeep, previousSummary, tokensBefore, resolveApiKey, entries, signal } = - event; + const { preparation, previousCompactions, signal } = event; + const { messagesToSummarize, messagesToKeep, tokensBefore, firstKeptEntryId } = preparation; + + // Get previous summary from most recent compaction (if any) + const previousSummary = previousCompactions[0]?.summary; // Use Gemini Flash for summarization (cheaper/faster than most conversation models) const model = getModel("google", "gemini-2.5-flash"); @@ -34,7 +35,7 @@ export default function (pi: HookAPI) { } // Resolve API key for the summarization model - const apiKey = await resolveApiKey(model); + const apiKey = await ctx.modelRegistry.getApiKey(model); if (!apiKey) { ctx.ui.notify(`No API key for ${model.provider}, using default compaction`, "warning"); return; @@ -49,7 +50,7 @@ export default function (pi: HookAPI) { ); // Transform app messages to pi-ai package format - const transformedMessages = messageTransformer(allMessages); + const transformedMessages = convertToLlm(allMessages); // Include previous summary context if available const previousContext = previousSummary ? `\n\nPrevious session summary for context:\n${previousSummary}` : ""; @@ -94,14 +95,12 @@ Format the summary as structured markdown with clear sections.`, return; } - // Return a compaction entry that discards ALL messages - // firstKeptEntryIndex points past all current entries + // Return compaction content - SessionManager adds id/parentId + // Use firstKeptEntryId from preparation to keep recent messages return { - compactionEntry: { - type: "compaction" as const, - timestamp: new Date().toISOString(), + compaction: { summary, - firstKeptEntryIndex: entries.length, + firstKeptEntryId, tokensBefore, }, }; diff --git a/packages/coding-agent/examples/hooks/dirty-repo-guard.ts b/packages/coding-agent/examples/hooks/dirty-repo-guard.ts index 8e6e5d66..134667e3 100644 --- a/packages/coding-agent/examples/hooks/dirty-repo-guard.ts +++ b/packages/coding-agent/examples/hooks/dirty-repo-guard.ts @@ -5,47 +5,55 @@ * Useful to ensure work is committed before switching context. */ -import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; +import type { HookAPI, HookEventContext } from "@mariozechner/pi-coding-agent/hooks"; + +async function checkDirtyRepo( + pi: HookAPI, + ctx: HookEventContext, + action: string, +): Promise<{ cancel: boolean } | undefined> { + // Check for uncommitted changes + const { stdout, code } = await pi.exec("git", ["status", "--porcelain"]); + + if (code !== 0) { + // Not a git repo, allow the action + return; + } + + const hasChanges = stdout.trim().length > 0; + if (!hasChanges) { + return; + } + + if (!ctx.hasUI) { + // In non-interactive mode, block by default + return { cancel: true }; + } + + // Count changed files + const changedFiles = stdout.trim().split("\n").filter(Boolean).length; + + const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [ + "Yes, proceed anyway", + "No, let me commit first", + ]); + + if (choice !== "Yes, proceed anyway") { + ctx.ui.notify("Commit your changes first", "warning"); + return { cancel: true }; + } +} export default function (pi: HookAPI) { - pi.on("session", async (event, ctx) => { - // Only guard destructive actions - if (event.reason !== "before_new" && event.reason !== "before_switch" && event.reason !== "before_branch") { - return; - } + pi.on("session_before_new", async (_event, ctx) => { + return checkDirtyRepo(pi, ctx, "new session"); + }); - // Check for uncommitted changes - const { stdout, code } = await ctx.exec("git", ["status", "--porcelain"]); + pi.on("session_before_switch", async (_event, ctx) => { + return checkDirtyRepo(pi, ctx, "switch session"); + }); - if (code !== 0) { - // Not a git repo, allow the action - return; - } - - const hasChanges = stdout.trim().length > 0; - if (!hasChanges) { - return; - } - - if (!ctx.hasUI) { - // In non-interactive mode, block by default - return { cancel: true }; - } - - // Count changed files - const changedFiles = stdout.trim().split("\n").filter(Boolean).length; - - const action = - event.reason === "before_new" ? "new session" : event.reason === "before_switch" ? "switch session" : "branch"; - - const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [ - "Yes, proceed anyway", - "No, let me commit first", - ]); - - if (choice !== "Yes, proceed anyway") { - ctx.ui.notify("Commit your changes first", "warning"); - return { cancel: true }; - } + pi.on("session_before_branch", async (_event, ctx) => { + return checkDirtyRepo(pi, ctx, "branch"); }); } diff --git a/packages/coding-agent/examples/hooks/file-trigger.ts b/packages/coding-agent/examples/hooks/file-trigger.ts index df98d2a2..4363bdce 100644 --- a/packages/coding-agent/examples/hooks/file-trigger.ts +++ b/packages/coding-agent/examples/hooks/file-trigger.ts @@ -12,16 +12,21 @@ import * as fs from "node:fs"; import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; export default function (pi: HookAPI) { - pi.on("session", async (event, ctx) => { - if (event.reason !== "start") return; - + pi.on("session_start", async (_event, ctx) => { const triggerFile = "/tmp/agent-trigger.txt"; fs.watch(triggerFile, () => { try { const content = fs.readFileSync(triggerFile, "utf-8").trim(); if (content) { - pi.send(`External trigger: ${content}`); + pi.sendMessage( + { + customType: "file-trigger", + content: `External trigger: ${content}`, + display: true, + }, + true, // triggerTurn - get LLM to respond + ); fs.writeFileSync(triggerFile, ""); // Clear after reading } } catch { diff --git a/packages/coding-agent/examples/hooks/git-checkpoint.ts b/packages/coding-agent/examples/hooks/git-checkpoint.ts index f7a73a7f..87c8f0b5 100644 --- a/packages/coding-agent/examples/hooks/git-checkpoint.ts +++ b/packages/coding-agent/examples/hooks/git-checkpoint.ts @@ -10,20 +10,17 @@ import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks"; export default function (pi: HookAPI) { const checkpoints = new Map(); - pi.on("turn_start", async (event, ctx) => { + pi.on("turn_start", async (event) => { // Create a git stash entry before LLM makes changes - const { stdout } = await ctx.exec("git", ["stash", "create"]); + const { stdout } = await pi.exec("git", ["stash", "create"]); const ref = stdout.trim(); if (ref) { checkpoints.set(event.turnIndex, ref); } }); - pi.on("session", async (event, ctx) => { - // Only handle before_branch events - if (event.reason !== "before_branch") return; - - const ref = checkpoints.get(event.targetTurnIndex); + pi.on("session_before_branch", async (event, ctx) => { + const ref = checkpoints.get(event.entryIndex); if (!ref) return; if (!ctx.hasUI) { @@ -37,7 +34,7 @@ export default function (pi: HookAPI) { ]); if (choice?.startsWith("Yes")) { - await ctx.exec("git", ["stash", "apply", ref]); + await pi.exec("git", ["stash", "apply", ref]); ctx.ui.notify("Code restored to checkpoint", "info"); } }); diff --git a/packages/coding-agent/examples/hooks/snake.ts b/packages/coding-agent/examples/hooks/snake.ts new file mode 100644 index 00000000..3b5ffe8d --- /dev/null +++ b/packages/coding-agent/examples/hooks/snake.ts @@ -0,0 +1,345 @@ +/** + * Snake game hook - play snake with /snake command + */ + +import { isArrowDown, isArrowLeft, isArrowRight, isArrowUp, isEscape, visibleWidth } from "@mariozechner/pi-tui"; +import type { HookAPI } from "../../src/core/hooks/types.js"; + +const GAME_WIDTH = 40; +const GAME_HEIGHT = 15; +const TICK_MS = 100; + +type Direction = "up" | "down" | "left" | "right"; +type Point = { x: number; y: number }; + +interface GameState { + snake: Point[]; + food: Point; + direction: Direction; + nextDirection: Direction; + score: number; + gameOver: boolean; + highScore: number; +} + +function createInitialState(): GameState { + const startX = Math.floor(GAME_WIDTH / 2); + const startY = Math.floor(GAME_HEIGHT / 2); + return { + snake: [ + { x: startX, y: startY }, + { x: startX - 1, y: startY }, + { x: startX - 2, y: startY }, + ], + food: spawnFood([{ x: startX, y: startY }]), + direction: "right", + nextDirection: "right", + score: 0, + gameOver: false, + highScore: 0, + }; +} + +function spawnFood(snake: Point[]): Point { + let food: Point; + do { + food = { + x: Math.floor(Math.random() * GAME_WIDTH), + y: Math.floor(Math.random() * GAME_HEIGHT), + }; + } while (snake.some((s) => s.x === food.x && s.y === food.y)); + return food; +} + +class SnakeComponent { + private state: GameState; + private interval: ReturnType | null = null; + private onClose: () => void; + private onSave: (state: GameState | null) => void; + private requestRender: () => void; + private cachedLines: string[] = []; + private cachedWidth = 0; + private version = 0; + private cachedVersion = -1; + private paused: boolean; + + constructor( + onClose: () => void, + onSave: (state: GameState | null) => void, + requestRender: () => void, + savedState?: GameState, + ) { + if (savedState && !savedState.gameOver) { + // Resume from saved state, start paused + this.state = savedState; + this.paused = true; + } else { + // New game or saved game was over + this.state = createInitialState(); + if (savedState) { + this.state.highScore = savedState.highScore; + } + this.paused = false; + this.startGame(); + } + this.onClose = onClose; + this.onSave = onSave; + this.requestRender = requestRender; + } + + private startGame(): void { + this.interval = setInterval(() => { + if (!this.state.gameOver) { + this.tick(); + this.version++; + this.requestRender(); + } + }, TICK_MS); + } + + private tick(): void { + // Apply queued direction change + this.state.direction = this.state.nextDirection; + + // Calculate new head position + const head = this.state.snake[0]; + let newHead: Point; + + switch (this.state.direction) { + case "up": + newHead = { x: head.x, y: head.y - 1 }; + break; + case "down": + newHead = { x: head.x, y: head.y + 1 }; + break; + case "left": + newHead = { x: head.x - 1, y: head.y }; + break; + case "right": + newHead = { x: head.x + 1, y: head.y }; + break; + } + + // Check wall collision + if (newHead.x < 0 || newHead.x >= GAME_WIDTH || newHead.y < 0 || newHead.y >= GAME_HEIGHT) { + this.state.gameOver = true; + return; + } + + // Check self collision + if (this.state.snake.some((s) => s.x === newHead.x && s.y === newHead.y)) { + this.state.gameOver = true; + return; + } + + // Move snake + this.state.snake.unshift(newHead); + + // Check food collision + if (newHead.x === this.state.food.x && newHead.y === this.state.food.y) { + this.state.score += 10; + if (this.state.score > this.state.highScore) { + this.state.highScore = this.state.score; + } + this.state.food = spawnFood(this.state.snake); + } else { + this.state.snake.pop(); + } + } + + handleInput(data: string): void { + // If paused (resuming), wait for any key + if (this.paused) { + if (isEscape(data) || data === "q" || data === "Q") { + // Quit without clearing save + this.dispose(); + this.onClose(); + return; + } + // Any other key resumes + this.paused = false; + this.startGame(); + return; + } + + // ESC to pause and save + if (isEscape(data)) { + this.dispose(); + this.onSave(this.state); + this.onClose(); + return; + } + + // Q to quit without saving (clears saved state) + if (data === "q" || data === "Q") { + this.dispose(); + this.onSave(null); // Clear saved state + this.onClose(); + return; + } + + // Arrow keys or WASD + if (isArrowUp(data) || data === "w" || data === "W") { + if (this.state.direction !== "down") this.state.nextDirection = "up"; + } else if (isArrowDown(data) || data === "s" || data === "S") { + if (this.state.direction !== "up") this.state.nextDirection = "down"; + } else if (isArrowRight(data) || data === "d" || data === "D") { + if (this.state.direction !== "left") this.state.nextDirection = "right"; + } else if (isArrowLeft(data) || data === "a" || data === "A") { + if (this.state.direction !== "right") this.state.nextDirection = "left"; + } + + // Restart on game over + if (this.state.gameOver && (data === "r" || data === "R" || data === " ")) { + const highScore = this.state.highScore; + this.state = createInitialState(); + this.state.highScore = highScore; + this.onSave(null); // Clear saved state on restart + this.version++; + this.requestRender(); + } + } + + invalidate(): void { + this.cachedWidth = 0; + } + + render(width: number): string[] { + if (width === this.cachedWidth && this.cachedVersion === this.version) { + return this.cachedLines; + } + + const lines: string[] = []; + + // Each game cell is 2 chars wide to appear square (terminal cells are ~2:1 aspect) + const cellWidth = 2; + const effectiveWidth = Math.min(GAME_WIDTH, Math.floor((width - 4) / cellWidth)); + const effectiveHeight = GAME_HEIGHT; + + // Colors + const dim = (s: string) => `\x1b[2m${s}\x1b[22m`; + const green = (s: string) => `\x1b[32m${s}\x1b[0m`; + const red = (s: string) => `\x1b[31m${s}\x1b[0m`; + const yellow = (s: string) => `\x1b[33m${s}\x1b[0m`; + const bold = (s: string) => `\x1b[1m${s}\x1b[22m`; + + const boxWidth = effectiveWidth * cellWidth; + + // Helper to pad content inside box + const boxLine = (content: string) => { + const contentLen = visibleWidth(content); + const padding = Math.max(0, boxWidth - contentLen); + return dim(" │") + content + " ".repeat(padding) + dim("│"); + }; + + // Top border + lines.push(this.padLine(dim(` ╭${"─".repeat(boxWidth)}╮`), width)); + + // Header with score + const scoreText = `Score: ${bold(yellow(String(this.state.score)))}`; + const highText = `High: ${bold(yellow(String(this.state.highScore)))}`; + const title = `${bold(green("SNAKE"))} │ ${scoreText} │ ${highText}`; + lines.push(this.padLine(boxLine(title), width)); + + // Separator + lines.push(this.padLine(dim(` ├${"─".repeat(boxWidth)}┤`), width)); + + // Game grid + for (let y = 0; y < effectiveHeight; y++) { + let row = ""; + for (let x = 0; x < effectiveWidth; x++) { + const isHead = this.state.snake[0].x === x && this.state.snake[0].y === y; + const isBody = this.state.snake.slice(1).some((s) => s.x === x && s.y === y); + const isFood = this.state.food.x === x && this.state.food.y === y; + + if (isHead) { + row += green("██"); // Snake head (2 chars) + } else if (isBody) { + row += green("▓▓"); // Snake body (2 chars) + } else if (isFood) { + row += red("◆ "); // Food (2 chars) + } else { + row += " "; // Empty cell (2 spaces) + } + } + lines.push(this.padLine(dim(" │") + row + dim("│"), width)); + } + + // Separator + lines.push(this.padLine(dim(` ├${"─".repeat(boxWidth)}┤`), width)); + + // Footer + let footer: string; + if (this.paused) { + footer = `${yellow(bold("PAUSED"))} Press any key to continue, ${bold("Q")} to quit`; + } else if (this.state.gameOver) { + footer = `${red(bold("GAME OVER!"))} Press ${bold("R")} to restart, ${bold("Q")} to quit`; + } else { + footer = `↑↓←→ or WASD to move, ${bold("ESC")} pause, ${bold("Q")} quit`; + } + lines.push(this.padLine(boxLine(footer), width)); + + // Bottom border + lines.push(this.padLine(dim(` ╰${"─".repeat(boxWidth)}╯`), width)); + + this.cachedLines = lines; + this.cachedWidth = width; + this.cachedVersion = this.version; + + return lines; + } + + private padLine(line: string, width: number): string { + // Calculate visible length (strip ANSI codes) + const visibleLen = line.replace(/\x1b\[[0-9;]*m/g, "").length; + const padding = Math.max(0, width - visibleLen); + return line + " ".repeat(padding); + } + + dispose(): void { + if (this.interval) { + clearInterval(this.interval); + this.interval = null; + } + } +} + +const SNAKE_SAVE_TYPE = "snake-save"; + +export default function (pi: HookAPI) { + pi.registerCommand("snake", { + description: "Play Snake!", + + handler: async (ctx) => { + if (!ctx.hasUI) { + ctx.ui.notify("Snake requires interactive mode", "error"); + return; + } + + // Load saved state from session + const entries = ctx.sessionManager.getEntries(); + let savedState: GameState | undefined; + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; + if (entry.type === "custom" && entry.customType === SNAKE_SAVE_TYPE) { + savedState = entry.data as GameState; + break; + } + } + + let ui: { close: () => void; requestRender: () => void } | null = null; + + const component = new SnakeComponent( + () => ui?.close(), + (state) => { + // Save or clear state + pi.appendEntry(SNAKE_SAVE_TYPE, state); + }, + () => ui?.requestRender(), + savedState, + ); + + ui = ctx.ui.custom(component); + }, + }); +} diff --git a/packages/coding-agent/src/cli/file-processor.ts b/packages/coding-agent/src/cli/file-processor.ts index 3afce9c7..7f82d796 100644 --- a/packages/coding-agent/src/cli/file-processor.ts +++ b/packages/coding-agent/src/cli/file-processor.ts @@ -3,21 +3,21 @@ */ import { access, readFile, stat } from "node:fs/promises"; -import type { Attachment } from "@mariozechner/pi-agent-core"; +import type { ImageContent } from "@mariozechner/pi-ai"; import chalk from "chalk"; import { resolve } from "path"; import { resolveReadPath } from "../core/tools/path-utils.js"; import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js"; export interface ProcessedFiles { - textContent: string; - imageAttachments: Attachment[]; + text: string; + images: ImageContent[]; } /** Process @file arguments into text content and image attachments */ export async function processFileArguments(fileArgs: string[]): Promise { - let textContent = ""; - const imageAttachments: Attachment[] = []; + let text = ""; + const images: ImageContent[] = []; for (const fileArg of fileArgs) { // Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces) @@ -45,24 +45,21 @@ export async function processFileArguments(fileArgs: string[]): Promise\n`; + text += `\n`; } else { // Handle text file try { const content = await readFile(absolutePath, "utf-8"); - textContent += `\n${content}\n\n`; + text += `\n${content}\n\n`; } catch (error: unknown) { const message = error instanceof Error ? error.message : String(error); console.error(chalk.red(`Error: Could not read file ${absolutePath}: ${message}`)); @@ -71,5 +68,5 @@ export async function processFileArguments(fileArgs: string[]): Promise; @@ -64,8 +83,8 @@ export interface AgentSessionConfig { export interface PromptOptions { /** Whether to expand file-based slash commands (default: true) */ expandSlashCommands?: boolean; - /** Image/file attachments */ - attachments?: Attachment[]; + /** Image attachments */ + images?: ImageContent[]; } /** Result from cycleModel() */ @@ -76,15 +95,9 @@ export interface ModelCycleResult { isScoped: boolean; } -/** Result from compact() or checkAutoCompaction() */ -export interface CompactionResult { - tokensBefore: number; - summary: string; -} - /** Session statistics for /session command */ export interface SessionStats { - sessionFile: string | null; + sessionFile: string | undefined; sessionId: string; userMessages: number; assistantMessages: number; @@ -101,6 +114,7 @@ export interface SessionStats { cost: number; } +/** Internal marker for hook messages queued through the agent loop */ // ============================================================================ // Constants // ============================================================================ @@ -131,21 +145,24 @@ export class AgentSession { private _queuedMessages: string[] = []; // Compaction state - private _compactionAbortController: AbortController | null = null; - private _autoCompactionAbortController: AbortController | null = null; + private _compactionAbortController: AbortController | undefined = undefined; + private _autoCompactionAbortController: AbortController | undefined = undefined; + + // Branch summarization state + private _branchSummaryAbortController: AbortController | undefined = undefined; // Retry state - private _retryAbortController: AbortController | null = null; + private _retryAbortController: AbortController | undefined = undefined; private _retryAttempt = 0; - private _retryPromise: Promise | null = null; - private _retryResolve: (() => void) | null = null; + private _retryPromise: Promise | undefined = undefined; + private _retryResolve: (() => void) | undefined = undefined; // Bash execution state - private _bashAbortController: AbortController | null = null; + private _bashAbortController: AbortController | undefined = undefined; private _pendingBashMessages: BashExecutionMessage[] = []; // Hook system - private _hookRunner: HookRunner | null = null; + private _hookRunner: HookRunner | undefined = undefined; private _turnIndex = 0; // Custom tools for session lifecycle @@ -162,10 +179,14 @@ export class AgentSession { this.settingsManager = config.settingsManager; this._scopedModels = config.scopedModels ?? []; this._fileCommands = config.fileCommands ?? []; - this._hookRunner = config.hookRunner ?? null; + this._hookRunner = config.hookRunner; this._customTools = config.customTools ?? []; this._skillsSettings = config.skillsSettings; this._modelRegistry = config.modelRegistry; + + // Always subscribe to agent events for internal handling + // (session persistence, hooks, auto-compaction, retry logic) + this._unsubscribeAgent = this.agent.subscribe(this._handleAgentEvent); } /** Model registry for API key resolution and model discovery */ @@ -185,7 +206,7 @@ export class AgentSession { } // Track last assistant message for auto-compaction check - private _lastAssistantMessage: AssistantMessage | null = null; + private _lastAssistantMessage: AssistantMessage | undefined = undefined; /** Internal handler for agent events - shared by subscribe and reconnect */ private _handleAgentEvent = async (event: AgentEvent): Promise => { @@ -211,7 +232,24 @@ export class AgentSession { // Handle session persistence if (event.type === "message_end") { - this.sessionManager.saveMessage(event.message); + // Check if this is a hook message + if (event.message.role === "hookMessage") { + // Persist as CustomMessageEntry + this.sessionManager.appendCustomMessageEntry( + event.message.customType, + event.message.content, + event.message.display, + event.message.details, + ); + } else if ( + event.message.role === "user" || + event.message.role === "assistant" || + event.message.role === "toolResult" + ) { + // Regular LLM message - persist as SessionMessageEntry + this.sessionManager.appendMessage(event.message); + } + // Other message types (bashExecution, compactionSummary, branchSummary) are persisted elsewhere // Track assistant message for auto-compaction (checked on agent_end) if (event.message.role === "assistant") { @@ -222,7 +260,7 @@ export class AgentSession { // Check auto-retry and auto-compaction after agent completes if (event.type === "agent_end" && this._lastAssistantMessage) { const msg = this._lastAssistantMessage; - this._lastAssistantMessage = null; + this._lastAssistantMessage = undefined; // Check for retryable errors first (overloaded, rate limit, server errors) if (this._isRetryableError(msg)) { @@ -248,8 +286,8 @@ export class AgentSession { private _resolveRetry(): void { if (this._retryResolve) { this._retryResolve(); - this._retryResolve = null; - this._retryPromise = null; + this._retryResolve = undefined; + this._retryPromise = undefined; } } @@ -263,7 +301,7 @@ export class AgentSession { } /** Find the last assistant message in agent state (including aborted ones) */ - private _findLastAssistantMessage(): AssistantMessage | null { + private _findLastAssistantMessage(): AssistantMessage | undefined { const messages = this.agent.state.messages; for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i]; @@ -271,7 +309,7 @@ export class AgentSession { return msg as AssistantMessage; } } - return null; + return undefined; } /** Emit hook events based on agent events */ @@ -310,11 +348,6 @@ export class AgentSession { subscribe(listener: AgentSessionEventListener): () => void { this._eventListeners.push(listener); - // Set up agent subscription if not already done - if (!this._unsubscribeAgent) { - this._unsubscribeAgent = this.agent.subscribe(this._handleAgentEvent); - } - // Return unsubscribe function for this specific listener return () => { const index = this._eventListeners.indexOf(listener); @@ -363,8 +396,8 @@ export class AgentSession { return this.agent.state; } - /** Current model (may be null if not yet selected) */ - get model(): Model | null { + /** Current model (may be undefined if not yet selected) */ + get model(): Model | undefined { return this.agent.state.model; } @@ -380,11 +413,11 @@ export class AgentSession { /** Whether auto-compaction is currently running */ get isCompacting(): boolean { - return this._autoCompactionAbortController !== null || this._compactionAbortController !== null; + return this._autoCompactionAbortController !== undefined || this._compactionAbortController !== undefined; } /** All messages including custom types like BashExecutionMessage */ - get messages(): AppMessage[] { + get messages(): AgentMessage[] { return this.agent.state.messages; } @@ -393,9 +426,9 @@ export class AgentSession { return this.agent.getQueueMode(); } - /** Current session file path, or null if sessions are disabled */ - get sessionFile(): string | null { - return this.sessionManager.isPersisted() ? this.sessionManager.getSessionFile() : null; + /** Current session file path, or undefined if sessions are disabled */ + get sessionFile(): string | undefined { + return this.sessionManager.getSessionFile(); } /** Current session ID */ @@ -420,6 +453,7 @@ export class AgentSession { /** * Send a prompt to the agent. * - Validates model and API key before sending + * - Handles hook commands (registered via pi.registerCommand) * - Expands file-based slash commands by default * @throws Error if no model selected or no API key available */ @@ -429,6 +463,15 @@ export class AgentSession { const expandCommands = options?.expandSlashCommands ?? true; + // Handle hook commands first (if enabled and text is a slash command) + if (expandCommands && text.startsWith("/")) { + const handled = await this._tryExecuteHookCommand(text); + if (handled) { + // Hook command executed, no prompt to send + return; + } + } + // Validate model if (!this.model) { throw new Error( @@ -453,13 +496,85 @@ export class AgentSession { await this._checkCompaction(lastAssistant, false); } - // Expand slash commands if requested + // Expand file-based slash commands if requested const expandedText = expandCommands ? expandSlashCommand(text, [...this._fileCommands]) : text; - await this.agent.prompt(expandedText, options?.attachments); + // Build messages array (hook message if any, then user message) + const messages: AgentMessage[] = []; + + // Add user message + const userContent: (TextContent | ImageContent)[] = [{ type: "text", text: expandedText }]; + if (options?.images) { + userContent.push(...options.images); + } + messages.push({ + role: "user", + content: userContent, + timestamp: Date.now(), + }); + + // Emit before_agent_start hook event + if (this._hookRunner) { + const result = await this._hookRunner.emitBeforeAgentStart(expandedText, options?.images); + if (result?.message) { + messages.push({ + role: "hookMessage", + customType: result.message.customType, + content: result.message.content, + display: result.message.display, + details: result.message.details, + timestamp: Date.now(), + }); + } + } + + await this.agent.prompt(messages); await this.waitForRetry(); } + /** + * Try to execute a hook command. Returns true if command was found and executed. + */ + private async _tryExecuteHookCommand(text: string): Promise { + if (!this._hookRunner) return false; + + // Parse command name and args + const spaceIndex = text.indexOf(" "); + const commandName = spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex); + const args = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1); + + const command = this._hookRunner.getCommand(commandName); + if (!command) return false; + + // Get UI context from hook runner (set by mode) + const uiContext = this._hookRunner.getUIContext(); + if (!uiContext) return false; + + // Build command context + const cwd = process.cwd(); + const ctx: HookCommandContext = { + args, + ui: uiContext, + hasUI: this._hookRunner.getHasUI(), + cwd, + sessionManager: this.sessionManager, + modelRegistry: this._modelRegistry, + }; + + try { + await command.handler(ctx); + return true; + } catch (err) { + // Emit error via hook runner + this._hookRunner.emitError({ + hookPath: `command:${commandName}`, + event: "command", + error: err instanceof Error ? err.message : String(err), + }); + return true; + } + } + /** * Queue a message to be sent after the current response completes. * Use when agent is currently streaming. @@ -473,6 +588,47 @@ export class AgentSession { }); } + /** + * Send a hook message to the session. Creates a CustomMessageEntry. + * + * Handles three cases: + * - Streaming: queues message, processed when loop pulls from queue + * - Not streaming + triggerTurn: appends to state/session, starts new turn + * - Not streaming + no trigger: appends to state/session, no turn + * + * @param message Hook message with customType, content, display, details + * @param triggerTurn If true and not streaming, triggers a new LLM turn + */ + async sendHookMessage( + message: Pick, "customType" | "content" | "display" | "details">, + triggerTurn?: boolean, + ): Promise { + const appMessage = { + role: "hookMessage" as const, + customType: message.customType, + content: message.content, + display: message.display, + details: message.details, + timestamp: Date.now(), + } satisfies HookMessage; + if (this.isStreaming) { + // Queue for processing by agent loop + await this.agent.queueMessage(appMessage); + } else if (triggerTurn) { + // Send as prompt - agent loop will emit message events + await this.agent.prompt(appMessage); + } else { + // Just append to agent state and session, no turn + this.agent.appendMessage(appMessage); + this.sessionManager.appendCustomMessageEntry( + message.customType, + message.content, + message.display, + message.details, + ); + } + } + /** * Clear queued messages and return them. * Useful for restoring to editor when user aborts. @@ -515,17 +671,12 @@ export class AgentSession { */ async reset(): Promise { const previousSessionFile = this.sessionFile; - const entries = this.sessionManager.getEntries(); - // Emit before_new event (can be cancelled) - if (this._hookRunner?.hasHandlers("session")) { + // Emit session_before_new event (can be cancelled) + if (this._hookRunner?.hasHandlers("session_before_new")) { const result = (await this._hookRunner.emit({ - type: "session", - entries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "before_new", - })) as SessionEventResult | undefined; + type: "session_before_new", + })) as SessionBeforeNewResult | undefined; if (result?.cancel) { return false; @@ -535,19 +686,14 @@ export class AgentSession { this._disconnectFromAgent(); await this.abort(); this.agent.reset(); - this.sessionManager.reset(); + this.sessionManager.newSession(); this._queuedMessages = []; this._reconnectToAgent(); - // Emit session event with reason "new" to hooks + // Emit session_new event to hooks if (this._hookRunner) { - this._hookRunner.setSessionFile(this.sessionFile); await this._hookRunner.emit({ - type: "session", - entries: [], - sessionFile: this.sessionFile, - previousSessionFile, - reason: "new", + type: "session_new", }); } @@ -572,7 +718,7 @@ export class AgentSession { } this.agent.setModel(model); - this.sessionManager.saveModelChange(model.provider, model.id); + this.sessionManager.appendModelChange(model.provider, model.id); this.settingsManager.setDefaultModelAndProvider(model.provider, model.id); // Re-clamp thinking level for new model's capabilities @@ -583,17 +729,17 @@ export class AgentSession { * Cycle to next/previous model. * Uses scoped models (from --models flag) if available, otherwise all available models. * @param direction - "forward" (default) or "backward" - * @returns The new model info, or null if only one model available + * @returns The new model info, or undefined if only one model available */ - async cycleModel(direction: "forward" | "backward" = "forward"): Promise { + async cycleModel(direction: "forward" | "backward" = "forward"): Promise { if (this._scopedModels.length > 0) { return this._cycleScopedModel(direction); } return this._cycleAvailableModel(direction); } - private async _cycleScopedModel(direction: "forward" | "backward"): Promise { - if (this._scopedModels.length <= 1) return null; + private async _cycleScopedModel(direction: "forward" | "backward"): Promise { + if (this._scopedModels.length <= 1) return undefined; const currentModel = this.model; let currentIndex = this._scopedModels.findIndex((sm) => modelsAreEqual(sm.model, currentModel)); @@ -611,7 +757,7 @@ export class AgentSession { // Apply model this.agent.setModel(next.model); - this.sessionManager.saveModelChange(next.model.provider, next.model.id); + this.sessionManager.appendModelChange(next.model.provider, next.model.id); this.settingsManager.setDefaultModelAndProvider(next.model.provider, next.model.id); // Apply thinking level (setThinkingLevel clamps to model capabilities) @@ -620,9 +766,9 @@ export class AgentSession { return { model: next.model, thinkingLevel: this.thinkingLevel, isScoped: true }; } - private async _cycleAvailableModel(direction: "forward" | "backward"): Promise { + private async _cycleAvailableModel(direction: "forward" | "backward"): Promise { const availableModels = await this._modelRegistry.getAvailable(); - if (availableModels.length <= 1) return null; + if (availableModels.length <= 1) return undefined; const currentModel = this.model; let currentIndex = availableModels.findIndex((m) => modelsAreEqual(m, currentModel)); @@ -638,7 +784,7 @@ export class AgentSession { } this.agent.setModel(nextModel); - this.sessionManager.saveModelChange(nextModel.provider, nextModel.id); + this.sessionManager.appendModelChange(nextModel.provider, nextModel.id); this.settingsManager.setDefaultModelAndProvider(nextModel.provider, nextModel.id); // Re-clamp thinking level for new model's capabilities @@ -671,16 +817,16 @@ export class AgentSession { effectiveLevel = "high"; } this.agent.setThinkingLevel(effectiveLevel); - this.sessionManager.saveThinkingLevelChange(effectiveLevel); + this.sessionManager.appendThinkingLevelChange(effectiveLevel); this.settingsManager.setDefaultThinkingLevel(effectiveLevel); } /** * Cycle to next thinking level. - * @returns New level, or null if model doesn't support thinking + * @returns New level, or undefined if model doesn't support thinking */ - cycleThinkingLevel(): ThinkingLevel | null { - if (!this.supportsThinking()) return null; + cycleThinkingLevel(): ThinkingLevel | undefined { + if (!this.supportsThinking()) return undefined; const levels = this.getAvailableThinkingLevels(); const currentIndex = levels.indexOf(this.thinkingLevel); @@ -754,51 +900,54 @@ export class AgentSession { const preparation = prepareCompaction(entries, settings); if (!preparation) { - throw new Error("Already compacted"); - } - - // Find previous compaction summary if any - let previousSummary: string | undefined; - for (let i = entries.length - 1; i >= 0; i--) { - if (entries[i].type === "compaction") { - previousSummary = (entries[i] as CompactionEntry).summary; - break; + // Check why we can't compact + const lastEntry = entries[entries.length - 1]; + if (lastEntry?.type === "compaction") { + throw new Error("Already compacted"); } + throw new Error("Nothing to compact (session too small)"); } - let compactionEntry: CompactionEntry | undefined; + let hookCompaction: CompactionResult | undefined; let fromHook = false; - if (this._hookRunner?.hasHandlers("session")) { + if (this._hookRunner?.hasHandlers("session_before_compact")) { + // Get previous compactions, newest first + const previousCompactions = entries.filter((e): e is CompactionEntry => e.type === "compaction").reverse(); + const result = (await this._hookRunner.emit({ - type: "session", - entries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "before_compact", - cutPoint: preparation.cutPoint, - previousSummary, - messagesToSummarize: [...preparation.messagesToSummarize], - messagesToKeep: [...preparation.messagesToKeep], - tokensBefore: preparation.tokensBefore, + type: "session_before_compact", + preparation, + previousCompactions, customInstructions, model: this.model, - resolveApiKey: async (m: Model) => (await this._modelRegistry.getApiKey(m)) ?? undefined, signal: this._compactionAbortController.signal, - })) as SessionEventResult | undefined; + })) as SessionBeforeCompactResult | undefined; if (result?.cancel) { throw new Error("Compaction cancelled"); } - if (result?.compactionEntry) { - compactionEntry = result.compactionEntry; + if (result?.compaction) { + hookCompaction = result.compaction; fromHook = true; } } - if (!compactionEntry) { - compactionEntry = await compact( + let summary: string; + let firstKeptEntryId: string; + let tokensBefore: number; + let details: unknown; + + if (hookCompaction) { + // Hook provided compaction content + summary = hookCompaction.summary; + firstKeptEntryId = hookCompaction.firstKeptEntryId; + tokensBefore = hookCompaction.tokensBefore; + details = hookCompaction.details; + } else { + // Generate compaction result + const result = await compact( entries, this.model, settings, @@ -806,36 +955,42 @@ export class AgentSession { this._compactionAbortController.signal, customInstructions, ); + summary = result.summary; + firstKeptEntryId = result.firstKeptEntryId; + tokensBefore = result.tokensBefore; + details = result.details; } if (this._compactionAbortController.signal.aborted) { throw new Error("Compaction cancelled"); } - this.sessionManager.saveCompaction(compactionEntry); + this.sessionManager.appendCompaction(summary, firstKeptEntryId, tokensBefore, details, fromHook); const newEntries = this.sessionManager.getEntries(); const sessionContext = this.sessionManager.buildSessionContext(); this.agent.replaceMessages(sessionContext.messages); - if (this._hookRunner) { + // Get the saved compaction entry for the hook + const savedCompactionEntry = newEntries.find((e) => e.type === "compaction" && e.summary === summary) as + | CompactionEntry + | undefined; + + if (this._hookRunner && savedCompactionEntry) { await this._hookRunner.emit({ - type: "session", - entries: newEntries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "compact", - compactionEntry, - tokensBefore: compactionEntry.tokensBefore, + type: "session_compact", + compactionEntry: savedCompactionEntry, fromHook, }); } return { - tokensBefore: compactionEntry.tokensBefore, - summary: compactionEntry.summary, + summary, + firstKeptEntryId, + tokensBefore, + details, }; } finally { - this._compactionAbortController = null; + this._compactionAbortController = undefined; this._reconnectToAgent(); } } @@ -848,6 +1003,13 @@ export class AgentSession { this._autoCompactionAbortController?.abort(); } + /** + * Cancel in-progress branch summarization. + */ + abortBranchSummary(): void { + this._branchSummaryAbortController?.abort(); + } + /** * Check if compaction is needed and run it. * Called after agent_end and before prompt submission. @@ -901,13 +1063,13 @@ export class AgentSession { try { if (!this.model) { - this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false }); return; } const apiKey = await this._modelRegistry.getApiKey(this.model); if (!apiKey) { - this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false }); return; } @@ -915,87 +1077,91 @@ export class AgentSession { const preparation = prepareCompaction(entries, settings); if (!preparation) { - this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false }); return; } - // Find previous compaction summary if any - let previousSummary: string | undefined; - for (let i = entries.length - 1; i >= 0; i--) { - if (entries[i].type === "compaction") { - previousSummary = (entries[i] as CompactionEntry).summary; - break; - } - } - - let compactionEntry: CompactionEntry | undefined; + let hookCompaction: CompactionResult | undefined; let fromHook = false; - if (this._hookRunner?.hasHandlers("session")) { + if (this._hookRunner?.hasHandlers("session_before_compact")) { + // Get previous compactions, newest first + const previousCompactions = entries.filter((e): e is CompactionEntry => e.type === "compaction").reverse(); + const hookResult = (await this._hookRunner.emit({ - type: "session", - entries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "before_compact", - cutPoint: preparation.cutPoint, - previousSummary, - messagesToSummarize: [...preparation.messagesToSummarize], - messagesToKeep: [...preparation.messagesToKeep], - tokensBefore: preparation.tokensBefore, + type: "session_before_compact", + preparation, + previousCompactions, customInstructions: undefined, model: this.model, - resolveApiKey: async (m: Model) => (await this._modelRegistry.getApiKey(m)) ?? undefined, signal: this._autoCompactionAbortController.signal, - })) as SessionEventResult | undefined; + })) as SessionBeforeCompactResult | undefined; if (hookResult?.cancel) { - this._emit({ type: "auto_compaction_end", result: null, aborted: true, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: true, willRetry: false }); return; } - if (hookResult?.compactionEntry) { - compactionEntry = hookResult.compactionEntry; + if (hookResult?.compaction) { + hookCompaction = hookResult.compaction; fromHook = true; } } - if (!compactionEntry) { - compactionEntry = await compact( + let summary: string; + let firstKeptEntryId: string; + let tokensBefore: number; + let details: unknown; + + if (hookCompaction) { + // Hook provided compaction content + summary = hookCompaction.summary; + firstKeptEntryId = hookCompaction.firstKeptEntryId; + tokensBefore = hookCompaction.tokensBefore; + details = hookCompaction.details; + } else { + // Generate compaction result + const compactResult = await compact( entries, this.model, settings, apiKey, this._autoCompactionAbortController.signal, ); + summary = compactResult.summary; + firstKeptEntryId = compactResult.firstKeptEntryId; + tokensBefore = compactResult.tokensBefore; + details = compactResult.details; } if (this._autoCompactionAbortController.signal.aborted) { - this._emit({ type: "auto_compaction_end", result: null, aborted: true, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: true, willRetry: false }); return; } - this.sessionManager.saveCompaction(compactionEntry); + this.sessionManager.appendCompaction(summary, firstKeptEntryId, tokensBefore, details, fromHook); const newEntries = this.sessionManager.getEntries(); const sessionContext = this.sessionManager.buildSessionContext(); this.agent.replaceMessages(sessionContext.messages); - if (this._hookRunner) { + // Get the saved compaction entry for the hook + const savedCompactionEntry = newEntries.find((e) => e.type === "compaction" && e.summary === summary) as + | CompactionEntry + | undefined; + + if (this._hookRunner && savedCompactionEntry) { await this._hookRunner.emit({ - type: "session", - entries: newEntries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "compact", - compactionEntry, - tokensBefore: compactionEntry.tokensBefore, + type: "session_compact", + compactionEntry: savedCompactionEntry, fromHook, }); } const result: CompactionResult = { - tokensBefore: compactionEntry.tokensBefore, - summary: compactionEntry.summary, + summary, + firstKeptEntryId, + tokensBefore, + details, }; this._emit({ type: "auto_compaction_end", result, aborted: false, willRetry }); @@ -1011,7 +1177,7 @@ export class AgentSession { }, 100); } } catch (error) { - this._emit({ type: "auto_compaction_end", result: null, aborted: false, willRetry: false }); + this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false }); if (reason === "overflow") { throw new Error( @@ -1019,7 +1185,7 @@ export class AgentSession { ); } } finally { - this._autoCompactionAbortController = null; + this._autoCompactionAbortController = undefined; } } @@ -1111,7 +1277,7 @@ export class AgentSession { // Aborted during sleep - emit end event so UI can clean up const attempt = this._retryAttempt; this._retryAttempt = 0; - this._retryAbortController = null; + this._retryAbortController = undefined; this._emit({ type: "auto_retry_end", success: false, @@ -1121,7 +1287,7 @@ export class AgentSession { this._resolveRetry(); return false; } - this._retryAbortController = null; + this._retryAbortController = undefined; // Retry via continue() - use setTimeout to break out of event handler chain setTimeout(() => { @@ -1173,7 +1339,7 @@ export class AgentSession { /** Whether auto-retry is currently in progress */ get isRetrying(): boolean { - return this._retryPromise !== null; + return this._retryPromise !== undefined; } /** Whether auto-retry is enabled */ @@ -1228,12 +1394,12 @@ export class AgentSession { this.agent.appendMessage(bashMessage); // Save to session - this.sessionManager.saveMessage(bashMessage); + this.sessionManager.appendMessage(bashMessage); } return result; } finally { - this._bashAbortController = null; + this._bashAbortController = undefined; } } @@ -1246,7 +1412,7 @@ export class AgentSession { /** Whether a bash command is currently running */ get isBashRunning(): boolean { - return this._bashAbortController !== null; + return this._bashAbortController !== undefined; } /** Whether there are pending bash messages waiting to be flushed */ @@ -1266,7 +1432,7 @@ export class AgentSession { this.agent.appendMessage(bashMessage); // Save to session - this.sessionManager.saveMessage(bashMessage); + this.sessionManager.appendMessage(bashMessage); } this._pendingBashMessages = []; @@ -1283,18 +1449,14 @@ export class AgentSession { * @returns true if switch completed, false if cancelled by hook */ async switchSession(sessionPath: string): Promise { - const previousSessionFile = this.sessionFile; - const oldEntries = this.sessionManager.getEntries(); + const previousSessionFile = this.sessionManager.getSessionFile(); - // Emit before_switch event (can be cancelled) - if (this._hookRunner?.hasHandlers("session")) { + // Emit session_before_switch event (can be cancelled) + if (this._hookRunner?.hasHandlers("session_before_switch")) { const result = (await this._hookRunner.emit({ - type: "session", - entries: oldEntries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "before_switch", - })) as SessionEventResult | undefined; + type: "session_before_switch", + targetSessionFile: sessionPath, + })) as SessionBeforeSwitchResult | undefined; if (result?.cancel) { return false; @@ -1309,18 +1471,13 @@ export class AgentSession { this.sessionManager.setSessionFile(sessionPath); // Reload messages - const entries = this.sessionManager.getEntries(); const sessionContext = this.sessionManager.buildSessionContext(); - // Emit session event to hooks + // Emit session_switch event to hooks if (this._hookRunner) { - this._hookRunner.setSessionFile(sessionPath); await this._hookRunner.emit({ - type: "session", - entries, - sessionFile: sessionPath, + type: "session_switch", previousSessionFile, - reason: "switch", }); } @@ -1371,16 +1528,12 @@ export class AgentSession { let skipConversationRestore = false; - // Emit before_branch event (can be cancelled) - if (this._hookRunner?.hasHandlers("session")) { + // Emit session_before_branch event (can be cancelled) + if (this._hookRunner?.hasHandlers("session_before_branch")) { const result = (await this._hookRunner.emit({ - type: "session", - entries, - sessionFile: this.sessionFile, - previousSessionFile: null, - reason: "before_branch", - targetTurnIndex: entryIndex, - })) as SessionEventResult | undefined; + type: "session_before_branch", + entryIndex: entryIndex, + })) as SessionBeforeBranchResult | undefined; if (result?.cancel) { return { selectedText, cancelled: true }; @@ -1388,28 +1541,20 @@ export class AgentSession { skipConversationRestore = result?.skipConversationRestore ?? false; } - // Create branched session (returns null in --no-session mode) - const newSessionFile = this.sessionManager.createBranchedSessionFromEntries(entries, entryIndex); - - // Update session file if we have one (file-based mode) - if (newSessionFile !== null) { - this.sessionManager.setSessionFile(newSessionFile); + if (!selectedEntry.parentId) { + this.sessionManager.newSession(); + } else { + this.sessionManager.createBranchedSession(selectedEntry.parentId); } // Reload messages from entries (works for both file and in-memory mode) - const newEntries = this.sessionManager.getEntries(); const sessionContext = this.sessionManager.buildSessionContext(); - // Emit branch event to hooks (after branch completes) + // Emit session_branch event to hooks (after branch completes) if (this._hookRunner) { - this._hookRunner.setSessionFile(newSessionFile); await this._hookRunner.emit({ - type: "session", - entries: newEntries, - sessionFile: newSessionFile, + type: "session_branch", previousSessionFile, - reason: "branch", - targetTurnIndex: entryIndex, }); } @@ -1423,6 +1568,174 @@ export class AgentSession { return { selectedText, cancelled: false }; } + // ========================================================================= + // Tree Navigation + // ========================================================================= + + /** + * Navigate to a different node in the session tree. + * Unlike branch() which creates a new session file, this stays in the same file. + * + * @param targetId The entry ID to navigate to + * @param options.summarize Whether user wants to summarize abandoned branch + * @param options.customInstructions Custom instructions for summarizer + * @returns Result with editorText (if user message) and cancelled status + */ + async navigateTree( + targetId: string, + options: { summarize?: boolean; customInstructions?: string } = {}, + ): Promise<{ editorText?: string; cancelled: boolean; aborted?: boolean; summaryEntry?: BranchSummaryEntry }> { + const oldLeafId = this.sessionManager.getLeafId(); + + // No-op if already at target + if (targetId === oldLeafId) { + return { cancelled: false }; + } + + // Model required for summarization + if (options.summarize && !this.model) { + throw new Error("No model available for summarization"); + } + + const targetEntry = this.sessionManager.getEntry(targetId); + if (!targetEntry) { + throw new Error(`Entry ${targetId} not found`); + } + + // Collect entries to summarize (from old leaf to common ancestor) + const { entries: entriesToSummarize, commonAncestorId } = collectEntriesForBranchSummary( + this.sessionManager, + oldLeafId, + targetId, + ); + + // Prepare event data + const preparation: TreePreparation = { + targetId, + oldLeafId, + commonAncestorId, + entriesToSummarize, + userWantsSummary: options.summarize ?? false, + }; + + // Set up abort controller for summarization + this._branchSummaryAbortController = new AbortController(); + let hookSummary: { summary: string; details?: unknown } | undefined; + let fromHook = false; + + // Emit session_before_tree event + if (this._hookRunner?.hasHandlers("session_before_tree")) { + const result = (await this._hookRunner.emit({ + type: "session_before_tree", + preparation, + model: this.model!, // Checked above if summarize is true + signal: this._branchSummaryAbortController.signal, + })) as SessionBeforeTreeResult | undefined; + + if (result?.cancel) { + return { cancelled: true }; + } + + if (result?.summary && options.summarize) { + hookSummary = result.summary; + fromHook = true; + } + } + + // Run default summarizer if needed + let summaryText: string | undefined; + let summaryDetails: unknown; + if (options.summarize && entriesToSummarize.length > 0 && !hookSummary) { + const model = this.model!; + const apiKey = await this._modelRegistry.getApiKey(model); + if (!apiKey) { + throw new Error(`No API key for ${model.provider}`); + } + const branchSummarySettings = this.settingsManager.getBranchSummarySettings(); + const result = await generateBranchSummary(entriesToSummarize, { + model, + apiKey, + signal: this._branchSummaryAbortController.signal, + customInstructions: options.customInstructions, + reserveTokens: branchSummarySettings.reserveTokens, + }); + this._branchSummaryAbortController = undefined; + if (result.aborted) { + return { cancelled: true, aborted: true }; + } + if (result.error) { + throw new Error(result.error); + } + summaryText = result.summary; + summaryDetails = { + readFiles: result.readFiles || [], + modifiedFiles: result.modifiedFiles || [], + }; + } else if (hookSummary) { + summaryText = hookSummary.summary; + summaryDetails = hookSummary.details; + } + + // Determine the new leaf position based on target type + let newLeafId: string | null; + let editorText: string | undefined; + + if (targetEntry.type === "message" && targetEntry.message.role === "user") { + // User message: leaf = parent (null if root), text goes to editor + newLeafId = targetEntry.parentId; + editorText = this._extractUserMessageText(targetEntry.message.content); + } else if (targetEntry.type === "custom_message") { + // Custom message: leaf = parent (null if root), text goes to editor + newLeafId = targetEntry.parentId; + editorText = + typeof targetEntry.content === "string" + ? targetEntry.content + : targetEntry.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join(""); + } else { + // Non-user message: leaf = selected node + newLeafId = targetId; + } + + // Switch leaf (with or without summary) + // Summary is attached at the navigation target position (newLeafId), not the old branch + let summaryEntry: BranchSummaryEntry | undefined; + if (summaryText) { + // Create summary at target position (can be null for root) + const summaryId = this.sessionManager.branchWithSummary(newLeafId, summaryText, summaryDetails, fromHook); + summaryEntry = this.sessionManager.getEntry(summaryId) as BranchSummaryEntry; + } else if (newLeafId === null) { + // No summary, navigating to root - reset leaf + this.sessionManager.resetLeaf(); + } else { + // No summary, navigating to non-root + this.sessionManager.branch(newLeafId); + } + + // Update agent state + const sessionContext = this.sessionManager.buildSessionContext(); + this.agent.replaceMessages(sessionContext.messages); + + // Emit session_tree event + if (this._hookRunner) { + await this._hookRunner.emit({ + type: "session_tree", + newLeafId: this.sessionManager.getLeafId(), + oldLeafId, + summaryEntry, + fromHook: summaryText ? fromHook : undefined, + }); + } + + // Emit to custom tools + await this._emitToolSessionEvent("tree", this.sessionFile); + + this._branchSummaryAbortController = undefined; + return { editorText, cancelled: false, summaryEntry }; + } + /** * Get all user messages from session for branch selector. */ @@ -1519,9 +1832,9 @@ export class AgentSession { /** * Get text content of last assistant message. * Useful for /copy command. - * @returns Text content, or null if no assistant message exists + * @returns Text content, or undefined if no assistant message exists */ - getLastAssistantText(): string | null { + getLastAssistantText(): string | undefined { const lastAssistant = this.messages .slice() .reverse() @@ -1533,7 +1846,7 @@ export class AgentSession { return true; }); - if (!lastAssistant) return null; + if (!lastAssistant) return undefined; let text = ""; for (const content of (lastAssistant as AssistantMessage).content) { @@ -1542,7 +1855,7 @@ export class AgentSession { } } - return text.trim() || null; + return text.trim() || undefined; } // ========================================================================= @@ -1559,7 +1872,7 @@ export class AgentSession { /** * Get the hook runner (for setting UI context and error handlers). */ - get hookRunner(): HookRunner | null { + get hookRunner(): HookRunner | undefined { return this._hookRunner; } @@ -1576,7 +1889,7 @@ export class AgentSession { */ private async _emitToolSessionEvent( reason: ToolSessionEvent["reason"], - previousSessionFile: string | null, + previousSessionFile: string | undefined, ): Promise { const event: ToolSessionEvent = { entries: this.sessionManager.getEntries(), diff --git a/packages/coding-agent/src/core/auth-storage.ts b/packages/coding-agent/src/core/auth-storage.ts index afc5e076..0b0c23a5 100644 --- a/packages/coding-agent/src/core/auth-storage.ts +++ b/packages/coding-agent/src/core/auth-storage.ts @@ -94,8 +94,8 @@ export class AuthStorage { /** * Get credential for a provider. */ - get(provider: string): AuthCredential | null { - return this.data[provider] ?? null; + get(provider: string): AuthCredential | undefined { + return this.data[provider] ?? undefined; } /** @@ -191,7 +191,7 @@ export class AuthStorage { * 4. Environment variable * 5. Fallback resolver (models.json custom providers) */ - async getApiKey(provider: string): Promise { + async getApiKey(provider: string): Promise { // Runtime override takes highest priority const runtimeKey = this.runtimeOverrides.get(provider); if (runtimeKey) { @@ -230,6 +230,6 @@ export class AuthStorage { if (envKey) return envKey; // Fall back to custom resolver (e.g., models.json custom providers) - return this.fallbackResolver?.(provider) ?? null; + return this.fallbackResolver?.(provider) ?? undefined; } } diff --git a/packages/coding-agent/src/core/bash-executor.ts b/packages/coding-agent/src/core/bash-executor.ts index d83f63b5..0a83d2e5 100644 --- a/packages/coding-agent/src/core/bash-executor.ts +++ b/packages/coding-agent/src/core/bash-executor.ts @@ -29,8 +29,8 @@ export interface BashExecutorOptions { export interface BashResult { /** Combined stdout + stderr output (sanitized, possibly truncated) */ output: string; - /** Process exit code (null if killed/cancelled) */ - exitCode: number | null; + /** Process exit code (undefined if killed/cancelled) */ + exitCode: number | undefined; /** Whether the command was cancelled via signal */ cancelled: boolean; /** Whether the output was truncated */ @@ -88,7 +88,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro child.kill(); resolve({ output: "", - exitCode: null, + exitCode: undefined, cancelled: true, truncated: false, }); @@ -154,7 +154,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro resolve({ output: truncationResult.truncated ? truncationResult.content : fullOutput, - exitCode: code, + exitCode: cancelled ? undefined : code, cancelled, truncated: truncationResult.truncated, fullOutputPath: tempFilePath, diff --git a/packages/coding-agent/src/core/compaction.ts b/packages/coding-agent/src/core/compaction.ts deleted file mode 100644 index af4edfb8..00000000 --- a/packages/coding-agent/src/core/compaction.ts +++ /dev/null @@ -1,530 +0,0 @@ -/** - * Context compaction for long sessions. - * - * Pure functions for compaction logic. The session manager handles I/O, - * and after compaction the session is reloaded. - */ - -import type { AppMessage } from "@mariozechner/pi-agent-core"; -import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai"; -import { complete } from "@mariozechner/pi-ai"; -import { messageTransformer } from "./messages.js"; -import type { CompactionEntry, SessionEntry } from "./session-manager.js"; - -// ============================================================================ -// Types -// ============================================================================ - -export interface CompactionSettings { - enabled: boolean; - reserveTokens: number; - keepRecentTokens: number; -} - -export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = { - enabled: true, - reserveTokens: 16384, - keepRecentTokens: 20000, -}; - -// ============================================================================ -// Token calculation -// ============================================================================ - -/** - * Calculate total context tokens from usage. - * Uses the native totalTokens field when available, falls back to computing from components. - */ -export function calculateContextTokens(usage: Usage): number { - return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite; -} - -/** - * Get usage from an assistant message if available. - * Skips aborted and error messages as they don't have valid usage data. - */ -function getAssistantUsage(msg: AppMessage): Usage | null { - if (msg.role === "assistant" && "usage" in msg) { - const assistantMsg = msg as AssistantMessage; - if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) { - return assistantMsg.usage; - } - } - return null; -} - -/** - * Find the last non-aborted assistant message usage from session entries. - */ -export function getLastAssistantUsage(entries: SessionEntry[]): Usage | null { - for (let i = entries.length - 1; i >= 0; i--) { - const entry = entries[i]; - if (entry.type === "message") { - const usage = getAssistantUsage(entry.message); - if (usage) return usage; - } - } - return null; -} - -/** - * Check if compaction should trigger based on context usage. - */ -export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean { - if (!settings.enabled) return false; - return contextTokens > contextWindow - settings.reserveTokens; -} - -// ============================================================================ -// Cut point detection -// ============================================================================ - -/** - * Estimate token count for a message using chars/4 heuristic. - * This is conservative (overestimates tokens). - */ -export function estimateTokens(message: AppMessage): number { - let chars = 0; - - // Handle bashExecution messages - if (message.role === "bashExecution") { - const bash = message as unknown as { command: string; output: string }; - chars = bash.command.length + bash.output.length; - return Math.ceil(chars / 4); - } - - // Handle user messages - if (message.role === "user") { - const content = (message as { content: string | Array<{ type: string; text?: string }> }).content; - if (typeof content === "string") { - chars = content.length; - } else if (Array.isArray(content)) { - for (const block of content) { - if (block.type === "text" && block.text) { - chars += block.text.length; - } - } - } - return Math.ceil(chars / 4); - } - - // Handle assistant messages - if (message.role === "assistant") { - const assistant = message as AssistantMessage; - for (const block of assistant.content) { - if (block.type === "text") { - chars += block.text.length; - } else if (block.type === "thinking") { - chars += block.thinking.length; - } else if (block.type === "toolCall") { - chars += block.name.length + JSON.stringify(block.arguments).length; - } - } - return Math.ceil(chars / 4); - } - - // Handle tool results - if (message.role === "toolResult") { - const toolResult = message as { content: Array<{ type: string; text?: string }> }; - for (const block of toolResult.content) { - if (block.type === "text" && block.text) { - chars += block.text.length; - } - } - return Math.ceil(chars / 4); - } - - return 0; -} - -/** - * Find valid cut points: indices of user, assistant, or bashExecution messages. - * Never cut at tool results (they must follow their tool call). - * When we cut at an assistant message with tool calls, its tool results follow it - * and will be kept. - * BashExecutionMessage is treated like a user message (user-initiated context). - */ -function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] { - const cutPoints: number[] = []; - for (let i = startIndex; i < endIndex; i++) { - const entry = entries[i]; - if (entry.type === "message") { - const role = entry.message.role; - // user, assistant, and bashExecution are valid cut points - // toolResult must stay with its preceding tool call - if (role === "user" || role === "assistant" || role === "bashExecution") { - cutPoints.push(i); - } - } - } - return cutPoints; -} - -/** - * Find the user message (or bashExecution) that starts the turn containing the given entry index. - * Returns -1 if no turn start found before the index. - * BashExecutionMessage is treated like a user message for turn boundaries. - */ -export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number { - for (let i = entryIndex; i >= startIndex; i--) { - const entry = entries[i]; - if (entry.type === "message") { - const role = entry.message.role; - if (role === "user" || role === "bashExecution") { - return i; - } - } - } - return -1; -} - -export interface CutPointResult { - /** Index of first entry to keep */ - firstKeptEntryIndex: number; - /** Index of user message that starts the turn being split, or -1 if not splitting */ - turnStartIndex: number; - /** Whether this cut splits a turn (cut point is not a user message) */ - isSplitTurn: boolean; -} - -/** - * Find the cut point in session entries that keeps approximately `keepRecentTokens`. - * - * Algorithm: Walk backwards from newest, accumulating estimated message sizes. - * Stop when we've accumulated >= keepRecentTokens. Cut at that point. - * - * Can cut at user OR assistant messages (never tool results). When cutting at an - * assistant message with tool calls, its tool results come after and will be kept. - * - * Returns CutPointResult with: - * - firstKeptEntryIndex: the entry index to start keeping from - * - turnStartIndex: if cutting mid-turn, the user message that started that turn - * - isSplitTurn: whether we're cutting in the middle of a turn - * - * Only considers entries between `startIndex` and `endIndex` (exclusive). - */ -export function findCutPoint( - entries: SessionEntry[], - startIndex: number, - endIndex: number, - keepRecentTokens: number, -): CutPointResult { - const cutPoints = findValidCutPoints(entries, startIndex, endIndex); - - if (cutPoints.length === 0) { - return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false }; - } - - // Walk backwards from newest, accumulating estimated message sizes - let accumulatedTokens = 0; - let cutIndex = startIndex; // Default: keep everything in range - - for (let i = endIndex - 1; i >= startIndex; i--) { - const entry = entries[i]; - if (entry.type !== "message") continue; - - // Estimate this message's size - const messageTokens = estimateTokens(entry.message); - accumulatedTokens += messageTokens; - - // Check if we've exceeded the budget - if (accumulatedTokens >= keepRecentTokens) { - // Find the closest valid cut point at or after this entry - for (let c = 0; c < cutPoints.length; c++) { - if (cutPoints[c] >= i) { - cutIndex = cutPoints[c]; - break; - } - } - break; - } - } - - // Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.) - while (cutIndex > startIndex) { - const prevEntry = entries[cutIndex - 1]; - // Stop at compaction boundaries - if (prevEntry.type === "compaction") { - break; - } - if (prevEntry.type === "message") { - // Stop if we hit any message - break; - } - // Include this non-message entry (bash, settings change, etc.) - cutIndex--; - } - - // Determine if this is a split turn - const cutEntry = entries[cutIndex]; - const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user"; - const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex); - - return { - firstKeptEntryIndex: cutIndex, - turnStartIndex, - isSplitTurn: !isUserMessage && turnStartIndex !== -1, - }; -} - -// ============================================================================ -// Summarization -// ============================================================================ - -const SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION. Create a handoff summary for another LLM that will resume the task. - -Include: -- Current progress and key decisions made -- Important context, constraints, or user preferences -- Absolute file paths of any relevant files that were read or modified -- What remains to be done (clear next steps) -- Any critical data, examples, or references needed to continue - -Be concise, structured, and focused on helping the next LLM seamlessly continue the work.`; - -/** - * Generate a summary of the conversation using the LLM. - */ -export async function generateSummary( - currentMessages: AppMessage[], - model: Model, - reserveTokens: number, - apiKey: string, - signal?: AbortSignal, - customInstructions?: string, -): Promise { - const maxTokens = Math.floor(0.8 * reserveTokens); - - const prompt = customInstructions - ? `${SUMMARIZATION_PROMPT}\n\nAdditional focus: ${customInstructions}` - : SUMMARIZATION_PROMPT; - - // Transform custom messages (like bashExecution) to LLM-compatible messages - const transformedMessages = messageTransformer(currentMessages); - - const summarizationMessages = [ - ...transformedMessages, - { - role: "user" as const, - content: [{ type: "text" as const, text: prompt }], - timestamp: Date.now(), - }, - ]; - - const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey }); - - const textContent = response.content - .filter((c): c is { type: "text"; text: string } => c.type === "text") - .map((c) => c.text) - .join("\n"); - - return textContent; -} - -// ============================================================================ -// Compaction Preparation (for hooks) -// ============================================================================ - -export interface CompactionPreparation { - cutPoint: CutPointResult; - /** Messages that will be summarized and discarded */ - messagesToSummarize: AppMessage[]; - /** Messages that will be kept after the summary (recent turns) */ - messagesToKeep: AppMessage[]; - tokensBefore: number; - boundaryStart: number; -} - -export function prepareCompaction(entries: SessionEntry[], settings: CompactionSettings): CompactionPreparation | null { - if (entries.length > 0 && entries[entries.length - 1].type === "compaction") { - return null; - } - - let prevCompactionIndex = -1; - for (let i = entries.length - 1; i >= 0; i--) { - if (entries[i].type === "compaction") { - prevCompactionIndex = i; - break; - } - } - const boundaryStart = prevCompactionIndex + 1; - const boundaryEnd = entries.length; - - const lastUsage = getLastAssistantUsage(entries); - const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0; - - const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens); - - const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex; - - // Messages to summarize (will be discarded after summary) - const messagesToSummarize: AppMessage[] = []; - for (let i = boundaryStart; i < historyEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - messagesToSummarize.push(entry.message); - } - } - - // Messages to keep (recent turns, kept after summary) - const messagesToKeep: AppMessage[] = []; - for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - messagesToKeep.push(entry.message); - } - } - - return { cutPoint, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart }; -} - -// ============================================================================ -// Main compaction function -// ============================================================================ - -const TURN_PREFIX_SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION for a split turn. -This is the PREFIX of a turn that was too large to keep in full. The SUFFIX (recent work) is being kept. - -Create a handoff summary that captures: -- What the user originally asked for in this turn -- Key decisions and progress made early in this turn -- Important context needed to understand the kept suffix - -Be concise. Focus on information needed to understand the retained recent work.`; - -/** - * Calculate compaction and generate summary. - * Returns the CompactionEntry to append to the session file. - * - * @param entries - All session entries - * @param model - Model to use for summarization - * @param settings - Compaction settings - * @param apiKey - API key for LLM - * @param signal - Optional abort signal - * @param customInstructions - Optional custom focus for the summary - */ -export async function compact( - entries: SessionEntry[], - model: Model, - settings: CompactionSettings, - apiKey: string, - signal?: AbortSignal, - customInstructions?: string, -): Promise { - // Don't compact if the last entry is already a compaction - if (entries.length > 0 && entries[entries.length - 1].type === "compaction") { - throw new Error("Already compacted"); - } - - // Find previous compaction boundary - let prevCompactionIndex = -1; - for (let i = entries.length - 1; i >= 0; i--) { - if (entries[i].type === "compaction") { - prevCompactionIndex = i; - break; - } - } - const boundaryStart = prevCompactionIndex + 1; - const boundaryEnd = entries.length; - - // Get token count before compaction - const lastUsage = getLastAssistantUsage(entries); - const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0; - - // Find cut point (entry index) within the valid range - const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens); - - // Extract messages for history summary (before the turn that contains the cut point) - const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex; - const historyMessages: AppMessage[] = []; - for (let i = boundaryStart; i < historyEnd; i++) { - const entry = entries[i]; - if (entry.type === "message") { - historyMessages.push(entry.message); - } - } - - // Include previous summary if there was a compaction - if (prevCompactionIndex >= 0) { - const prevCompaction = entries[prevCompactionIndex] as CompactionEntry; - historyMessages.unshift({ - role: "user", - content: `Previous session summary:\n${prevCompaction.summary}`, - timestamp: Date.now(), - }); - } - - // Extract messages for turn prefix summary (if splitting a turn) - const turnPrefixMessages: AppMessage[] = []; - if (cutResult.isSplitTurn) { - for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) { - const entry = entries[i]; - if (entry.type === "message") { - turnPrefixMessages.push(entry.message); - } - } - } - - // Generate summaries (can be parallel if both needed) and merge into one - let summary: string; - - if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) { - // Generate both summaries in parallel - const [historyResult, turnPrefixResult] = await Promise.all([ - historyMessages.length > 0 - ? generateSummary(historyMessages, model, settings.reserveTokens, apiKey, signal, customInstructions) - : Promise.resolve("No prior history."), - generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal), - ]); - // Merge into single summary - summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`; - } else { - // Just generate history summary - summary = await generateSummary( - historyMessages, - model, - settings.reserveTokens, - apiKey, - signal, - customInstructions, - ); - } - - return { - type: "compaction", - timestamp: new Date().toISOString(), - summary, - firstKeptEntryIndex: cutResult.firstKeptEntryIndex, - tokensBefore, - }; -} - -/** - * Generate a summary for a turn prefix (when splitting a turn). - */ -async function generateTurnPrefixSummary( - messages: AppMessage[], - model: Model, - reserveTokens: number, - apiKey: string, - signal?: AbortSignal, -): Promise { - const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix - - const transformedMessages = messageTransformer(messages); - const summarizationMessages = [ - ...transformedMessages, - { - role: "user" as const, - content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }], - timestamp: Date.now(), - }, - ]; - - const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey }); - - return response.content - .filter((c): c is { type: "text"; text: string } => c.type === "text") - .map((c) => c.text) - .join("\n"); -} diff --git a/packages/coding-agent/src/core/compaction/branch-summarization.ts b/packages/coding-agent/src/core/compaction/branch-summarization.ts new file mode 100644 index 00000000..8bca45ff --- /dev/null +++ b/packages/coding-agent/src/core/compaction/branch-summarization.ts @@ -0,0 +1,343 @@ +/** + * Branch summarization for tree navigation. + * + * When navigating to a different point in the session tree, this generates + * a summary of the branch being left so context isn't lost. + */ + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { Model } from "@mariozechner/pi-ai"; +import { completeSimple } from "@mariozechner/pi-ai"; +import { + convertToLlm, + createBranchSummaryMessage, + createCompactionSummaryMessage, + createHookMessage, +} from "../messages.js"; +import type { ReadonlySessionManager, SessionEntry } from "../session-manager.js"; +import { estimateTokens } from "./compaction.js"; +import { + computeFileLists, + createFileOps, + extractFileOpsFromMessage, + type FileOperations, + formatFileOperations, + SUMMARIZATION_SYSTEM_PROMPT, + serializeConversation, +} from "./utils.js"; + +// ============================================================================ +// Types +// ============================================================================ + +export interface BranchSummaryResult { + summary?: string; + readFiles?: string[]; + modifiedFiles?: string[]; + aborted?: boolean; + error?: string; +} + +/** Details stored in BranchSummaryEntry.details for file tracking */ +export interface BranchSummaryDetails { + readFiles: string[]; + modifiedFiles: string[]; +} + +export type { FileOperations } from "./utils.js"; + +export interface BranchPreparation { + /** Messages extracted for summarization, in chronological order */ + messages: AgentMessage[]; + /** File operations extracted from tool calls */ + fileOps: FileOperations; + /** Total estimated tokens in messages */ + totalTokens: number; +} + +export interface CollectEntriesResult { + /** Entries to summarize, in chronological order */ + entries: SessionEntry[]; + /** Common ancestor between old and new position, if any */ + commonAncestorId: string | null; +} + +export interface GenerateBranchSummaryOptions { + /** Model to use for summarization */ + model: Model; + /** API key for the model */ + apiKey: string; + /** Abort signal for cancellation */ + signal: AbortSignal; + /** Optional custom instructions for summarization */ + customInstructions?: string; + /** Tokens reserved for prompt + LLM response (default 16384) */ + reserveTokens?: number; +} + +// ============================================================================ +// Entry Collection +// ============================================================================ + +/** + * Collect entries that should be summarized when navigating from one position to another. + * + * Walks from oldLeafId back to the common ancestor with targetId, collecting entries + * along the way. Does NOT stop at compaction boundaries - those are included and their + * summaries become context. + * + * @param session - Session manager (read-only access) + * @param oldLeafId - Current position (where we're navigating from) + * @param targetId - Target position (where we're navigating to) + * @returns Entries to summarize and the common ancestor + */ +export function collectEntriesForBranchSummary( + session: ReadonlySessionManager, + oldLeafId: string | null, + targetId: string, +): CollectEntriesResult { + // If no old position, nothing to summarize + if (!oldLeafId) { + return { entries: [], commonAncestorId: null }; + } + + // Find common ancestor (deepest node that's on both paths) + const oldPath = new Set(session.getPath(oldLeafId).map((e) => e.id)); + const targetPath = session.getPath(targetId); + + // targetPath is root-first, so iterate backwards to find deepest common ancestor + let commonAncestorId: string | null = null; + for (let i = targetPath.length - 1; i >= 0; i--) { + if (oldPath.has(targetPath[i].id)) { + commonAncestorId = targetPath[i].id; + break; + } + } + + // Collect entries from old leaf back to common ancestor + const entries: SessionEntry[] = []; + let current: string | null = oldLeafId; + + while (current && current !== commonAncestorId) { + const entry = session.getEntry(current); + if (!entry) break; + entries.push(entry); + current = entry.parentId; + } + + // Reverse to get chronological order + entries.reverse(); + + return { entries, commonAncestorId }; +} + +// ============================================================================ +// Entry to Message Conversion +// ============================================================================ + +/** + * Extract AgentMessage from a session entry. + * Similar to getMessageFromEntry in compaction.ts but also handles compaction entries. + */ +function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined { + switch (entry.type) { + case "message": + // Skip tool results - context is in assistant's tool call + if (entry.message.role === "toolResult") return undefined; + return entry.message; + + case "custom_message": + return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp); + + case "branch_summary": + return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp); + + case "compaction": + return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp); + + // These don't contribute to conversation content + case "thinking_level_change": + case "model_change": + case "custom": + case "label": + return undefined; + } +} + +/** + * Prepare entries for summarization with token budget. + * + * Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget. + * This ensures we keep the most recent context when the branch is too long. + * + * Also collects file operations from: + * - Tool calls in assistant messages + * - Existing branch_summary entries' details (for cumulative tracking) + * + * @param entries - Entries in chronological order + * @param tokenBudget - Maximum tokens to include (0 = no limit) + */ +export function prepareBranchEntries(entries: SessionEntry[], tokenBudget: number = 0): BranchPreparation { + const messages: AgentMessage[] = []; + const fileOps = createFileOps(); + let totalTokens = 0; + + // First pass: collect file ops from ALL entries (even if they don't fit in token budget) + // This ensures we capture cumulative file tracking from nested branch summaries + // Only extract from pi-generated summaries (fromHook !== true), not hook-generated ones + for (const entry of entries) { + if (entry.type === "branch_summary" && !entry.fromHook && entry.details) { + const details = entry.details as BranchSummaryDetails; + if (Array.isArray(details.readFiles)) { + for (const f of details.readFiles) fileOps.read.add(f); + } + if (Array.isArray(details.modifiedFiles)) { + // Modified files go into both edited and written for proper deduplication + for (const f of details.modifiedFiles) { + fileOps.edited.add(f); + } + } + } + } + + // Second pass: walk from newest to oldest, adding messages until token budget + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; + const message = getMessageFromEntry(entry); + if (!message) continue; + + // Extract file ops from assistant messages (tool calls) + extractFileOpsFromMessage(message, fileOps); + + const tokens = estimateTokens(message); + + // Check budget before adding + if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) { + // If this is a summary entry, try to fit it anyway as it's important context + if (entry.type === "compaction" || entry.type === "branch_summary") { + if (totalTokens < tokenBudget * 0.9) { + messages.unshift(message); + totalTokens += tokens; + } + } + // Stop - we've hit the budget + break; + } + + messages.unshift(message); + totalTokens += tokens; + } + + return { messages, fileOps, totalTokens }; +} + +// ============================================================================ +// Summary Generation +// ============================================================================ + +const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here. +Summary of that exploration: + +`; + +const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later. + +Use this EXACT format: + +## Goal +[What was the user trying to accomplish in this branch?] + +## Constraints & Preferences +- [Any constraints, preferences, or requirements mentioned] +- [Or "(none)" if none were mentioned] + +## Progress +### Done +- [x] [Completed tasks/changes] + +### In Progress +- [ ] [Work that was started but not finished] + +### Blocked +- [Issues preventing progress, if any] + +## Key Decisions +- **[Decision]**: [Brief rationale] + +## Next Steps +1. [What should happen next to continue this work] + +Keep each section concise. Preserve exact file paths, function names, and error messages.`; + +/** + * Generate a summary of abandoned branch entries. + * + * @param entries - Session entries to summarize (chronological order) + * @param options - Generation options + */ +export async function generateBranchSummary( + entries: SessionEntry[], + options: GenerateBranchSummaryOptions, +): Promise { + const { model, apiKey, signal, customInstructions, reserveTokens = 16384 } = options; + + // Token budget = context window minus reserved space for prompt + response + const contextWindow = model.contextWindow || 128000; + const tokenBudget = contextWindow - reserveTokens; + + const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget); + + if (messages.length === 0) { + return { summary: "No content to summarize" }; + } + + // Transform to LLM-compatible messages, then serialize to text + // Serialization prevents the model from treating it as a conversation to continue + const llmMessages = convertToLlm(messages); + const conversationText = serializeConversation(llmMessages); + + // Build prompt + const instructions = customInstructions || BRANCH_SUMMARY_PROMPT; + const promptText = `\n${conversationText}\n\n\n${instructions}`; + + const summarizationMessages = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: promptText }], + timestamp: Date.now(), + }, + ]; + + // Call LLM for summarization + const response = await completeSimple( + model, + { systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages }, + { apiKey, signal, maxTokens: 2048 }, + ); + + // Check if aborted or errored + if (response.stopReason === "aborted") { + return { aborted: true }; + } + if (response.stopReason === "error") { + return { error: response.errorMessage || "Summarization failed" }; + } + + let summary = response.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join("\n"); + + // Prepend preamble to provide context about the branch summary + summary = BRANCH_SUMMARY_PREAMBLE + summary; + + // Compute file lists and append to summary + const { readFiles, modifiedFiles } = computeFileLists(fileOps); + summary += formatFileOperations(readFiles, modifiedFiles); + + return { + summary: summary || "No summary generated", + readFiles, + modifiedFiles, + }; +} diff --git a/packages/coding-agent/src/core/compaction/compaction.ts b/packages/coding-agent/src/core/compaction/compaction.ts new file mode 100644 index 00000000..afa99152 --- /dev/null +++ b/packages/coding-agent/src/core/compaction/compaction.ts @@ -0,0 +1,759 @@ +/** + * Context compaction for long sessions. + * + * Pure functions for compaction logic. The session manager handles I/O, + * and after compaction the session is reloaded. + */ + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai"; +import { complete, completeSimple } from "@mariozechner/pi-ai"; +import { convertToLlm, createBranchSummaryMessage, createHookMessage } from "../messages.js"; +import type { CompactionEntry, SessionEntry } from "../session-manager.js"; +import { + computeFileLists, + createFileOps, + extractFileOpsFromMessage, + type FileOperations, + formatFileOperations, + SUMMARIZATION_SYSTEM_PROMPT, + serializeConversation, +} from "./utils.js"; + +// ============================================================================ +// File Operation Tracking +// ============================================================================ + +/** Details stored in CompactionEntry.details for file tracking */ +export interface CompactionDetails { + readFiles: string[]; + modifiedFiles: string[]; +} + +/** + * Extract file operations from messages and previous compaction entries. + */ +function extractFileOperations( + messages: AgentMessage[], + entries: SessionEntry[], + prevCompactionIndex: number, +): FileOperations { + const fileOps = createFileOps(); + + // Collect from previous compaction's details (if pi-generated) + if (prevCompactionIndex >= 0) { + const prevCompaction = entries[prevCompactionIndex] as CompactionEntry; + if (!prevCompaction.fromHook && prevCompaction.details) { + const details = prevCompaction.details as CompactionDetails; + if (Array.isArray(details.readFiles)) { + for (const f of details.readFiles) fileOps.read.add(f); + } + if (Array.isArray(details.modifiedFiles)) { + for (const f of details.modifiedFiles) fileOps.edited.add(f); + } + } + } + + // Extract from tool calls in messages + for (const msg of messages) { + extractFileOpsFromMessage(msg, fileOps); + } + + return fileOps; +} + +// ============================================================================ +// Message Extraction +// ============================================================================ + +/** + * Extract AgentMessage from an entry if it produces one. + * Returns undefined for entries that don't contribute to LLM context. + */ +function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined { + if (entry.type === "message") { + return entry.message; + } + if (entry.type === "custom_message") { + return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp); + } + if (entry.type === "branch_summary") { + return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp); + } + return undefined; +} + +/** Result from compact() - SessionManager adds uuid/parentUuid when saving */ +export interface CompactionResult { + summary: string; + firstKeptEntryId: string; + tokensBefore: number; + /** Hook-specific data (e.g., ArtifactIndex, version markers for structured compaction) */ + details?: T; +} + +// ============================================================================ +// Types +// ============================================================================ + +export interface CompactionSettings { + enabled: boolean; + reserveTokens: number; + keepRecentTokens: number; +} + +export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = { + enabled: true, + reserveTokens: 16384, + keepRecentTokens: 20000, +}; + +// ============================================================================ +// Token calculation +// ============================================================================ + +/** + * Calculate total context tokens from usage. + * Uses the native totalTokens field when available, falls back to computing from components. + */ +export function calculateContextTokens(usage: Usage): number { + return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite; +} + +/** + * Get usage from an assistant message if available. + * Skips aborted and error messages as they don't have valid usage data. + */ +function getAssistantUsage(msg: AgentMessage): Usage | undefined { + if (msg.role === "assistant" && "usage" in msg) { + const assistantMsg = msg as AssistantMessage; + if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) { + return assistantMsg.usage; + } + } + return undefined; +} + +/** + * Find the last non-aborted assistant message usage from session entries. + */ +export function getLastAssistantUsage(entries: SessionEntry[]): Usage | undefined { + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; + if (entry.type === "message") { + const usage = getAssistantUsage(entry.message); + if (usage) return usage; + } + } + return undefined; +} + +/** + * Check if compaction should trigger based on context usage. + */ +export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean { + if (!settings.enabled) return false; + return contextTokens > contextWindow - settings.reserveTokens; +} + +// ============================================================================ +// Cut point detection +// ============================================================================ + +/** + * Estimate token count for a message using chars/4 heuristic. + * This is conservative (overestimates tokens). + */ +export function estimateTokens(message: AgentMessage): number { + let chars = 0; + + switch (message.role) { + case "user": { + const content = (message as { content: string | Array<{ type: string; text?: string }> }).content; + if (typeof content === "string") { + chars = content.length; + } else if (Array.isArray(content)) { + for (const block of content) { + if (block.type === "text" && block.text) { + chars += block.text.length; + } + } + } + return Math.ceil(chars / 4); + } + case "assistant": { + const assistant = message as AssistantMessage; + for (const block of assistant.content) { + if (block.type === "text") { + chars += block.text.length; + } else if (block.type === "thinking") { + chars += block.thinking.length; + } else if (block.type === "toolCall") { + chars += block.name.length + JSON.stringify(block.arguments).length; + } + } + return Math.ceil(chars / 4); + } + case "hookMessage": + case "toolResult": { + if (typeof message.content === "string") { + chars = message.content.length; + } else { + for (const block of message.content) { + if (block.type === "text" && block.text) { + chars += block.text.length; + } + if (block.type === "image") { + chars += 4800; // Estimate images as 4000 chars, or 1200 tokens + } + } + } + return Math.ceil(chars / 4); + } + case "bashExecution": { + chars = message.command.length + message.output.length; + return Math.ceil(chars / 4); + } + case "branchSummary": + case "compactionSummary": { + chars = message.summary.length; + return Math.ceil(chars / 4); + } + } + + return 0; +} + +/** + * Find valid cut points: indices of user, assistant, custom, or bashExecution messages. + * Never cut at tool results (they must follow their tool call). + * When we cut at an assistant message with tool calls, its tool results follow it + * and will be kept. + * BashExecutionMessage is treated like a user message (user-initiated context). + */ +function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] { + const cutPoints: number[] = []; + for (let i = startIndex; i < endIndex; i++) { + const entry = entries[i]; + switch (entry.type) { + case "message": { + const role = entry.message.role; + switch (role) { + case "bashExecution": + case "hookMessage": + case "branchSummary": + case "compactionSummary": + case "user": + case "assistant": + cutPoints.push(i); + break; + case "toolResult": + break; + } + break; + } + case "thinking_level_change": + case "model_change": + case "compaction": + case "branch_summary": + case "custom": + case "custom_message": + case "label": + } + // branch_summary and custom_message are user-role messages, valid cut points + if (entry.type === "branch_summary" || entry.type === "custom_message") { + cutPoints.push(i); + } + } + return cutPoints; +} + +/** + * Find the user message (or bashExecution) that starts the turn containing the given entry index. + * Returns -1 if no turn start found before the index. + * BashExecutionMessage is treated like a user message for turn boundaries. + */ +export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number { + for (let i = entryIndex; i >= startIndex; i--) { + const entry = entries[i]; + // branch_summary and custom_message are user-role messages, can start a turn + if (entry.type === "branch_summary" || entry.type === "custom_message") { + return i; + } + if (entry.type === "message") { + const role = entry.message.role; + if (role === "user" || role === "bashExecution") { + return i; + } + } + } + return -1; +} + +export interface CutPointResult { + /** Index of first entry to keep */ + firstKeptEntryIndex: number; + /** Index of user message that starts the turn being split, or -1 if not splitting */ + turnStartIndex: number; + /** Whether this cut splits a turn (cut point is not a user message) */ + isSplitTurn: boolean; +} + +/** + * Find the cut point in session entries that keeps approximately `keepRecentTokens`. + * + * Algorithm: Walk backwards from newest, accumulating estimated message sizes. + * Stop when we've accumulated >= keepRecentTokens. Cut at that point. + * + * Can cut at user OR assistant messages (never tool results). When cutting at an + * assistant message with tool calls, its tool results come after and will be kept. + * + * Returns CutPointResult with: + * - firstKeptEntryIndex: the entry index to start keeping from + * - turnStartIndex: if cutting mid-turn, the user message that started that turn + * - isSplitTurn: whether we're cutting in the middle of a turn + * + * Only considers entries between `startIndex` and `endIndex` (exclusive). + */ +export function findCutPoint( + entries: SessionEntry[], + startIndex: number, + endIndex: number, + keepRecentTokens: number, +): CutPointResult { + const cutPoints = findValidCutPoints(entries, startIndex, endIndex); + + if (cutPoints.length === 0) { + return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false }; + } + + // Walk backwards from newest, accumulating estimated message sizes + let accumulatedTokens = 0; + let cutIndex = cutPoints[0]; // Default: keep from first message (not header) + + for (let i = endIndex - 1; i >= startIndex; i--) { + const entry = entries[i]; + if (entry.type !== "message") continue; + + // Estimate this message's size + const messageTokens = estimateTokens(entry.message); + accumulatedTokens += messageTokens; + + // Check if we've exceeded the budget + if (accumulatedTokens >= keepRecentTokens) { + // Find the closest valid cut point at or after this entry + for (let c = 0; c < cutPoints.length; c++) { + if (cutPoints[c] >= i) { + cutIndex = cutPoints[c]; + break; + } + } + break; + } + } + + // Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.) + while (cutIndex > startIndex) { + const prevEntry = entries[cutIndex - 1]; + // Stop at session header or compaction boundaries + if (prevEntry.type === "compaction") { + break; + } + if (prevEntry.type === "message") { + // Stop if we hit any message + break; + } + // Include this non-message entry (bash, settings change, etc.) + cutIndex--; + } + + // Determine if this is a split turn + const cutEntry = entries[cutIndex]; + const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user"; + const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex); + + return { + firstKeptEntryIndex: cutIndex, + turnStartIndex, + isSplitTurn: !isUserMessage && turnStartIndex !== -1, + }; +} + +// ============================================================================ +// Summarization +// ============================================================================ + +const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work. + +Use this EXACT format: + +## Goal +[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.] + +## Constraints & Preferences +- [Any constraints, preferences, or requirements mentioned by user] +- [Or "(none)" if none were mentioned] + +## Progress +### Done +- [x] [Completed tasks/changes] + +### In Progress +- [ ] [Current work] + +### Blocked +- [Issues preventing progress, if any] + +## Key Decisions +- **[Decision]**: [Brief rationale] + +## Next Steps +1. [Ordered list of what should happen next] + +## Critical Context +- [Any data, examples, or references needed to continue] +- [Or "(none)" if not applicable] + +Keep each section concise. Preserve exact file paths, function names, and error messages.`; + +const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in tags. + +Update the existing structured summary with new information. RULES: +- PRESERVE all existing information from the previous summary +- ADD new progress, decisions, and context from the new messages +- UPDATE the Progress section: move items from "In Progress" to "Done" when completed +- UPDATE "Next Steps" based on what was accomplished +- PRESERVE exact file paths, function names, and error messages +- If something is no longer relevant, you may remove it + +Use this EXACT format: + +## Goal +[Preserve existing goals, add new ones if the task expanded] + +## Constraints & Preferences +- [Preserve existing, add new ones discovered] + +## Progress +### Done +- [x] [Include previously done items AND newly completed items] + +### In Progress +- [ ] [Current work - update based on progress] + +### Blocked +- [Current blockers - remove if resolved] + +## Key Decisions +- **[Decision]**: [Brief rationale] (preserve all previous, add new) + +## Next Steps +1. [Update based on current state] + +## Critical Context +- [Preserve important context, add new if needed] + +Keep each section concise. Preserve exact file paths, function names, and error messages.`; + +/** + * Generate a summary of the conversation using the LLM. + * If previousSummary is provided, uses the update prompt to merge. + */ +export async function generateSummary( + currentMessages: AgentMessage[], + model: Model, + reserveTokens: number, + apiKey: string, + signal?: AbortSignal, + customInstructions?: string, + previousSummary?: string, +): Promise { + const maxTokens = Math.floor(0.8 * reserveTokens); + + // Use update prompt if we have a previous summary, otherwise initial prompt + let basePrompt = previousSummary ? UPDATE_SUMMARIZATION_PROMPT : SUMMARIZATION_PROMPT; + if (customInstructions) { + basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`; + } + + // Serialize conversation to text so model doesn't try to continue it + // Convert to LLM messages first (handles custom types like bashExecution, hookMessage, etc.) + const llmMessages = convertToLlm(currentMessages); + const conversationText = serializeConversation(llmMessages); + + // Build the prompt with conversation wrapped in tags + let promptText = `\n${conversationText}\n\n\n`; + if (previousSummary) { + promptText += `\n${previousSummary}\n\n\n`; + } + promptText += basePrompt; + + const summarizationMessages = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: promptText }], + timestamp: Date.now(), + }, + ]; + + const response = await completeSimple( + model, + { systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages }, + { maxTokens, signal, apiKey, reasoning: "high" }, + ); + + if (response.stopReason === "error") { + throw new Error(`Summarization failed: ${response.errorMessage || "Unknown error"}`); + } + + const textContent = response.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join("\n"); + + return textContent; +} + +// ============================================================================ +// Compaction Preparation (for hooks) +// ============================================================================ + +export interface CompactionPreparation { + cutPoint: CutPointResult; + /** UUID of first entry to keep */ + firstKeptEntryId: string; + /** Messages that will be summarized and discarded */ + messagesToSummarize: AgentMessage[]; + /** Messages that will be kept after the summary (recent turns) */ + messagesToKeep: AgentMessage[]; + tokensBefore: number; + boundaryStart: number; +} + +export function prepareCompaction( + entries: SessionEntry[], + settings: CompactionSettings, +): CompactionPreparation | undefined { + if (entries.length > 0 && entries[entries.length - 1].type === "compaction") { + return undefined; + } + + let prevCompactionIndex = -1; + for (let i = entries.length - 1; i >= 0; i--) { + if (entries[i].type === "compaction") { + prevCompactionIndex = i; + break; + } + } + const boundaryStart = prevCompactionIndex + 1; + const boundaryEnd = entries.length; + + const lastUsage = getLastAssistantUsage(entries); + const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0; + + const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens); + + // Get UUID of first kept entry + const firstKeptEntry = entries[cutPoint.firstKeptEntryIndex]; + if (!firstKeptEntry?.id) { + return undefined; // Session needs migration + } + const firstKeptEntryId = firstKeptEntry.id; + + const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex; + + // Messages to summarize (will be discarded after summary) + const messagesToSummarize: AgentMessage[] = []; + for (let i = boundaryStart; i < historyEnd; i++) { + const msg = getMessageFromEntry(entries[i]); + if (msg) messagesToSummarize.push(msg); + } + + // Messages to keep (recent turns, kept after summary) + const messagesToKeep: AgentMessage[] = []; + for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) { + const msg = getMessageFromEntry(entries[i]); + if (msg) messagesToKeep.push(msg); + } + + return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart }; +} + +// ============================================================================ +// Main compaction function +// ============================================================================ + +const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained. + +Summarize the prefix to provide context for the retained suffix: + +## Original Request +[What did the user ask for in this turn?] + +## Early Progress +- [Key decisions and work done in the prefix] + +## Context for Suffix +- [Information needed to understand the retained recent work] + +Be concise. Focus on what's needed to understand the kept suffix.`; + +/** + * Calculate compaction and generate summary. + * Returns CompactionResult - SessionManager adds uuid/parentUuid when saving. + * + * @param entries - All session entries (must have uuid fields for v2) + * @param model - Model to use for summarization + * @param settings - Compaction settings + * @param apiKey - API key for LLM + * @param signal - Optional abort signal + * @param customInstructions - Optional custom focus for the summary + */ +export async function compact( + entries: SessionEntry[], + model: Model, + settings: CompactionSettings, + apiKey: string, + signal?: AbortSignal, + customInstructions?: string, +): Promise { + // Don't compact if the last entry is already a compaction + if (entries.length > 0 && entries[entries.length - 1].type === "compaction") { + throw new Error("Already compacted"); + } + + // Find previous compaction boundary + let prevCompactionIndex = -1; + for (let i = entries.length - 1; i >= 0; i--) { + if (entries[i].type === "compaction") { + prevCompactionIndex = i; + break; + } + } + const boundaryStart = prevCompactionIndex + 1; + const boundaryEnd = entries.length; + + // Get token count before compaction + const lastUsage = getLastAssistantUsage(entries); + const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0; + + // Find cut point (entry index) within the valid range + const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens); + + // Extract messages for history summary (before the turn that contains the cut point) + const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex; + const historyMessages: AgentMessage[] = []; + for (let i = boundaryStart; i < historyEnd; i++) { + const msg = getMessageFromEntry(entries[i]); + if (msg) historyMessages.push(msg); + } + + // Get previous summary for iterative update (if not from hook) + let previousSummary: string | undefined; + if (prevCompactionIndex >= 0) { + const prevCompaction = entries[prevCompactionIndex] as CompactionEntry; + previousSummary = prevCompaction.summary; + } + + // Extract file operations from messages and previous compaction + const fileOps = extractFileOperations(historyMessages, entries, prevCompactionIndex); + + // Extract messages for turn prefix summary (if splitting a turn) + const turnPrefixMessages: AgentMessage[] = []; + if (cutResult.isSplitTurn) { + for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) { + const msg = getMessageFromEntry(entries[i]); + if (msg) turnPrefixMessages.push(msg); + } + // Also extract file ops from turn prefix + for (const msg of turnPrefixMessages) { + extractFileOpsFromMessage(msg, fileOps); + } + } + + // Generate summaries (can be parallel if both needed) and merge into one + let summary: string; + + if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) { + // Generate both summaries in parallel + const [historyResult, turnPrefixResult] = await Promise.all([ + historyMessages.length > 0 + ? generateSummary( + historyMessages, + model, + settings.reserveTokens, + apiKey, + signal, + customInstructions, + previousSummary, + ) + : Promise.resolve("No prior history."), + generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal), + ]); + // Merge into single summary + summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`; + } else { + // Just generate history summary + summary = await generateSummary( + historyMessages, + model, + settings.reserveTokens, + apiKey, + signal, + customInstructions, + previousSummary, + ); + } + + // Compute file lists and append to summary + const { readFiles, modifiedFiles } = computeFileLists(fileOps); + summary += formatFileOperations(readFiles, modifiedFiles); + + // Get UUID of first kept entry + const firstKeptEntry = entries[cutResult.firstKeptEntryIndex]; + const firstKeptEntryId = firstKeptEntry.id; + if (!firstKeptEntryId) { + throw new Error("First kept entry has no UUID - session may need migration"); + } + + return { + summary, + firstKeptEntryId, + tokensBefore, + details: { readFiles, modifiedFiles } as CompactionDetails, + }; +} + +/** + * Generate a summary for a turn prefix (when splitting a turn). + */ +async function generateTurnPrefixSummary( + messages: AgentMessage[], + model: Model, + reserveTokens: number, + apiKey: string, + signal?: AbortSignal, +): Promise { + const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix + + const transformedMessages = convertToLlm(messages); + const summarizationMessages = [ + ...transformedMessages, + { + role: "user" as const, + content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }], + timestamp: Date.now(), + }, + ]; + + const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey }); + + if (response.stopReason === "error") { + throw new Error(`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`); + } + + return response.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join("\n"); +} diff --git a/packages/coding-agent/src/core/compaction/index.ts b/packages/coding-agent/src/core/compaction/index.ts new file mode 100644 index 00000000..d8c92a67 --- /dev/null +++ b/packages/coding-agent/src/core/compaction/index.ts @@ -0,0 +1,7 @@ +/** + * Compaction and summarization utilities. + */ + +export * from "./branch-summarization.js"; +export * from "./compaction.js"; +export * from "./utils.js"; diff --git a/packages/coding-agent/src/core/compaction/utils.ts b/packages/coding-agent/src/core/compaction/utils.ts new file mode 100644 index 00000000..9c8f46bc --- /dev/null +++ b/packages/coding-agent/src/core/compaction/utils.ts @@ -0,0 +1,154 @@ +/** + * Shared utilities for compaction and branch summarization. + */ + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { Message } from "@mariozechner/pi-ai"; + +// ============================================================================ +// File Operation Tracking +// ============================================================================ + +export interface FileOperations { + read: Set; + written: Set; + edited: Set; +} + +export function createFileOps(): FileOperations { + return { + read: new Set(), + written: new Set(), + edited: new Set(), + }; +} + +/** + * Extract file operations from tool calls in an assistant message. + */ +export function extractFileOpsFromMessage(message: AgentMessage, fileOps: FileOperations): void { + if (message.role !== "assistant") return; + if (!("content" in message) || !Array.isArray(message.content)) return; + + for (const block of message.content) { + if (typeof block !== "object" || block === null) continue; + if (!("type" in block) || block.type !== "toolCall") continue; + if (!("arguments" in block) || !("name" in block)) continue; + + const args = block.arguments as Record | undefined; + if (!args) continue; + + const path = typeof args.path === "string" ? args.path : undefined; + if (!path) continue; + + switch (block.name) { + case "read": + fileOps.read.add(path); + break; + case "write": + fileOps.written.add(path); + break; + case "edit": + fileOps.edited.add(path); + break; + } + } +} + +/** + * Compute final file lists from file operations. + * Returns readFiles (files only read, not modified) and modifiedFiles. + */ +export function computeFileLists(fileOps: FileOperations): { readFiles: string[]; modifiedFiles: string[] } { + const modified = new Set([...fileOps.edited, ...fileOps.written]); + const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort(); + const modifiedFiles = [...modified].sort(); + return { readFiles: readOnly, modifiedFiles }; +} + +/** + * Format file operations as XML tags for summary. + */ +export function formatFileOperations(readFiles: string[], modifiedFiles: string[]): string { + const sections: string[] = []; + if (readFiles.length > 0) { + sections.push(`\n${readFiles.join("\n")}\n`); + } + if (modifiedFiles.length > 0) { + sections.push(`\n${modifiedFiles.join("\n")}\n`); + } + if (sections.length === 0) return ""; + return `\n\n${sections.join("\n\n")}`; +} + +// ============================================================================ +// Message Serialization +// ============================================================================ + +/** + * Serialize LLM messages to text for summarization. + * This prevents the model from treating it as a conversation to continue. + * Call convertToLlm() first to handle custom message types. + */ +export function serializeConversation(messages: Message[]): string { + const parts: string[] = []; + + for (const msg of messages) { + if (msg.role === "user") { + const content = + typeof msg.content === "string" + ? msg.content + : msg.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join(""); + if (content) parts.push(`[User]: ${content}`); + } else if (msg.role === "assistant") { + const textParts: string[] = []; + const thinkingParts: string[] = []; + const toolCalls: string[] = []; + + for (const block of msg.content) { + if (block.type === "text") { + textParts.push(block.text); + } else if (block.type === "thinking") { + thinkingParts.push(block.thinking); + } else if (block.type === "toolCall") { + const args = block.arguments as Record; + const argsStr = Object.entries(args) + .map(([k, v]) => `${k}=${JSON.stringify(v)}`) + .join(", "); + toolCalls.push(`${block.name}(${argsStr})`); + } + } + + if (thinkingParts.length > 0) { + parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`); + } + if (textParts.length > 0) { + parts.push(`[Assistant]: ${textParts.join("\n")}`); + } + if (toolCalls.length > 0) { + parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`); + } + } else if (msg.role === "toolResult") { + const content = msg.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join(""); + if (content) { + parts.push(`[Tool result]: ${content}`); + } + } + } + + return parts.join("\n\n"); +} + +// ============================================================================ +// Summarization System Prompt +// ============================================================================ + +export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified. + +Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`; diff --git a/packages/coding-agent/src/core/custom-tools/loader.ts b/packages/coding-agent/src/core/custom-tools/loader.ts index 095f41c8..f4480611 100644 --- a/packages/coding-agent/src/core/custom-tools/loader.ts +++ b/packages/coding-agent/src/core/custom-tools/loader.ts @@ -7,7 +7,6 @@ * for custom tools that depend on pi packages. */ -import { spawn } from "node:child_process"; import * as fs from "node:fs"; import { createRequire } from "node:module"; import * as os from "node:os"; @@ -15,15 +14,10 @@ import * as path from "node:path"; import { fileURLToPath } from "node:url"; import { createJiti } from "jiti"; import { getAgentDir, isBunBinary } from "../../config.js"; +import type { ExecOptions } from "../exec.js"; +import { execCommand } from "../exec.js"; import type { HookUIContext } from "../hooks/types.js"; -import type { - CustomToolFactory, - CustomToolsLoadResult, - ExecOptions, - ExecResult, - LoadedCustomTool, - ToolAPI, -} from "./types.js"; +import type { CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool, ToolAPI } from "./types.js"; // Create require function to resolve module paths at runtime const require = createRequire(import.meta.url); @@ -87,97 +81,16 @@ function resolveToolPath(toolPath: string, cwd: string): string { return path.resolve(cwd, expanded); } -/** - * Execute a command and return stdout/stderr/code. - * Supports cancellation via AbortSignal and timeout. - */ -async function execCommand(command: string, args: string[], cwd: string, options?: ExecOptions): Promise { - return new Promise((resolve) => { - const proc = spawn(command, args, { - cwd, - shell: false, - stdio: ["ignore", "pipe", "pipe"], - }); - - let stdout = ""; - let stderr = ""; - let killed = false; - let timeoutId: NodeJS.Timeout | undefined; - - const killProcess = () => { - if (!killed) { - killed = true; - proc.kill("SIGTERM"); - // Force kill after 5 seconds if SIGTERM doesn't work - setTimeout(() => { - if (!proc.killed) { - proc.kill("SIGKILL"); - } - }, 5000); - } - }; - - // Handle abort signal - if (options?.signal) { - if (options.signal.aborted) { - killProcess(); - } else { - options.signal.addEventListener("abort", killProcess, { once: true }); - } - } - - // Handle timeout - if (options?.timeout && options.timeout > 0) { - timeoutId = setTimeout(() => { - killProcess(); - }, options.timeout); - } - - proc.stdout.on("data", (data) => { - stdout += data.toString(); - }); - - proc.stderr.on("data", (data) => { - stderr += data.toString(); - }); - - proc.on("close", (code) => { - if (timeoutId) clearTimeout(timeoutId); - if (options?.signal) { - options.signal.removeEventListener("abort", killProcess); - } - resolve({ - stdout, - stderr, - code: code ?? 0, - killed, - }); - }); - - proc.on("error", (err) => { - if (timeoutId) clearTimeout(timeoutId); - if (options?.signal) { - options.signal.removeEventListener("abort", killProcess); - } - resolve({ - stdout, - stderr: stderr || err.message, - code: 1, - killed, - }); - }); - }); -} - /** * Create a no-op UI context for headless modes. */ function createNoOpUIContext(): HookUIContext { return { - select: async () => null, + select: async () => undefined, confirm: async () => false, - input: async () => null, + input: async () => undefined, notify: () => {}, + custom: () => ({ close: () => {}, requestRender: () => {} }), }; } @@ -298,7 +211,8 @@ export async function loadCustomTools( // Shared API object - all tools get the same instance const sharedApi: ToolAPI = { cwd, - exec: (command: string, args: string[], options?: ExecOptions) => execCommand(command, args, cwd, options), + exec: (command: string, args: string[], options?: ExecOptions) => + execCommand(command, args, options?.cwd ?? cwd, options), ui: createNoOpUIContext(), hasUI: false, }; diff --git a/packages/coding-agent/src/core/custom-tools/types.ts b/packages/coding-agent/src/core/custom-tools/types.ts index b5ccc591..28192b05 100644 --- a/packages/coding-agent/src/core/custom-tools/types.ts +++ b/packages/coding-agent/src/core/custom-tools/types.ts @@ -5,10 +5,11 @@ * They can provide custom rendering for tool calls and results in the TUI. */ -import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-ai"; +import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core"; import type { Component } from "@mariozechner/pi-tui"; import type { Static, TSchema } from "@sinclair/typebox"; import type { Theme } from "../../modes/interactive/theme/theme.js"; +import type { ExecOptions, ExecResult } from "../exec.js"; import type { HookUIContext } from "../hooks/types.js"; import type { SessionEntry } from "../session-manager.js"; @@ -18,20 +19,8 @@ export type ToolUIContext = HookUIContext; /** Re-export for custom tools to use in execute signature */ export type { AgentToolUpdateCallback }; -export interface ExecResult { - stdout: string; - stderr: string; - code: number; - /** True if the process was killed due to signal or timeout */ - killed?: boolean; -} - -export interface ExecOptions { - /** AbortSignal to cancel the process */ - signal?: AbortSignal; - /** Timeout in milliseconds */ - timeout?: number; -} +// Re-export for backward compatibility +export type { ExecOptions, ExecResult } from "../exec.js"; /** API passed to custom tool factory (stable across session changes) */ export interface ToolAPI { @@ -49,12 +38,12 @@ export interface ToolAPI { export interface SessionEvent { /** All session entries (including pre-compaction history) */ entries: SessionEntry[]; - /** Current session file path, or null in --no-session mode */ - sessionFile: string | null; - /** Previous session file path, or null for "start" and "new" */ - previousSessionFile: string | null; + /** Current session file path, or undefined in --no-session mode */ + sessionFile: string | undefined; + /** Previous session file path, or undefined for "start" and "new" */ + previousSessionFile: string | undefined; /** Reason for the session event */ - reason: "start" | "switch" | "branch" | "new"; + reason: "start" | "switch" | "branch" | "new" | "tree"; } /** Rendering options passed to renderResult */ diff --git a/packages/coding-agent/src/core/exec.ts b/packages/coding-agent/src/core/exec.ts new file mode 100644 index 00000000..fccf5504 --- /dev/null +++ b/packages/coding-agent/src/core/exec.ts @@ -0,0 +1,104 @@ +/** + * Shared command execution utilities for hooks and custom tools. + */ + +import { spawn } from "node:child_process"; + +/** + * Options for executing shell commands. + */ +export interface ExecOptions { + /** AbortSignal to cancel the command */ + signal?: AbortSignal; + /** Timeout in milliseconds */ + timeout?: number; + /** Working directory */ + cwd?: string; +} + +/** + * Result of executing a shell command. + */ +export interface ExecResult { + stdout: string; + stderr: string; + code: number; + killed: boolean; +} + +/** + * Execute a shell command and return stdout/stderr/code. + * Supports timeout and abort signal. + */ +export async function execCommand( + command: string, + args: string[], + cwd: string, + options?: ExecOptions, +): Promise { + return new Promise((resolve) => { + const proc = spawn(command, args, { + cwd, + shell: false, + stdio: ["ignore", "pipe", "pipe"], + }); + + let stdout = ""; + let stderr = ""; + let killed = false; + let timeoutId: NodeJS.Timeout | undefined; + + const killProcess = () => { + if (!killed) { + killed = true; + proc.kill("SIGTERM"); + // Force kill after 5 seconds if SIGTERM doesn't work + setTimeout(() => { + if (!proc.killed) { + proc.kill("SIGKILL"); + } + }, 5000); + } + }; + + // Handle abort signal + if (options?.signal) { + if (options.signal.aborted) { + killProcess(); + } else { + options.signal.addEventListener("abort", killProcess, { once: true }); + } + } + + // Handle timeout + if (options?.timeout && options.timeout > 0) { + timeoutId = setTimeout(() => { + killProcess(); + }, options.timeout); + } + + proc.stdout?.on("data", (data) => { + stdout += data.toString(); + }); + + proc.stderr?.on("data", (data) => { + stderr += data.toString(); + }); + + proc.on("close", (code) => { + if (timeoutId) clearTimeout(timeoutId); + if (options?.signal) { + options.signal.removeEventListener("abort", killProcess); + } + resolve({ stdout, stderr, code: code ?? 0, killed }); + }); + + proc.on("error", (_err) => { + if (timeoutId) clearTimeout(timeoutId); + if (options?.signal) { + options.signal.removeEventListener("abort", killProcess); + } + resolve({ stdout, stderr, code: 1, killed }); + }); + }); +} diff --git a/packages/coding-agent/src/core/export-html.ts b/packages/coding-agent/src/core/export-html.ts index bff53487..c6f3139f 100644 --- a/packages/coding-agent/src/core/export-html.ts +++ b/packages/coding-agent/src/core/export-html.ts @@ -1,4 +1,4 @@ -import type { AgentState } from "@mariozechner/pi-agent-core"; +import type { AgentMessage, AgentState } from "@mariozechner/pi-agent-core"; import type { AssistantMessage, ImageContent, Message, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai"; import { existsSync, readFileSync, writeFileSync } from "fs"; import hljs from "highlight.js"; @@ -7,7 +7,6 @@ import { homedir } from "os"; import * as path from "path"; import { basename } from "path"; import { APP_NAME, getCustomThemesDir, getThemesDir, VERSION } from "../config.js"; -import { type BashExecutionMessage, isBashExecutionMessage } from "./messages.js"; import type { SessionManager } from "./session-manager.js"; // ============================================================================ @@ -122,7 +121,7 @@ function resolveColorValue( } /** Load theme JSON from built-in or custom themes directory. */ -function loadThemeJson(name: string): ThemeJson | null { +function loadThemeJson(name: string): ThemeJson | undefined { // Try built-in themes first const themesDir = getThemesDir(); const builtinPath = path.join(themesDir, `${name}.json`); @@ -130,7 +129,7 @@ function loadThemeJson(name: string): ThemeJson | null { try { return JSON.parse(readFileSync(builtinPath, "utf-8")) as ThemeJson; } catch { - return null; + return undefined; } } @@ -141,11 +140,11 @@ function loadThemeJson(name: string): ThemeJson | null { try { return JSON.parse(readFileSync(customPath, "utf-8")) as ThemeJson; } catch { - return null; + return undefined; } } - return null; + return undefined; } /** Build complete theme colors object, resolving theme JSON values against defaults. */ @@ -821,110 +820,138 @@ function formatToolExecution( return { html, bgColor }; } -function formatMessage(message: Message, toolResultsMap: Map, colors: ThemeColors): string { +function formatMessage( + message: AgentMessage, + toolResultsMap: Map, + colors: ThemeColors, +): string { let html = ""; const timestamp = (message as { timestamp?: number }).timestamp; const timestampHtml = timestamp ? `
${formatTimestamp(timestamp)}
` : ""; - // Handle bash execution messages (user-executed via ! command) - if (isBashExecutionMessage(message)) { - const bashMsg = message as unknown as BashExecutionMessage; - const isError = bashMsg.cancelled || (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null); + switch (message.role) { + case "bashExecution": { + const isError = + message.cancelled || + (message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined); - html += `
`; - html += timestampHtml; - html += `
$ ${escapeHtml(bashMsg.command)}
`; + html += `
`; + html += timestampHtml; + html += `
$ ${escapeHtml(message.command)}
`; - if (bashMsg.output) { - const lines = bashMsg.output.split("\n"); - html += formatExpandableOutput(lines, 10); + if (message.output) { + const lines = message.output.split("\n"); + html += formatExpandableOutput(lines, 10); + } + + if (message.cancelled) { + html += `
(cancelled)
`; + } else if (message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined) { + html += `
(exit ${message.exitCode})
`; + } + + if (message.truncated && message.fullOutputPath) { + html += `
Output truncated. Full output: ${escapeHtml(message.fullOutputPath)}
`; + } + + html += `
`; + break; } + case "user": { + const userMsg = message as UserMessage; + let textContent = ""; + const images: ImageContent[] = []; - if (bashMsg.cancelled) { - html += `
(cancelled)
`; - } else if (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null) { - html += `
(exit ${bashMsg.exitCode})
`; - } - - if (bashMsg.truncated && bashMsg.fullOutputPath) { - html += `
Output truncated. Full output: ${escapeHtml(bashMsg.fullOutputPath)}
`; - } - - html += `
`; - return html; - } - - if (message.role === "user") { - const userMsg = message as UserMessage; - let textContent = ""; - const images: ImageContent[] = []; - - if (typeof userMsg.content === "string") { - textContent = userMsg.content; - } else { - for (const block of userMsg.content) { - if (block.type === "text") { - textContent += block.text; - } else if (block.type === "image") { - images.push(block as ImageContent); + if (typeof userMsg.content === "string") { + textContent = userMsg.content; + } else { + for (const block of userMsg.content) { + if (block.type === "text") { + textContent += block.text; + } else if (block.type === "image") { + images.push(block as ImageContent); + } } } - } - html += `
${timestampHtml}`; + html += `
${timestampHtml}`; - // Render images first - if (images.length > 0) { - html += `
`; - for (const img of images) { - html += `User uploaded image`; + // Render images first + if (images.length > 0) { + html += `
`; + for (const img of images) { + html += `User uploaded image`; + } + html += `
`; } + + // Render text as markdown (server-side) + if (textContent.trim()) { + html += `
${renderMarkdown(textContent)}
`; + } + html += `
`; + break; } + case "assistant": { + html += timestampHtml ? `
${timestampHtml}` : ""; - // Render text as markdown (server-side) - if (textContent.trim()) { - html += `
${renderMarkdown(textContent)}
`; - } - - html += `
`; - } else if (message.role === "assistant") { - const assistantMsg = message as AssistantMessage; - html += timestampHtml ? `
${timestampHtml}` : ""; - - for (const content of assistantMsg.content) { - if (content.type === "text" && content.text.trim()) { - // Render markdown server-side - html += `
${renderMarkdown(content.text)}
`; - } else if (content.type === "thinking" && content.thinking.trim()) { - html += `
${escapeHtml(content.thinking.trim()).replace(/\n/g, "
")}
`; + for (const content of message.content) { + if (content.type === "text" && content.text.trim()) { + // Render markdown server-side + html += `
${renderMarkdown(content.text)}
`; + } else if (content.type === "thinking" && content.thinking.trim()) { + html += `
${escapeHtml(content.thinking.trim()).replace(/\n/g, "
")}
`; + } } - } - for (const content of assistantMsg.content) { - if (content.type === "toolCall") { - const toolResult = toolResultsMap.get(content.id); - const { html: toolHtml, bgColor } = formatToolExecution( - content.name, - content.arguments as Record, - toolResult, - colors, - ); - html += `
${toolHtml}
`; + for (const content of message.content) { + if (content.type === "toolCall") { + const toolResult = toolResultsMap.get(content.id); + const { html: toolHtml, bgColor } = formatToolExecution( + content.name, + content.arguments as Record, + toolResult, + colors, + ); + html += `
${toolHtml}
`; + } } - } - const hasToolCalls = assistantMsg.content.some((c) => c.type === "toolCall"); - if (!hasToolCalls) { - if (assistantMsg.stopReason === "aborted") { - html += '
Aborted
'; - } else if (assistantMsg.stopReason === "error") { - html += `
Error: ${escapeHtml(assistantMsg.errorMessage || "Unknown error")}
`; + const hasToolCalls = message.content.some((c) => c.type === "toolCall"); + if (!hasToolCalls) { + if (message.stopReason === "aborted") { + html += '
Aborted
'; + } else if (message.stopReason === "error") { + html += `
Error: ${escapeHtml(message.errorMessage || "Unknown error")}
`; + } } - } - if (timestampHtml) { - html += "
"; + if (timestampHtml) { + html += "
"; + } + break; + } + case "toolResult": + // Tool results are rendered inline with tool calls + break; + case "hookMessage": + // Hook messages with display:true shown as info boxes + if (message.display) { + const content = typeof message.content === "string" ? message.content : JSON.stringify(message.content); + html += `
${timestampHtml}
[${escapeHtml(message.customType)}]
${renderMarkdown(content)}
`; + } + break; + case "compactionSummary": + // Rendered separately via formatCompaction + break; + case "branchSummary": + // Rendered as compaction-like summary + html += `
Branch Summary
${escapeHtml(message.summary).replace(/\n/g, "
")}
`; + break; + default: { + // Exhaustive check + const _exhaustive: never = message; } } @@ -995,7 +1022,7 @@ function generateHtml(data: ParsedSessionData, filename: string, colors: ThemeCo const lastModelInfo = lastProvider ? `${lastProvider}/${lastModel}` : lastModel; const contextWindow = data.contextWindow || 0; - const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : null; + const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : undefined; let messagesHtml = ""; for (const event of data.sessionEvents) { @@ -1343,6 +1370,9 @@ export function exportSessionToHtml( const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {}; const sessionFile = sessionManager.getSessionFile(); + if (!sessionFile) { + throw new Error("Cannot export in-memory session to HTML"); + } const content = readFileSync(sessionFile, "utf8"); const data = parseSessionFile(content); diff --git a/packages/coding-agent/src/core/hooks/index.ts b/packages/coding-agent/src/core/hooks/index.ts index db222989..cb0806ed 100644 --- a/packages/coding-agent/src/core/hooks/index.ts +++ b/packages/coding-agent/src/core/hooks/index.ts @@ -1,39 +1,13 @@ -export { discoverAndLoadHooks, type LoadedHook, type LoadHooksResult, loadHooks, type SendHandler } from "./loader.js"; -export { type HookErrorListener, HookRunner } from "./runner.js"; -export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js"; -export type { - AgentEndEvent, - AgentStartEvent, - BashToolResultEvent, - CustomToolResultEvent, - EditToolResultEvent, - ExecResult, - FindToolResultEvent, - GrepToolResultEvent, - HookAPI, - HookError, - HookEvent, - HookEventContext, - HookFactory, - HookUIContext, - LsToolResultEvent, - ReadToolResultEvent, - SessionEvent, - SessionEventResult, - ToolCallEvent, - ToolCallEventResult, - ToolResultEvent, - ToolResultEventResult, - TurnEndEvent, - TurnStartEvent, - WriteToolResultEvent, -} from "./types.js"; +// biome-ignore assist/source/organizeImports: biome is not smart export { - isBashToolResult, - isEditToolResult, - isFindToolResult, - isGrepToolResult, - isLsToolResult, - isReadToolResult, - isWriteToolResult, -} from "./types.js"; + discoverAndLoadHooks, + loadHooks, + type AppendEntryHandler, + type LoadedHook, + type LoadHooksResult, + type SendMessageHandler, +} from "./loader.js"; +export { execCommand, HookRunner, type HookErrorListener } from "./runner.js"; +export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js"; +export type * from "./types.js"; +export type { ReadonlySessionManager } from "../session-manager.js"; diff --git a/packages/coding-agent/src/core/hooks/loader.ts b/packages/coding-agent/src/core/hooks/loader.ts index 18e09c19..3ac44b27 100644 --- a/packages/coding-agent/src/core/hooks/loader.ts +++ b/packages/coding-agent/src/core/hooks/loader.ts @@ -7,10 +7,11 @@ import { createRequire } from "node:module"; import * as os from "node:os"; import * as path from "node:path"; import { fileURLToPath } from "node:url"; -import type { Attachment } from "@mariozechner/pi-agent-core"; import { createJiti } from "jiti"; import { getAgentDir } from "../../config.js"; -import type { HookAPI, HookFactory } from "./types.js"; +import type { HookMessage } from "../messages.js"; +import { execCommand } from "./runner.js"; +import type { ExecOptions, HookAPI, HookFactory, HookMessageRenderer, RegisteredCommand } from "./types.js"; // Create require function to resolve module paths at runtime const require = createRequire(import.meta.url); @@ -47,9 +48,17 @@ function getAliases(): Record { type HandlerFn = (...args: unknown[]) => Promise; /** - * Send handler type for pi.send(). + * Send message handler type for pi.sendMessage(). */ -export type SendHandler = (text: string, attachments?: Attachment[]) => void; +export type SendMessageHandler = ( + message: Pick, "customType" | "content" | "display" | "details">, + triggerTurn?: boolean, +) => void; + +/** + * Append entry handler type for pi.appendEntry(). + */ +export type AppendEntryHandler = (customType: string, data?: T) => void; /** * Registered handlers for a loaded hook. @@ -61,8 +70,14 @@ export interface LoadedHook { resolvedPath: string; /** Map of event type to handler functions */ handlers: Map; - /** Set the send handler for this hook's pi.send() */ - setSendHandler: (handler: SendHandler) => void; + /** Map of customType to hook message renderer */ + messageRenderers: Map; + /** Map of command name to registered command */ + commands: Map; + /** Set the send message handler for this hook's pi.sendMessage() */ + setSendMessageHandler: (handler: SendMessageHandler) => void; + /** Set the append entry handler for this hook's pi.appendEntry() */ + setAppendEntryHandler: (handler: AppendEntryHandler) => void; } /** @@ -110,32 +125,62 @@ function resolveHookPath(hookPath: string, cwd: string): string { } /** - * Create a HookAPI instance that collects handlers. - * Returns the API and a function to set the send handler later. + * Create a HookAPI instance that collects handlers, renderers, and commands. + * Returns the API, maps, and a function to set the send message handler later. */ -function createHookAPI(handlers: Map): { +function createHookAPI( + handlers: Map, + cwd: string, +): { api: HookAPI; - setSendHandler: (handler: SendHandler) => void; + messageRenderers: Map; + commands: Map; + setSendMessageHandler: (handler: SendMessageHandler) => void; + setAppendEntryHandler: (handler: AppendEntryHandler) => void; } { - let sendHandler: SendHandler = () => { + let sendMessageHandler: SendMessageHandler = () => { // Default no-op until mode sets the handler }; + let appendEntryHandler: AppendEntryHandler = () => { + // Default no-op until mode sets the handler + }; + const messageRenderers = new Map(); + const commands = new Map(); - const api: HookAPI = { + // Cast to HookAPI - the implementation is more general (string event names) + // but the interface has specific overloads for type safety in hooks + const api = { on(event: string, handler: HandlerFn): void { const list = handlers.get(event) ?? []; list.push(handler); handlers.set(event, list); }, - send(text: string, attachments?: Attachment[]): void { - sendHandler(text, attachments); + sendMessage(message: HookMessage, triggerTurn?: boolean): void { + sendMessageHandler(message, triggerTurn); + }, + appendEntry(customType: string, data?: T): void { + appendEntryHandler(customType, data); + }, + registerMessageRenderer(customType: string, renderer: HookMessageRenderer): void { + messageRenderers.set(customType, renderer as HookMessageRenderer); + }, + registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void { + commands.set(name, { name, ...options }); + }, + exec(command: string, args: string[], options?: ExecOptions) { + return execCommand(command, args, options?.cwd ?? cwd, options); }, } as HookAPI; return { api, - setSendHandler: (handler: SendHandler) => { - sendHandler = handler; + messageRenderers, + commands, + setSendMessageHandler: (handler: SendMessageHandler) => { + sendMessageHandler = handler; + }, + setAppendEntryHandler: (handler: AppendEntryHandler) => { + appendEntryHandler = handler; }, }; } @@ -164,13 +209,24 @@ async function loadHook(hookPath: string, cwd: string): Promise<{ hook: LoadedHo // Create handlers map and API const handlers = new Map(); - const { api, setSendHandler } = createHookAPI(handlers); + const { api, messageRenderers, commands, setSendMessageHandler, setAppendEntryHandler } = createHookAPI( + handlers, + cwd, + ); // Call factory to register handlers factory(api); return { - hook: { path: hookPath, resolvedPath, handlers, setSendHandler }, + hook: { + path: hookPath, + resolvedPath, + handlers, + messageRenderers, + commands, + setSendMessageHandler, + setAppendEntryHandler, + }, error: null, }; } catch (err) { diff --git a/packages/coding-agent/src/core/hooks/runner.ts b/packages/coding-agent/src/core/hooks/runner.ts index e7646a19..fb908cf8 100644 --- a/packages/coding-agent/src/core/hooks/runner.ts +++ b/packages/coding-agent/src/core/hooks/runner.ts @@ -2,17 +2,23 @@ * Hook runner - executes hooks and manages their lifecycle. */ -import { spawn } from "node:child_process"; -import type { LoadedHook, SendHandler } from "./loader.js"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ModelRegistry } from "../model-registry.js"; +import type { SessionManager } from "../session-manager.js"; +import type { AppendEntryHandler, LoadedHook, SendMessageHandler } from "./loader.js"; import type { - ExecOptions, - ExecResult, + BeforeAgentStartEvent, + BeforeAgentStartEventResult, + ContextEvent, + ContextEventResult, HookError, HookEvent, HookEventContext, + HookMessageRenderer, HookUIContext, - SessionEvent, - SessionEventResult, + RegisteredCommand, + SessionBeforeCompactResult, + SessionBeforeTreeResult, ToolCallEvent, ToolCallEventResult, ToolResultEventResult, @@ -28,73 +34,8 @@ const DEFAULT_TIMEOUT = 30000; */ export type HookErrorListener = (error: HookError) => void; -/** - * Execute a command and return stdout/stderr/code. - * Supports cancellation via AbortSignal and timeout. - */ -async function exec(command: string, args: string[], cwd: string, options?: ExecOptions): Promise { - return new Promise((resolve) => { - const proc = spawn(command, args, { cwd, shell: false }); - - let stdout = ""; - let stderr = ""; - let killed = false; - let timeoutId: NodeJS.Timeout | undefined; - - const killProcess = () => { - if (!killed) { - killed = true; - proc.kill("SIGTERM"); - // Force kill after 5 seconds if SIGTERM doesn't work - setTimeout(() => { - if (!proc.killed) { - proc.kill("SIGKILL"); - } - }, 5000); - } - }; - - // Handle abort signal - if (options?.signal) { - if (options.signal.aborted) { - killProcess(); - } else { - options.signal.addEventListener("abort", killProcess, { once: true }); - } - } - - // Handle timeout - if (options?.timeout && options.timeout > 0) { - timeoutId = setTimeout(() => { - killProcess(); - }, options.timeout); - } - - proc.stdout?.on("data", (data) => { - stdout += data.toString(); - }); - - proc.stderr?.on("data", (data) => { - stderr += data.toString(); - }); - - proc.on("close", (code) => { - if (timeoutId) clearTimeout(timeoutId); - if (options?.signal) { - options.signal.removeEventListener("abort", killProcess); - } - resolve({ stdout, stderr, code: code ?? 0, killed }); - }); - - proc.on("error", (_err) => { - if (timeoutId) clearTimeout(timeoutId); - if (options?.signal) { - options.signal.removeEventListener("abort", killProcess); - } - resolve({ stdout, stderr, code: 1, killed }); - }); - }); -} +// Re-export execCommand for backward compatibility +export { execCommand } from "../exec.js"; /** * Create a promise that rejects after a timeout. @@ -112,10 +53,11 @@ function createTimeout(ms: number): { promise: Promise; clear: () => void /** No-op UI context used when no UI is available */ const noOpUIContext: HookUIContext = { - select: async () => null, + select: async () => undefined, confirm: async () => false, - input: async () => null, + input: async () => undefined, notify: () => {}, + custom: () => ({ close: () => {}, requestRender: () => {} }), }; /** @@ -126,16 +68,24 @@ export class HookRunner { private uiContext: HookUIContext; private hasUI: boolean; private cwd: string; - private sessionFile: string | null; + private sessionManager: SessionManager; + private modelRegistry: ModelRegistry; private timeout: number; private errorListeners: Set = new Set(); - constructor(hooks: LoadedHook[], cwd: string, timeout: number = DEFAULT_TIMEOUT) { + constructor( + hooks: LoadedHook[], + cwd: string, + sessionManager: SessionManager, + modelRegistry: ModelRegistry, + timeout: number = DEFAULT_TIMEOUT, + ) { this.hooks = hooks; this.uiContext = noOpUIContext; this.hasUI = false; this.cwd = cwd; - this.sessionFile = null; + this.sessionManager = sessionManager; + this.modelRegistry = modelRegistry; this.timeout = timeout; } @@ -148,6 +98,20 @@ export class HookRunner { this.hasUI = hasUI; } + /** + * Get the UI context (set by mode). + */ + getUIContext(): HookUIContext | null { + return this.uiContext; + } + + /** + * Get whether UI is available. + */ + getHasUI(): boolean { + return this.hasUI; + } + /** * Get the paths of all loaded hooks. */ @@ -156,19 +120,22 @@ export class HookRunner { } /** - * Set the session file path. + * Set the send message handler for all hooks' pi.sendMessage(). + * Call this when the mode initializes. */ - setSessionFile(sessionFile: string | null): void { - this.sessionFile = sessionFile; + setSendMessageHandler(handler: SendMessageHandler): void { + for (const hook of this.hooks) { + hook.setSendMessageHandler(handler); + } } /** - * Set the send handler for all hooks' pi.send(). + * Set the append entry handler for all hooks' pi.appendEntry(). * Call this when the mode initializes. */ - setSendHandler(handler: SendHandler): void { + setAppendEntryHandler(handler: AppendEntryHandler): void { for (const hook of this.hooks) { - hook.setSendHandler(handler); + hook.setAppendEntryHandler(handler); } } @@ -184,7 +151,10 @@ export class HookRunner { /** * Emit an error to all listeners. */ - private emitError(error: HookError): void { + /** + * Emit an error to all error listeners. + */ + emitError(error: HookError): void { for (const listener of this.errorListeners) { listener(error); } @@ -203,26 +173,89 @@ export class HookRunner { return false; } + /** + * Get a message renderer for the given customType. + * Returns the first renderer found across all hooks, or undefined if none. + */ + getMessageRenderer(customType: string): HookMessageRenderer | undefined { + for (const hook of this.hooks) { + const renderer = hook.messageRenderers.get(customType); + if (renderer) { + return renderer; + } + } + return undefined; + } + + /** + * Get all registered commands from all hooks. + */ + getRegisteredCommands(): RegisteredCommand[] { + const commands: RegisteredCommand[] = []; + for (const hook of this.hooks) { + for (const command of hook.commands.values()) { + commands.push(command); + } + } + return commands; + } + + /** + * Get a registered command by name. + * Returns the first command found across all hooks, or undefined if none. + */ + getCommand(name: string): RegisteredCommand | undefined { + for (const hook of this.hooks) { + const command = hook.commands.get(name); + if (command) { + return command; + } + } + return undefined; + } + /** * Create the event context for handlers. */ private createContext(): HookEventContext { return { - exec: (command: string, args: string[], options?: ExecOptions) => exec(command, args, this.cwd, options), ui: this.uiContext, hasUI: this.hasUI, cwd: this.cwd, - sessionFile: this.sessionFile, + sessionManager: this.sessionManager, + modelRegistry: this.modelRegistry, }; } /** - * Emit an event to all hooks. - * Returns the result from session/tool_result events (if any handler returns one). + * Check if event type is a session "before_*" event that can be cancelled. */ - async emit(event: HookEvent): Promise { + private isSessionBeforeEvent( + type: string, + ): type is + | "session_before_switch" + | "session_before_new" + | "session_before_branch" + | "session_before_compact" + | "session_before_tree" { + return ( + type === "session_before_switch" || + type === "session_before_new" || + type === "session_before_branch" || + type === "session_before_compact" || + type === "session_before_tree" + ); + } + + /** + * Emit an event to all hooks. + * Returns the result from session before_* / tool_result events (if any handler returns one). + */ + async emit( + event: HookEvent, + ): Promise { const ctx = this.createContext(); - let result: SessionEventResult | ToolResultEventResult | undefined; + let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined; for (const hook of this.hooks) { const handlers = hook.handlers.get(event.type); @@ -230,11 +263,10 @@ export class HookRunner { for (const handler of handlers) { try { - // No timeout for before_compact events (like tool_call, they may take a while) - const isBeforeCompact = event.type === "session" && (event as SessionEvent).reason === "before_compact"; + // No timeout for session_before_compact events (like tool_call, they may take a while) let handlerResult: unknown; - if (isBeforeCompact) { + if (event.type === "session_before_compact") { handlerResult = await handler(event, ctx); } else { const timeout = createTimeout(this.timeout); @@ -242,9 +274,9 @@ export class HookRunner { timeout.clear(); } - // For session events, capture the result (for before_* cancellation) - if (event.type === "session" && handlerResult) { - result = handlerResult as SessionEventResult; + // For session before_* events, capture the result (for cancellation) + if (this.isSessionBeforeEvent(event.type) && handlerResult) { + result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult; // If cancelled, stop processing further hooks if (result.cancel) { return result; @@ -298,4 +330,83 @@ export class HookRunner { return result; } + + /** + * Emit a context event to all hooks. + * Handlers are chained - each gets the previous handler's output (if any). + * Returns the final modified messages, or the original if no modifications. + * + * Note: Messages are already deep-copied by the caller (pi-ai preprocessor). + */ + async emitContext(messages: AgentMessage[]): Promise { + const ctx = this.createContext(); + let currentMessages = messages; + + for (const hook of this.hooks) { + const handlers = hook.handlers.get("context"); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + try { + const event: ContextEvent = { type: "context", messages: currentMessages }; + const timeout = createTimeout(this.timeout); + const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]); + timeout.clear(); + + if (handlerResult && (handlerResult as ContextEventResult).messages) { + currentMessages = (handlerResult as ContextEventResult).messages!; + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + this.emitError({ + hookPath: hook.path, + event: "context", + error: message, + }); + } + } + } + + return currentMessages; + } + + /** + * Emit before_agent_start event to all hooks. + * Returns the first message to inject (if any handler returns one). + */ + async emitBeforeAgentStart( + prompt: string, + images?: import("@mariozechner/pi-ai").ImageContent[], + ): Promise { + const ctx = this.createContext(); + let result: BeforeAgentStartEventResult | undefined; + + for (const hook of this.hooks) { + const handlers = hook.handlers.get("before_agent_start"); + if (!handlers || handlers.length === 0) continue; + + for (const handler of handlers) { + try { + const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images }; + const timeout = createTimeout(this.timeout); + const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]); + timeout.clear(); + + // Take the first message returned + if (handlerResult && (handlerResult as BeforeAgentStartEventResult).message && !result) { + result = handlerResult as BeforeAgentStartEventResult; + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + this.emitError({ + hookPath: hook.path, + event: "before_agent_start", + error: message, + }); + } + } + } + + return result; + } } diff --git a/packages/coding-agent/src/core/hooks/tool-wrapper.ts b/packages/coding-agent/src/core/hooks/tool-wrapper.ts index b9e518d2..c3499d9f 100644 --- a/packages/coding-agent/src/core/hooks/tool-wrapper.ts +++ b/packages/coding-agent/src/core/hooks/tool-wrapper.ts @@ -2,7 +2,7 @@ * Tool wrapper - wraps tools with hook callbacks for interception. */ -import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-ai"; +import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core"; import type { HookRunner } from "./runner.js"; import type { ToolCallEventResult, ToolResultEventResult } from "./types.js"; diff --git a/packages/coding-agent/src/core/hooks/types.ts b/packages/coding-agent/src/core/hooks/types.ts index a60d009b..e8dd689e 100644 --- a/packages/coding-agent/src/core/hooks/types.ts +++ b/packages/coding-agent/src/core/hooks/types.ts @@ -5,10 +5,37 @@ * and interact with the user via UI primitives. */ -import type { AppMessage, Attachment } from "@mariozechner/pi-agent-core"; -import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai"; -import type { CutPointResult } from "../compaction.js"; -import type { CompactionEntry, SessionEntry } from "../session-manager.js"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Message, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { Component } from "@mariozechner/pi-tui"; +import type { Theme } from "../../modes/interactive/theme/theme.js"; +import type { CompactionPreparation, CompactionResult } from "../compaction/index.js"; +import type { ExecOptions, ExecResult } from "../exec.js"; +import type { HookMessage } from "../messages.js"; +import type { ModelRegistry } from "../model-registry.js"; +import type { BranchSummaryEntry, CompactionEntry, SessionEntry, SessionManager } from "../session-manager.js"; + +/** + * Read-only view of SessionManager for hooks. + * Hooks should use pi.sendMessage() and pi.appendEntry() for writes. + */ +export type ReadonlySessionManager = Pick< + SessionManager, + | "getCwd" + | "getSessionDir" + | "getSessionId" + | "getSessionFile" + | "getLeafId" + | "getLeafEntry" + | "getEntry" + | "getLabel" + | "getPath" + | "getHeader" + | "getEntries" + | "getTree" +>; + +import type { EditToolDetails } from "../tools/edit.js"; import type { BashToolDetails, FindToolDetails, @@ -17,27 +44,8 @@ import type { ReadToolDetails, } from "../tools/index.js"; -// ============================================================================ -// Execution Context -// ============================================================================ - -/** - * Result of executing a command via ctx.exec() - */ -export interface ExecResult { - stdout: string; - stderr: string; - code: number; - /** True if the process was killed due to signal or timeout */ - killed?: boolean; -} - -export interface ExecOptions { - /** AbortSignal to cancel the process */ - signal?: AbortSignal; - /** Timeout in milliseconds */ - timeout?: number; -} +// Re-export for backward compatibility +export type { ExecOptions, ExecResult } from "../exec.js"; /** * UI context for hooks to request interactive UI from the harness. @@ -50,7 +58,7 @@ export interface HookUIContext { * @param options - Array of string options * @returns Selected option string, or null if cancelled */ - select(title: string, options: string[]): Promise; + select(title: string, options: string[]): Promise; /** * Show a confirmation dialog. @@ -60,97 +68,191 @@ export interface HookUIContext { /** * Show a text input dialog. - * @returns User input, or null if cancelled + * @returns User input, or undefined if cancelled */ - input(title: string, placeholder?: string): Promise; + input(title: string, placeholder?: string): Promise; /** * Show a notification to the user. */ notify(message: string, type?: "info" | "warning" | "error"): void; + + /** + * Show a custom component with keyboard focus. + * The component receives keyboard input via handleInput() if implemented. + * + * @param component - Component to display (implement handleInput for keyboard, dispose for cleanup) + * @returns Object with close() to restore normal UI and requestRender() to trigger redraw + */ + custom(component: Component & { dispose?(): void }): { close: () => void; requestRender: () => void }; } /** * Context passed to hook event handlers. */ export interface HookEventContext { - /** Execute a command and return stdout/stderr/code */ - exec(command: string, args: string[], options?: ExecOptions): Promise; /** UI methods for user interaction */ ui: HookUIContext; /** Whether UI is available (false in print mode) */ hasUI: boolean; /** Current working directory */ cwd: string; - /** Path to session file, or null if --no-session */ - sessionFile: string | null; + /** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */ + sessionManager: ReadonlySessionManager; + /** Model registry - use for API key resolution and model retrieval */ + modelRegistry: ModelRegistry; } // ============================================================================ -// Events +// Session Events // ============================================================================ -/** - * Base fields shared by all session events. - */ -interface SessionEventBase { - type: "session"; - /** All session entries (including pre-compaction history) */ - entries: SessionEntry[]; - /** Current session file path, or null in --no-session mode */ - sessionFile: string | null; - /** Previous session file path, or null for "start" and "new" */ - previousSessionFile: string | null; +/** Fired on initial session load */ +export interface SessionStartEvent { + type: "session_start"; } -/** - * Event data for session events. - * Discriminated union based on reason. - * - * Lifecycle: - * - start: Initial session load - * - before_switch / switch: Session switch (e.g., /resume command) - * - before_new / new: New session (e.g., /new command) - * - before_branch / branch: Session branch (e.g., /branch command) - * - before_compact / compact: Before/after context compaction - * - shutdown: Process exit (SIGINT/SIGTERM) - * - * "before_*" events fire before the action and can be cancelled via SessionEventResult. - * Other events fire after the action completes. - */ +/** Fired before switching to another session (can be cancelled) */ +export interface SessionBeforeSwitchEvent { + type: "session_before_switch"; + /** Session file we're switching to */ + targetSessionFile: string; +} + +/** Fired after switching to another session */ +export interface SessionSwitchEvent { + type: "session_switch"; + /** Session file we came from */ + previousSessionFile: string | undefined; +} + +/** Fired before creating a new session (can be cancelled) */ +export interface SessionBeforeNewEvent { + type: "session_before_new"; +} + +/** Fired after creating a new session */ +export interface SessionNewEvent { + type: "session_new"; +} + +/** Fired before branching a session (can be cancelled) */ +export interface SessionBeforeBranchEvent { + type: "session_before_branch"; + /** Index of the entry in the session (SessionManager.getEntries()) to branch from */ + entryIndex: number; +} + +/** Fired after branching a session */ +export interface SessionBranchEvent { + type: "session_branch"; + previousSessionFile: string | undefined; +} + +/** Fired before context compaction (can be cancelled or customized) */ +export interface SessionBeforeCompactEvent { + type: "session_before_compact"; + /** Compaction preparation with cut point, messages to summarize/keep, etc. */ + preparation: CompactionPreparation; + /** Previous compaction entries, newest first. Use for iterative summarization. */ + previousCompactions: CompactionEntry[]; + /** Optional user-provided instructions for the summary */ + customInstructions?: string; + /** Current model */ + model: Model; + /** Abort signal - hooks should pass this to LLM calls and check it periodically */ + signal: AbortSignal; +} + +/** Fired after context compaction */ +export interface SessionCompactEvent { + type: "session_compact"; + compactionEntry: CompactionEntry; + /** Whether the compaction entry was provided by a hook */ + fromHook: boolean; +} + +/** Fired on process exit (SIGINT/SIGTERM) */ +export interface SessionShutdownEvent { + type: "session_shutdown"; +} + +/** Preparation data for tree navigation (used by session_before_tree event) */ +export interface TreePreparation { + /** Node being switched to */ + targetId: string; + /** Current active leaf (being abandoned), null if no current position */ + oldLeafId: string | null; + /** Common ancestor of target and old leaf, null if no common ancestor */ + commonAncestorId: string | null; + /** Entries to summarize (old leaf back to common ancestor or compaction) */ + entriesToSummarize: SessionEntry[]; + /** Whether user chose to summarize */ + userWantsSummary: boolean; +} + +/** Fired before navigating to a different node in the session tree (can be cancelled) */ +export interface SessionBeforeTreeEvent { + type: "session_before_tree"; + /** Preparation data for the navigation */ + preparation: TreePreparation; + /** Model to use for summarization (conversation model) */ + model: Model; + /** Abort signal - honors Escape during summarization */ + signal: AbortSignal; +} + +/** Fired after navigating to a different node in the session tree */ +export interface SessionTreeEvent { + type: "session_tree"; + /** The new active leaf, null if navigated to before first entry */ + newLeafId: string | null; + /** Previous active leaf, null if there was no position */ + oldLeafId: string | null; + /** Branch summary entry if one was created */ + summaryEntry?: BranchSummaryEntry; + /** Whether summary came from hook */ + fromHook?: boolean; +} + +/** Union of all session event types */ export type SessionEvent = - | (SessionEventBase & { - reason: "start" | "switch" | "new" | "before_switch" | "before_new" | "shutdown"; - }) - | (SessionEventBase & { - reason: "branch" | "before_branch"; - /** Index of the turn to branch from */ - targetTurnIndex: number; - }) - | (SessionEventBase & { - reason: "before_compact"; - cutPoint: CutPointResult; - /** Summary from previous compaction, if any. Include this in your summary to preserve context. */ - previousSummary?: string; - /** Messages that will be summarized and discarded */ - messagesToSummarize: AppMessage[]; - /** Messages that will be kept after the summary (recent turns) */ - messagesToKeep: AppMessage[]; - tokensBefore: number; - customInstructions?: string; - model: Model; - /** Resolve API key for any model (checks settings, OAuth, env vars) */ - resolveApiKey: (model: Model) => Promise; - /** Abort signal - hooks should pass this to LLM calls and check it periodically */ - signal: AbortSignal; - }) - | (SessionEventBase & { - reason: "compact"; - compactionEntry: CompactionEntry; - tokensBefore: number; - /** Whether the compaction entry was provided by a hook */ - fromHook: boolean; - }); + | SessionStartEvent + | SessionBeforeSwitchEvent + | SessionSwitchEvent + | SessionBeforeNewEvent + | SessionNewEvent + | SessionBeforeBranchEvent + | SessionBranchEvent + | SessionBeforeCompactEvent + | SessionCompactEvent + | SessionShutdownEvent + | SessionBeforeTreeEvent + | SessionTreeEvent; + +/** + * Event data for context event. + * Fired before each LLM call, allowing hooks to modify context non-destructively. + * Original session messages are NOT modified - only the messages sent to the LLM are affected. + */ +export interface ContextEvent { + type: "context"; + /** Messages about to be sent to the LLM (deep copy, safe to modify) */ + messages: AgentMessage[]; +} + +/** + * Event data for before_agent_start event. + * Fired after user submits a prompt but before the agent loop starts. + * Allows hooks to inject context that will be persisted and visible in TUI. + */ +export interface BeforeAgentStartEvent { + type: "before_agent_start"; + /** The user's prompt text */ + prompt: string; + /** Any images attached to the prompt */ + images?: ImageContent[]; +} /** * Event data for agent_start event. @@ -165,7 +267,7 @@ export interface AgentStartEvent { */ export interface AgentEndEvent { type: "agent_end"; - messages: AppMessage[]; + messages: AgentMessage[]; } /** @@ -183,7 +285,7 @@ export interface TurnStartEvent { export interface TurnEndEvent { type: "turn_end"; turnIndex: number; - message: AppMessage; + message: AgentMessage; toolResults: ToolResultMessage[]; } @@ -231,7 +333,7 @@ export interface ReadToolResultEvent extends ToolResultEventBase { /** Tool result event for edit tool */ export interface EditToolResultEvent extends ToolResultEventBase { toolName: "edit"; - details: undefined; + details: EditToolDetails | undefined; } /** Tool result event for write tool */ @@ -307,6 +409,8 @@ export function isLsToolResult(e: ToolResultEvent): e is LsToolResultEvent { */ export type HookEvent = | SessionEvent + | ContextEvent + | BeforeAgentStartEvent | AgentStartEvent | AgentEndEvent | TurnStartEvent @@ -318,6 +422,15 @@ export type HookEvent = // Event Results // ============================================================================ +/** + * Return type for context event handlers. + * Allows hooks to modify messages before they're sent to the LLM. + */ +export interface ContextEventResult { + /** Modified messages to send instead of the original */ + messages?: Message[]; +} + /** * Return type for tool_call event handlers. * Allows hooks to block tool execution. @@ -343,16 +456,68 @@ export interface ToolResultEventResult { } /** - * Return type for session event handlers. - * Allows hooks to cancel "before_*" actions. + * Return type for before_agent_start event handlers. + * Allows hooks to inject context before the agent runs. */ -export interface SessionEventResult { - /** If true, cancel the pending action (switch, clear, or branch) */ +export interface BeforeAgentStartEventResult { + /** Message to inject into context (persisted to session, visible in TUI) */ + message?: Pick; +} + +/** Return type for session_before_switch handlers */ +export interface SessionBeforeSwitchResult { + /** If true, cancel the switch */ cancel?: boolean; - /** If true (for before_branch only), skip restoring conversation to branch point while still creating the branched session file */ +} + +/** Return type for session_before_new handlers */ +export interface SessionBeforeNewResult { + /** If true, cancel the new session */ + cancel?: boolean; +} + +/** Return type for session_before_branch handlers */ +export interface SessionBeforeBranchResult { + /** + * If true, abort the branch entirely. No new session file is created, + * conversation stays unchanged. + */ + cancel?: boolean; + /** + * If true, the branch proceeds (new session file created, session state updated) + * but the in-memory conversation is NOT rewound to the branch point. + * + * Use case: git-checkpoint hook that restores code state separately. + * The hook handles state restoration itself, so it doesn't want the + * agent's conversation to be rewound (which would lose recent context). + * + * - `cancel: true` → nothing happens, user stays in current session + * - `skipConversationRestore: true` → branch happens, but messages stay as-is + * - neither → branch happens AND messages rewind to branch point (default) + */ skipConversationRestore?: boolean; - /** Custom compaction entry (for before_compact event) */ - compactionEntry?: CompactionEntry; +} + +/** Return type for session_before_compact handlers */ +export interface SessionBeforeCompactResult { + /** If true, cancel the compaction */ + cancel?: boolean; + /** Custom compaction result - SessionManager adds id/parentId */ + compaction?: CompactionResult; +} + +/** Return type for session_before_tree handlers */ +export interface SessionBeforeTreeResult { + /** If true, cancel the navigation entirely */ + cancel?: boolean; + /** + * Custom summary (skips default summarizer). + * Only used if preparation.userWantsSummary is true. + */ + summary?: { + summary: string; + details?: unknown; + }; } // ============================================================================ @@ -361,29 +526,152 @@ export interface SessionEventResult { /** * Handler function type for each event. + * Handlers can return R, undefined, or void (bare return statements). */ -export type HookHandler = (event: E, ctx: HookEventContext) => Promise; +// biome-ignore lint/suspicious/noConfusingVoidType: void allows bare return statements in handlers +export type HookHandler = (event: E, ctx: HookEventContext) => Promise | R | void; + +export interface HookMessageRenderOptions { + /** Whether the view is expanded */ + expanded: boolean; +} + +/** + * Renderer for hook messages. + * Hooks register these to provide custom TUI rendering for their message types. + */ +export type HookMessageRenderer = ( + message: HookMessage, + options: HookMessageRenderOptions, + theme: Theme, +) => Component | undefined; + +/** + * Context passed to hook command handlers. + */ +export interface HookCommandContext { + /** Arguments after the command name */ + args: string; + /** UI methods for user interaction */ + ui: HookUIContext; + /** Whether UI is available (false in print mode) */ + hasUI: boolean; + /** Current working directory */ + cwd: string; + /** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */ + sessionManager: ReadonlySessionManager; + /** Model registry for API keys */ + modelRegistry: ModelRegistry; +} + +/** + * Command registration options. + */ +export interface RegisteredCommand { + name: string; + description?: string; + + handler: (ctx: HookCommandContext) => Promise; +} /** * HookAPI passed to hook factory functions. - * Hooks use pi.on() to subscribe to events and pi.send() to inject messages. + * Hooks use pi.on() to subscribe to events and pi.sendMessage() to inject messages. */ export interface HookAPI { - // biome-ignore lint/suspicious/noConfusingVoidType: void allows handlers to not return anything - on(event: "session", handler: HookHandler): void; + // Session events + on(event: "session_start", handler: HookHandler): void; + on(event: "session_before_switch", handler: HookHandler): void; + on(event: "session_switch", handler: HookHandler): void; + on(event: "session_before_new", handler: HookHandler): void; + on(event: "session_new", handler: HookHandler): void; + on(event: "session_before_branch", handler: HookHandler): void; + on(event: "session_branch", handler: HookHandler): void; + on( + event: "session_before_compact", + handler: HookHandler, + ): void; + on(event: "session_compact", handler: HookHandler): void; + on(event: "session_shutdown", handler: HookHandler): void; + on(event: "session_before_tree", handler: HookHandler): void; + on(event: "session_tree", handler: HookHandler): void; + + // Context and agent events + on(event: "context", handler: HookHandler): void; + on(event: "before_agent_start", handler: HookHandler): void; on(event: "agent_start", handler: HookHandler): void; on(event: "agent_end", handler: HookHandler): void; on(event: "turn_start", handler: HookHandler): void; on(event: "turn_end", handler: HookHandler): void; - on(event: "tool_call", handler: HookHandler): void; - on(event: "tool_result", handler: HookHandler): void; + on(event: "tool_call", handler: HookHandler): void; + on(event: "tool_result", handler: HookHandler): void; /** - * Send a message to the agent. - * If the agent is streaming, the message is queued. - * If the agent is idle, a new agent loop is started. + * Send a custom message to the session. Creates a CustomMessageEntry that + * participates in LLM context and can be displayed in the TUI. + * + * Use this when you want the LLM to see the message content. + * For hook state that should NOT be sent to the LLM, use appendEntry() instead. + * + * @param message - The message to send + * @param message.customType - Identifier for your hook (used for filtering on reload) + * @param message.content - Message content (string or TextContent/ImageContent array) + * @param message.display - Whether to show in TUI (true = styled display, false = hidden) + * @param message.details - Optional hook-specific metadata (not sent to LLM) + * @param triggerTurn - If true and agent is idle, triggers a new LLM turn. Default: false. + * If agent is streaming, message is queued and triggerTurn is ignored. */ - send(text: string, attachments?: Attachment[]): void; + sendMessage( + message: Pick, "customType" | "content" | "display" | "details">, + triggerTurn?: boolean, + ): void; + + /** + * Append a custom entry to the session for hook state persistence. + * Creates a CustomEntry that does NOT participate in LLM context. + * + * Use this to store hook-specific data that should persist across session reloads + * but should NOT be sent to the LLM. On reload, scan session entries for your + * customType to reconstruct hook state. + * + * For messages that SHOULD be sent to the LLM, use sendMessage() instead. + * + * @param customType - Identifier for your hook (used for filtering on reload) + * @param data - Hook-specific data to persist (must be JSON-serializable) + * + * @example + * // Store permission state + * pi.appendEntry("permissions", { level: "full", grantedAt: Date.now() }); + * + * // On reload, reconstruct state from entries + * pi.on("session", async (event, ctx) => { + * if (event.reason === "start") { + * const entries = event.sessionManager.getEntries(); + * const myEntries = entries.filter(e => e.type === "custom" && e.customType === "permissions"); + * // Reconstruct state from myEntries... + * } + * }); + */ + appendEntry(customType: string, data?: T): void; + + /** + * Register a custom renderer for CustomMessageEntry with a specific customType. + * The renderer is called when rendering the entry in the TUI. + * Return nothing to use the default renderer. + */ + registerMessageRenderer(customType: string, renderer: HookMessageRenderer): void; + + /** + * Register a custom slash command. + * Handler receives HookCommandContext. + */ + registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void; + + /** + * Execute a shell command and return stdout/stderr/code. + * Supports timeout and abort signal. + */ + exec(command: string, args: string[], options?: ExecOptions): Promise; } /** diff --git a/packages/coding-agent/src/core/index.ts b/packages/coding-agent/src/core/index.ts index 117d96b6..b25df418 100644 --- a/packages/coding-agent/src/core/index.ts +++ b/packages/coding-agent/src/core/index.ts @@ -7,12 +7,12 @@ export { type AgentSessionConfig, type AgentSessionEvent, type AgentSessionEventListener, - type CompactionResult, type ModelCycleResult, type PromptOptions, type SessionStats, } from "./agent-session.js"; export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js"; +export type { CompactionResult } from "./compaction/index.js"; export { type CustomAgentTool, type CustomToolFactory, diff --git a/packages/coding-agent/src/core/messages.ts b/packages/coding-agent/src/core/messages.ts index 947801f8..2726903d 100644 --- a/packages/coding-agent/src/core/messages.ts +++ b/packages/coding-agent/src/core/messages.ts @@ -1,16 +1,27 @@ /** * Custom message types and transformers for the coding agent. * - * Extends the base AppMessage type with coding-agent specific message types, + * Extends the base AgentMessage type with coding-agent specific message types, * and provides a transformer to convert them to LLM-compatible messages. */ -import type { AppMessage } from "@mariozechner/pi-agent-core"; -import type { Message } from "@mariozechner/pi-ai"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai"; -// ============================================================================ -// Custom Message Types -// ============================================================================ +export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary: + + +`; + +export const COMPACTION_SUMMARY_SUFFIX = ` +`; + +export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from: + + +`; + +export const BRANCH_SUMMARY_SUFFIX = ``; /** * Message type for bash executions via the ! command. @@ -19,35 +30,50 @@ export interface BashExecutionMessage { role: "bashExecution"; command: string; output: string; - exitCode: number | null; + exitCode: number | undefined; cancelled: boolean; truncated: boolean; fullOutputPath?: string; timestamp: number; } -// Extend CustomMessages via declaration merging +/** + * Message type for hook-injected messages via sendMessage(). + * These are custom messages that hooks can inject into the conversation. + */ +export interface HookMessage { + role: "hookMessage"; + customType: string; + content: string | (TextContent | ImageContent)[]; + display: boolean; + details?: T; + timestamp: number; +} + +export interface BranchSummaryMessage { + role: "branchSummary"; + summary: string; + fromId: string; + timestamp: number; +} + +export interface CompactionSummaryMessage { + role: "compactionSummary"; + summary: string; + tokensBefore: number; + timestamp: number; +} + +// Extend CustomAgentMessages via declaration merging declare module "@mariozechner/pi-agent-core" { - interface CustomMessages { + interface CustomAgentMessages { bashExecution: BashExecutionMessage; + hookMessage: HookMessage; + branchSummary: BranchSummaryMessage; + compactionSummary: CompactionSummaryMessage; } } -// ============================================================================ -// Type Guards -// ============================================================================ - -/** - * Type guard for BashExecutionMessage. - */ -export function isBashExecutionMessage(msg: AppMessage | Message): msg is BashExecutionMessage { - return (msg as BashExecutionMessage).role === "bashExecution"; -} - -// ============================================================================ -// Message Formatting -// ============================================================================ - /** * Convert a BashExecutionMessage to user message text for LLM context. */ @@ -60,7 +86,7 @@ export function bashExecutionToText(msg: BashExecutionMessage): string { } if (msg.cancelled) { text += "\n\n(command cancelled)"; - } else if (msg.exitCode !== null && msg.exitCode !== 0) { + } else if (msg.exitCode !== null && msg.exitCode !== undefined && msg.exitCode !== 0) { text += `\n\nCommand exited with code ${msg.exitCode}`; } if (msg.truncated && msg.fullOutputPath) { @@ -69,34 +95,95 @@ export function bashExecutionToText(msg: BashExecutionMessage): string { return text; } -// ============================================================================ -// Message Transformer -// ============================================================================ +export function createBranchSummaryMessage(summary: string, fromId: string, timestamp: string): BranchSummaryMessage { + return { + role: "branchSummary", + summary, + fromId, + timestamp: new Date(timestamp).getTime(), + }; +} + +export function createCompactionSummaryMessage( + summary: string, + tokensBefore: number, + timestamp: string, +): CompactionSummaryMessage { + return { + role: "compactionSummary", + summary: summary, + tokensBefore, + timestamp: new Date(timestamp).getTime(), + }; +} + +/** Convert CustomMessageEntry to AgentMessage format */ +export function createHookMessage( + customType: string, + content: string | (TextContent | ImageContent)[], + display: boolean, + details: unknown | undefined, + timestamp: string, +): HookMessage { + return { + role: "hookMessage", + customType, + content, + display, + details, + timestamp: new Date(timestamp).getTime(), + }; +} /** - * Transform AppMessages (including custom types) to LLM-compatible Messages. + * Transform AgentMessages (including custom types) to LLM-compatible Messages. * * This is used by: - * - Agent's messageTransformer option (for prompt calls) + * - Agent's transormToLlm option (for prompt calls and queued messages) * - Compaction's generateSummary (for summarization) + * - Custom hooks and tools */ -export function messageTransformer(messages: AppMessage[]): Message[] { +export function convertToLlm(messages: AgentMessage[]): Message[] { return messages - .map((m): Message | null => { - if (isBashExecutionMessage(m)) { - // Convert bash execution to user message - return { - role: "user", - content: [{ type: "text", text: bashExecutionToText(m) }], - timestamp: m.timestamp, - }; + .map((m): Message | undefined => { + switch (m.role) { + case "bashExecution": + return { + role: "user", + content: [{ type: "text", text: bashExecutionToText(m) }], + timestamp: m.timestamp, + }; + case "hookMessage": { + const content = typeof m.content === "string" ? [{ type: "text" as const, text: m.content }] : m.content; + return { + role: "user", + content, + timestamp: m.timestamp, + }; + } + case "branchSummary": + return { + role: "user", + content: [{ type: "text" as const, text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX }], + timestamp: m.timestamp, + }; + case "compactionSummary": + return { + role: "user", + content: [ + { type: "text" as const, text: COMPACTION_SUMMARY_PREFIX + m.summary + COMPACTION_SUMMARY_SUFFIX }, + ], + timestamp: m.timestamp, + }; + case "user": + case "assistant": + case "toolResult": + return m; + default: + // biome-ignore lint/correctness/noSwitchDeclarations: fine + const _exhaustiveCheck: never = m; + return undefined; } - // Pass through standard LLM roles - if (m.role === "user" || m.role === "assistant" || m.role === "toolResult") { - return m as Message; - } - // Filter out unknown message types - return null; }) - .filter((m): m is Message => m !== null); + .filter((m) => m !== undefined); } diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index 91aa0f8c..94a839f8 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -90,11 +90,11 @@ function resolveApiKeyConfig(keyConfig: string): string | undefined { export class ModelRegistry { private models: Model[] = []; private customProviderApiKeys: Map = new Map(); - private loadError: string | null = null; + private loadError: string | undefined = undefined; constructor( readonly authStorage: AuthStorage, - private modelsJsonPath: string | null = null, + private modelsJsonPath: string | undefined = undefined, ) { // Set up fallback resolver for custom provider API keys this.authStorage.setFallbackResolver((provider) => { @@ -114,14 +114,14 @@ export class ModelRegistry { */ refresh(): void { this.customProviderApiKeys.clear(); - this.loadError = null; + this.loadError = undefined; this.loadModels(); } /** - * Get any error from loading models.json (null if no error). + * Get any error from loading models.json (undefined if no error). */ - getError(): string | null { + getError(): string | undefined { return this.loadError; } @@ -160,9 +160,9 @@ export class ModelRegistry { } } - private loadCustomModels(modelsJsonPath: string): { models: Model[]; error: string | null } { + private loadCustomModels(modelsJsonPath: string): { models: Model[]; error: string | undefined } { if (!existsSync(modelsJsonPath)) { - return { models: [], error: null }; + return { models: [], error: undefined }; } try { @@ -186,7 +186,7 @@ export class ModelRegistry { this.validateConfig(config); // Parse models - return { models: this.parseModels(config), error: null }; + return { models: this.parseModels(config), error: undefined }; } catch (error) { if (error instanceof SyntaxError) { return { @@ -294,14 +294,14 @@ export class ModelRegistry { /** * Find a model by provider and ID. */ - find(provider: string, modelId: string): Model | null { - return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null; + find(provider: string, modelId: string): Model | undefined { + return this.models.find((m) => m.provider === provider && m.id === modelId) ?? undefined; } /** * Get API key for a model. */ - async getApiKey(model: Model): Promise { + async getApiKey(model: Model): Promise { return this.authStorage.getApiKey(model.provider); } diff --git a/packages/coding-agent/src/core/model-resolver.ts b/packages/coding-agent/src/core/model-resolver.ts index 1734c015..981f11f2 100644 --- a/packages/coding-agent/src/core/model-resolver.ts +++ b/packages/coding-agent/src/core/model-resolver.ts @@ -44,9 +44,9 @@ function isAlias(id: string): boolean { /** * Try to match a pattern to a model from the available models list. - * Returns the matched model or null if no match found. + * Returns the matched model or undefined if no match found. */ -function tryMatchModel(modelPattern: string, availableModels: Model[]): Model | null { +function tryMatchModel(modelPattern: string, availableModels: Model[]): Model | undefined { // Check for provider/modelId format (provider is everything before the first /) const slashIndex = modelPattern.indexOf("/"); if (slashIndex !== -1) { @@ -75,7 +75,7 @@ function tryMatchModel(modelPattern: string, availableModels: Model[]): Mod ); if (matches.length === 0) { - return null; + return undefined; } // Separate into aliases and dated versions @@ -94,9 +94,9 @@ function tryMatchModel(modelPattern: string, availableModels: Model[]): Mod } export interface ParsedModelResult { - model: Model | null; + model: Model | undefined; thinkingLevel: ThinkingLevel; - warning: string | null; + warning: string | undefined; } /** @@ -116,14 +116,14 @@ export function parseModelPattern(pattern: string, availableModels: Model[] // Try exact match first const exactMatch = tryMatchModel(pattern, availableModels); if (exactMatch) { - return { model: exactMatch, thinkingLevel: "off", warning: null }; + return { model: exactMatch, thinkingLevel: "off", warning: undefined }; } // No match - try splitting on last colon if present const lastColonIndex = pattern.lastIndexOf(":"); if (lastColonIndex === -1) { // No colons, pattern simply doesn't match any model - return { model: null, thinkingLevel: "off", warning: null }; + return { model: undefined, thinkingLevel: "off", warning: undefined }; } const prefix = pattern.substring(0, lastColonIndex); @@ -193,9 +193,9 @@ export async function resolveModelScope(patterns: string[], modelRegistry: Model } export interface InitialModelResult { - model: Model | null; + model: Model | undefined; thinkingLevel: ThinkingLevel; - fallbackMessage: string | null; + fallbackMessage: string | undefined; } /** @@ -227,7 +227,7 @@ export async function findInitialModel(options: { modelRegistry, } = options; - let model: Model | null = null; + let model: Model | undefined; let thinkingLevel: ThinkingLevel = "off"; // 1. CLI args take priority @@ -237,7 +237,7 @@ export async function findInitialModel(options: { console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`)); process.exit(1); } - return { model: found, thinkingLevel: "off", fallbackMessage: null }; + return { model: found, thinkingLevel: "off", fallbackMessage: undefined }; } // 2. Use first model from scoped models (skip if continuing/resuming) @@ -245,7 +245,7 @@ export async function findInitialModel(options: { return { model: scopedModels[0].model, thinkingLevel: scopedModels[0].thinkingLevel, - fallbackMessage: null, + fallbackMessage: undefined, }; } @@ -257,7 +257,7 @@ export async function findInitialModel(options: { if (defaultThinkingLevel) { thinkingLevel = defaultThinkingLevel; } - return { model, thinkingLevel, fallbackMessage: null }; + return { model, thinkingLevel, fallbackMessage: undefined }; } } @@ -270,16 +270,16 @@ export async function findInitialModel(options: { const defaultId = defaultModelPerProvider[provider]; const match = availableModels.find((m) => m.provider === provider && m.id === defaultId); if (match) { - return { model: match, thinkingLevel: "off", fallbackMessage: null }; + return { model: match, thinkingLevel: "off", fallbackMessage: undefined }; } } // If no default found, use first available - return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: null }; + return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: undefined }; } // 5. No model found - return { model: null, thinkingLevel: "off", fallbackMessage: null }; + return { model: undefined, thinkingLevel: "off", fallbackMessage: undefined }; } /** @@ -288,10 +288,10 @@ export async function findInitialModel(options: { export async function restoreModelFromSession( savedProvider: string, savedModelId: string, - currentModel: Model | null, + currentModel: Model | undefined, shouldPrintMessages: boolean, modelRegistry: ModelRegistry, -): Promise<{ model: Model | null; fallbackMessage: string | null }> { +): Promise<{ model: Model | undefined; fallbackMessage: string | undefined }> { const restoredModel = modelRegistry.find(savedProvider, savedModelId); // Check if restored model exists and has a valid API key @@ -301,7 +301,7 @@ export async function restoreModelFromSession( if (shouldPrintMessages) { console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`)); } - return { model: restoredModel, fallbackMessage: null }; + return { model: restoredModel, fallbackMessage: undefined }; } // Model not found or no API key - fall back @@ -327,7 +327,7 @@ export async function restoreModelFromSession( if (availableModels.length > 0) { // Try to find a default model from known providers - let fallbackModel: Model | null = null; + let fallbackModel: Model | undefined; for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) { const defaultId = defaultModelPerProvider[provider]; const match = availableModels.find((m) => m.provider === provider && m.id === defaultId); @@ -353,5 +353,5 @@ export async function restoreModelFromSession( } // No models available - return { model: null, fallbackMessage: null }; + return { model: undefined, fallbackMessage: undefined }; } diff --git a/packages/coding-agent/src/core/sdk.ts b/packages/coding-agent/src/core/sdk.ts index e245f233..089691d3 100644 --- a/packages/coding-agent/src/core/sdk.ts +++ b/packages/coding-agent/src/core/sdk.ts @@ -29,7 +29,7 @@ * ``` */ -import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core"; +import { Agent, type ThinkingLevel } from "@mariozechner/pi-agent-core"; import type { Model } from "@mariozechner/pi-ai"; import { join } from "path"; import { getAgentDir } from "../config.js"; @@ -39,7 +39,7 @@ import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tool import type { CustomAgentTool } from "./custom-tools/types.js"; import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js"; import type { HookFactory } from "./hooks/types.js"; -import { messageTransformer } from "./messages.js"; +import { convertToLlm } from "./messages.js"; import { ModelRegistry } from "./model-registry.js"; import { SessionManager } from "./session-manager.js"; import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js"; @@ -340,7 +340,10 @@ function createFactoryFromLoadedHook(loaded: LoadedHook): HookFactory { function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; factory: HookFactory }>): LoadedHook[] { return definitions.map((def) => { const handlers = new Map Promise>>(); - let sendHandler: (text: string, attachments?: any[]) => void = () => {}; + const messageRenderers = new Map(); + const commands = new Map(); + let sendMessageHandler: (message: any, triggerTurn?: boolean) => void = () => {}; + let appendEntryHandler: (customType: string, data?: any) => void = () => {}; const api = { on: (event: string, handler: (...args: unknown[]) => Promise) => { @@ -348,8 +351,17 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa list.push(handler); handlers.set(event, list); }, - send: (text: string, attachments?: any[]) => { - sendHandler(text, attachments); + sendMessage: (message: any, triggerTurn?: boolean) => { + sendMessageHandler(message, triggerTurn); + }, + appendEntry: (customType: string, data?: any) => { + appendEntryHandler(customType, data); + }, + registerMessageRenderer: (customType: string, renderer: any) => { + messageRenderers.set(customType, renderer); + }, + registerCommand: (name: string, options: any) => { + commands.set(name, { name, ...options }); }, }; @@ -359,8 +371,13 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa path: def.path ?? "", resolvedPath: def.path ?? "", handlers, - setSendHandler: (handler: (text: string, attachments?: any[]) => void) => { - sendHandler = handler; + messageRenderers, + commands, + setSendMessageHandler: (handler: (message: any, triggerTurn?: boolean) => void) => { + sendMessageHandler = handler; + }, + setAppendEntryHandler: (handler: (customType: string, data?: any) => void) => { + appendEntryHandler = handler; }, }; }); @@ -513,11 +530,11 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} customToolsResult = result; } - let hookRunner: HookRunner | null = null; + let hookRunner: HookRunner | undefined; if (options.hooks !== undefined) { if (options.hooks.length > 0) { const loadedHooks = createLoadedHooksFromDefinitions(options.hooks); - hookRunner = new HookRunner(loadedHooks, cwd, settingsManager.getHookTimeout()); + hookRunner = new HookRunner(loadedHooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout()); } } else { // Discover hooks, merging with additional paths @@ -528,7 +545,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} console.error(`Failed to load hook "${path}": ${error}`); } if (hooks.length > 0) { - hookRunner = new HookRunner(hooks, cwd, settingsManager.getHookTimeout()); + hookRunner = new HookRunner(hooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout()); } } @@ -571,21 +588,24 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} thinkingLevel, tools: allToolsArray, }, - messageTransformer, + convertToLlm, + transformContext: hookRunner + ? async (messages) => { + return hookRunner.emitContext(messages); + } + : undefined, queueMode: settingsManager.getQueueMode(), - transport: new ProviderTransport({ - getApiKey: async () => { - const currentModel = agent.state.model; - if (!currentModel) { - throw new Error("No model selected"); - } - const key = await modelRegistry.getApiKey(currentModel); - if (!key) { - throw new Error(`No API key found for provider "${currentModel.provider}"`); - } - return key; - }, - }), + getApiKey: async () => { + const currentModel = agent.state.model; + if (!currentModel) { + throw new Error("No model selected"); + } + const key = await modelRegistry.getApiKey(currentModel); + if (!key) { + throw new Error(`No API key found for provider "${currentModel.provider}"`); + } + return key; + }, }); time("createAgent"); diff --git a/packages/coding-agent/src/core/session-manager.ts b/packages/coding-agent/src/core/session-manager.ts index dd118afd..c3e17714 100644 --- a/packages/coding-agent/src/core/session-manager.ts +++ b/packages/coding-agent/src/core/session-manager.ts @@ -1,61 +1,150 @@ -import type { AppMessage } from "@mariozechner/pi-agent-core"; -import { randomBytes } from "crypto"; -import { appendFileSync, existsSync, mkdirSync, readdirSync, readFileSync, statSync } from "fs"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai"; +import { randomUUID } from "crypto"; +import { + appendFileSync, + closeSync, + existsSync, + mkdirSync, + openSync, + readdirSync, + readFileSync, + readSync, + statSync, + writeFileSync, +} from "fs"; import { join, resolve } from "path"; import { getAgentDir as getDefaultAgentDir } from "../config.js"; +import { + type BashExecutionMessage, + createBranchSummaryMessage, + createCompactionSummaryMessage, + createHookMessage, + type HookMessage, +} from "./messages.js"; -function uuidv4(): string { - const bytes = randomBytes(16); - bytes[6] = (bytes[6] & 0x0f) | 0x40; - bytes[8] = (bytes[8] & 0x3f) | 0x80; - const hex = bytes.toString("hex"); - return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20, 32)}`; -} +export const CURRENT_SESSION_VERSION = 2; export interface SessionHeader { type: "session"; + version?: number; // v1 sessions don't have this id: string; timestamp: string; cwd: string; branchedFrom?: string; } -export interface SessionMessageEntry { - type: "message"; +export interface SessionEntryBase { + type: string; + id: string; + parentId: string | null; timestamp: string; - message: AppMessage; } -export interface ThinkingLevelChangeEntry { +export interface SessionMessageEntry extends SessionEntryBase { + type: "message"; + message: AgentMessage; +} + +export interface ThinkingLevelChangeEntry extends SessionEntryBase { type: "thinking_level_change"; - timestamp: string; thinkingLevel: string; } -export interface ModelChangeEntry { +export interface ModelChangeEntry extends SessionEntryBase { type: "model_change"; - timestamp: string; provider: string; modelId: string; } -export interface CompactionEntry { +export interface CompactionEntry extends SessionEntryBase { type: "compaction"; - timestamp: string; summary: string; - firstKeptEntryIndex: number; + firstKeptEntryId: string; tokensBefore: number; + /** Hook-specific data (e.g., ArtifactIndex, version markers for structured compaction) */ + details?: T; + /** True if generated by a hook, undefined/false if pi-generated (backward compatible) */ + fromHook?: boolean; } +export interface BranchSummaryEntry extends SessionEntryBase { + type: "branch_summary"; + fromId: string; + summary: string; + /** Hook-specific data (not sent to LLM) */ + details?: T; + /** True if generated by a hook, false if pi-generated */ + fromHook?: boolean; +} + +/** + * Custom entry for hooks to store hook-specific data in the session. + * Use customType to identify your hook's entries. + * + * Purpose: Persist hook state across session reloads. On reload, hooks can + * scan entries for their customType and reconstruct internal state. + * + * Does NOT participate in LLM context (ignored by buildSessionContext). + * For injecting content into context, see CustomMessageEntry. + */ +export interface CustomEntry extends SessionEntryBase { + type: "custom"; + customType: string; + data?: T; +} + +/** Label entry for user-defined bookmarks/markers on entries. */ +export interface LabelEntry extends SessionEntryBase { + type: "label"; + targetId: string; + label: string | undefined; +} + +/** + * Custom message entry for hooks to inject messages into LLM context. + * Use customType to identify your hook's entries. + * + * Unlike CustomEntry, this DOES participate in LLM context. + * The content is converted to a user message in buildSessionContext(). + * Use details for hook-specific metadata (not sent to LLM). + * + * display controls TUI rendering: + * - false: hidden entirely + * - true: rendered with distinct styling (different from user messages) + */ +export interface CustomMessageEntry extends SessionEntryBase { + type: "custom_message"; + customType: string; + content: string | (TextContent | ImageContent)[]; + details?: T; + display: boolean; +} + +/** Session entry - has id/parentId for tree structure (returned by "read" methods in SessionManager) */ export type SessionEntry = - | SessionHeader | SessionMessageEntry | ThinkingLevelChangeEntry | ModelChangeEntry - | CompactionEntry; + | CompactionEntry + | BranchSummaryEntry + | CustomEntry + | CustomMessageEntry + | LabelEntry; + +/** Raw file entry (includes header) */ +export type FileEntry = SessionHeader | SessionEntry; + +/** Tree node for getTree() - defensive copy of session structure */ +export interface SessionTreeNode { + entry: SessionEntry; + children: SessionTreeNode[]; + /** Resolved label for this entry, if any */ + label?: string; +} export interface SessionContext { - messages: AppMessage[]; + messages: AgentMessage[]; thinkingLevel: string; model: { provider: string; modelId: string } | null; } @@ -70,32 +159,92 @@ export interface SessionInfo { allMessagesText: string; } -export const SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary: +/** + * Read-only interface for SessionManager. + * Used by compaction/summarization utilities that only need to read session data. + */ +export interface ReadonlySessionManager { + getLeafId(): string | null; + getEntry(id: string): SessionEntry | undefined; + getPath(fromId?: string): SessionEntry[]; + getEntries(): SessionEntry[]; + getChildren(parentId: string): SessionEntry[]; + getTree(): SessionTreeNode[]; + getLabel(id: string): string | undefined; +} - -`; +/** Generate a unique short ID (8 hex chars, collision-checked) */ +function generateId(byId: { has(id: string): boolean }): string { + for (let i = 0; i < 100; i++) { + const id = randomUUID().slice(0, 8); + if (!byId.has(id)) return id; + } + // Fallback to full UUID if somehow we have collisions + return randomUUID(); +} -export const SUMMARY_SUFFIX = ` -`; +/** Migrate v1 → v2: add id/parentId tree structure. Mutates in place. */ +function migrateV1ToV2(entries: FileEntry[]): void { + const ids = new Set(); + let prevId: string | null = null; -/** Exported for compaction.test.ts */ -export function createSummaryMessage(summary: string): AppMessage { - return { - role: "user", - content: SUMMARY_PREFIX + summary + SUMMARY_SUFFIX, - timestamp: Date.now(), - }; + for (const entry of entries) { + if (entry.type === "session") { + entry.version = 2; + continue; + } + + entry.id = generateId(ids); + entry.parentId = prevId; + prevId = entry.id; + + // Convert firstKeptEntryIndex to firstKeptEntryId for compaction + if (entry.type === "compaction") { + const comp = entry as CompactionEntry & { firstKeptEntryIndex?: number }; + if (typeof comp.firstKeptEntryIndex === "number") { + const targetEntry = entries[comp.firstKeptEntryIndex]; + if (targetEntry && targetEntry.type !== "session") { + comp.firstKeptEntryId = targetEntry.id; + } + delete comp.firstKeptEntryIndex; + } + } + } +} + +// Add future migrations here: +// function migrateV2ToV3(entries: FileEntry[]): void { ... } + +/** + * Run all necessary migrations to bring entries to current version. + * Mutates entries in place. Returns true if any migration was applied. + */ +function migrateToCurrentVersion(entries: FileEntry[]): boolean { + const header = entries.find((e) => e.type === "session") as SessionHeader | undefined; + const version = header?.version ?? 1; + + if (version >= CURRENT_SESSION_VERSION) return false; + + if (version < 2) migrateV1ToV2(entries); + // if (version < 3) migrateV2ToV3(entries); + + return true; +} + +/** Exported for testing */ +export function migrateSessionEntries(entries: FileEntry[]): void { + migrateToCurrentVersion(entries); } /** Exported for compaction.test.ts */ -export function parseSessionEntries(content: string): SessionEntry[] { - const entries: SessionEntry[] = []; +export function parseSessionEntries(content: string): FileEntry[] { + const entries: FileEntry[] = []; const lines = content.trim().split("\n"); for (const line of lines) { if (!line.trim()) continue; try { - const entry = JSON.parse(line) as SessionEntry; + const entry = JSON.parse(line) as FileEntry; entries.push(entry); } catch { // Skip malformed lines @@ -115,59 +264,116 @@ export function getLatestCompactionEntry(entries: SessionEntry[]): CompactionEnt } /** - * Build the session context from entries. This is what gets sent to the LLM. - * - * If there's a compaction entry, returns the summary message plus messages - * from `firstKeptEntryIndex` onwards. Otherwise returns all messages. - * - * Also extracts the current thinking level and model from the entries. + * Build the session context from entries using tree traversal. + * If leafId is provided, walks from that entry to root. + * Handles compaction and branch summaries along the path. */ -export function buildSessionContext(entries: SessionEntry[]): SessionContext { +export function buildSessionContext( + entries: SessionEntry[], + leafId?: string | null, + byId?: Map, +): SessionContext { + // Build uuid index if not available + if (!byId) { + byId = new Map(); + for (const entry of entries) { + byId.set(entry.id, entry); + } + } + + // Find leaf + let leaf: SessionEntry | undefined; + if (leafId === null) { + // Explicitly null - return no messages (navigated to before first entry) + return { messages: [], thinkingLevel: "off", model: null }; + } + if (leafId) { + leaf = byId.get(leafId); + } + if (!leaf) { + // Fallback to last entry (when leafId is undefined) + leaf = entries[entries.length - 1]; + } + + if (!leaf) { + return { messages: [], thinkingLevel: "off", model: null }; + } + + // Walk from leaf to root, collecting path + const path: SessionEntry[] = []; + let current: SessionEntry | undefined = leaf; + while (current) { + path.unshift(current); + current = current.parentId ? byId.get(current.parentId) : undefined; + } + + // Extract settings and find compaction let thinkingLevel = "off"; let model: { provider: string; modelId: string } | null = null; + let compaction: CompactionEntry | null = null; - for (const entry of entries) { + for (const entry of path) { if (entry.type === "thinking_level_change") { thinkingLevel = entry.thinkingLevel; } else if (entry.type === "model_change") { model = { provider: entry.provider, modelId: entry.modelId }; } else if (entry.type === "message" && entry.message.role === "assistant") { model = { provider: entry.message.provider, modelId: entry.message.model }; + } else if (entry.type === "compaction") { + compaction = entry; } } - let latestCompactionIndex = -1; - for (let i = entries.length - 1; i >= 0; i--) { - if (entries[i].type === "compaction") { - latestCompactionIndex = i; - break; - } - } + // Build messages and collect corresponding entries + // When there's a compaction, we need to: + // 1. Emit summary first (entry = compaction) + // 2. Emit kept messages (from firstKeptEntryId up to compaction) + // 3. Emit messages after compaction + const messages: AgentMessage[] = []; - if (latestCompactionIndex === -1) { - const messages: AppMessage[] = []; - for (const entry of entries) { - if (entry.type === "message") { - messages.push(entry.message); + const appendMessage = (entry: SessionEntry) => { + if (entry.type === "message") { + messages.push(entry.message); + } else if (entry.type === "custom_message") { + messages.push( + createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp), + ); + } else if (entry.type === "branch_summary" && entry.summary) { + messages.push(createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp)); + } + }; + + if (compaction) { + // Emit summary first + messages.push(createCompactionSummaryMessage(compaction.summary, compaction.tokensBefore, compaction.timestamp)); + + // Find compaction index in path + const compactionIdx = path.findIndex((e) => e.type === "compaction" && e.id === compaction.id); + + // Emit kept messages (before compaction, starting from firstKeptEntryId) + let foundFirstKept = false; + for (let i = 0; i < compactionIdx; i++) { + const entry = path[i]; + if (entry.id === compaction.firstKeptEntryId) { + foundFirstKept = true; + } + if (foundFirstKept) { + appendMessage(entry); } } - return { messages, thinkingLevel, model }; - } - const compactionEvent = entries[latestCompactionIndex] as CompactionEntry; - - const keptMessages: AppMessage[] = []; - for (let i = compactionEvent.firstKeptEntryIndex; i < entries.length; i++) { - const entry = entries[i]; - if (entry.type === "message") { - keptMessages.push(entry.message); + // Emit messages after compaction + for (let i = compactionIdx + 1; i < path.length; i++) { + const entry = path[i]; + appendMessage(entry); + } + } else { + // No compaction - emit all messages, handle branch summaries and custom messages + for (const entry of path) { + appendMessage(entry); } } - const messages: AppMessage[] = []; - messages.push(createSummaryMessage(compactionEvent.summary)); - messages.push(...keptMessages); - return { messages, thinkingLevel, model }; } @@ -184,34 +390,57 @@ function getDefaultSessionDir(cwd: string): string { return sessionDir; } -function loadEntriesFromFile(filePath: string): SessionEntry[] { +/** Exported for testing */ +export function loadEntriesFromFile(filePath: string): FileEntry[] { if (!existsSync(filePath)) return []; const content = readFileSync(filePath, "utf8"); - const entries: SessionEntry[] = []; + const entries: FileEntry[] = []; const lines = content.trim().split("\n"); for (const line of lines) { if (!line.trim()) continue; try { - const entry = JSON.parse(line) as SessionEntry; + const entry = JSON.parse(line) as FileEntry; entries.push(entry); } catch { // Skip malformed lines } } + // Validate session header + if (entries.length === 0) return entries; + const header = entries[0]; + if (header.type !== "session" || typeof (header as any).id !== "string") { + return []; + } + return entries; } -function findMostRecentSession(sessionDir: string): string | null { +function isValidSessionFile(filePath: string): boolean { + try { + const fd = openSync(filePath, "r"); + const buffer = Buffer.alloc(512); + const bytesRead = readSync(fd, buffer, 0, 512, 0); + closeSync(fd); + const firstLine = buffer.toString("utf8", 0, bytesRead).split("\n")[0]; + if (!firstLine) return false; + const header = JSON.parse(firstLine); + return header.type === "session" && typeof header.id === "string"; + } catch { + return false; + } +} + +/** Exported for testing */ +export function findMostRecentSession(sessionDir: string): string | null { try { const files = readdirSync(sessionDir) .filter((f) => f.endsWith(".jsonl")) - .map((f) => ({ - path: join(sessionDir, f), - mtime: statSync(join(sessionDir, f)).mtime, - })) + .map((f) => join(sessionDir, f)) + .filter(isValidSessionFile) + .map((path) => ({ path, mtime: statSync(path).mtime })) .sort((a, b) => b.mtime.getTime() - a.mtime.getTime()); return files[0]?.path || null; @@ -220,30 +449,41 @@ function findMostRecentSession(sessionDir: string): string | null { } } +/** + * Manages conversation sessions as append-only trees stored in JSONL files. + * + * Each session entry has an id and parentId forming a tree structure. The "leaf" + * pointer tracks the current position. Appending creates a child of the current leaf. + * Branching moves the leaf to an earlier entry, allowing new branches without + * modifying history. + * + * Use buildSessionContext() to get the resolved message list for the LLM, which + * handles compaction summaries and follows the path from root to current leaf. + */ export class SessionManager { private sessionId: string = ""; - private sessionFile: string = ""; + private sessionFile: string | undefined; private sessionDir: string; private cwd: string; private persist: boolean; private flushed: boolean = false; - private inMemoryEntries: SessionEntry[] = []; + private fileEntries: FileEntry[] = []; + private byId: Map = new Map(); + private labelsById: Map = new Map(); + private leafId: string | null = null; - private constructor(cwd: string, sessionDir: string, sessionFile: string | null, persist: boolean) { + private constructor(cwd: string, sessionDir: string, sessionFile: string | undefined, persist: boolean) { this.cwd = cwd; this.sessionDir = sessionDir; + this.persist = persist; if (persist && sessionDir && !existsSync(sessionDir)) { mkdirSync(sessionDir, { recursive: true }); } - this.persist = persist; if (sessionFile) { this.setSessionFile(sessionFile); } else { - this.sessionId = uuidv4(); - const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); - const sessionFile = join(this.getSessionDir(), `${timestamp}_${this.sessionId}.jsonl`); - this.setSessionFile(sessionFile); + this.newSession(); } } @@ -251,24 +491,67 @@ export class SessionManager { setSessionFile(sessionFile: string): void { this.sessionFile = resolve(sessionFile); if (existsSync(this.sessionFile)) { - this.inMemoryEntries = loadEntriesFromFile(this.sessionFile); - const header = this.inMemoryEntries.find((e) => e.type === "session"); - this.sessionId = header ? (header as SessionHeader).id : uuidv4(); + this.fileEntries = loadEntriesFromFile(this.sessionFile); + const header = this.fileEntries.find((e) => e.type === "session") as SessionHeader | undefined; + this.sessionId = header?.id ?? randomUUID(); + + if (migrateToCurrentVersion(this.fileEntries)) { + this._rewriteFile(); + } + + this._buildIndex(); this.flushed = true; } else { - this.sessionId = uuidv4(); - this.inMemoryEntries = []; - this.flushed = false; - const entry: SessionHeader = { - type: "session", - id: this.sessionId, - timestamp: new Date().toISOString(), - cwd: this.cwd, - }; - this.inMemoryEntries.push(entry); + this.newSession(); } } + newSession(): string | undefined { + this.sessionId = randomUUID(); + const timestamp = new Date().toISOString(); + const header: SessionHeader = { + type: "session", + version: CURRENT_SESSION_VERSION, + id: this.sessionId, + timestamp, + cwd: this.cwd, + }; + this.fileEntries = [header]; + this.byId.clear(); + this.leafId = null; + this.flushed = false; + // Only generate filename if persisting and not already set (e.g., via --session flag) + if (this.persist && !this.sessionFile) { + const fileTimestamp = timestamp.replace(/[:.]/g, "-"); + this.sessionFile = join(this.getSessionDir(), `${fileTimestamp}_${this.sessionId}.jsonl`); + } + return this.sessionFile; + } + + private _buildIndex(): void { + this.byId.clear(); + this.labelsById.clear(); + this.leafId = null; + for (const entry of this.fileEntries) { + if (entry.type === "session") continue; + this.byId.set(entry.id, entry); + this.leafId = entry.id; + if (entry.type === "label") { + if (entry.label) { + this.labelsById.set(entry.targetId, entry.label); + } else { + this.labelsById.delete(entry.targetId); + } + } + } + } + + private _rewriteFile(): void { + if (!this.persist || !this.sessionFile) return; + const content = `${this.fileEntries.map((e) => JSON.stringify(e)).join("\n")}\n`; + writeFileSync(this.sessionFile, content); + } + isPersisted(): boolean { return this.persist; } @@ -285,33 +568,18 @@ export class SessionManager { return this.sessionId; } - getSessionFile(): string { + getSessionFile(): string | undefined { return this.sessionFile; } - reset(): void { - this.sessionId = uuidv4(); - this.flushed = false; - const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); - this.sessionFile = join(this.getSessionDir(), `${timestamp}_${this.sessionId}.jsonl`); - this.inMemoryEntries = [ - { - type: "session", - id: this.sessionId, - timestamp: new Date().toISOString(), - cwd: this.cwd, - }, - ]; - } - _persist(entry: SessionEntry): void { - if (!this.persist) return; + if (!this.persist || !this.sessionFile) return; - const hasAssistant = this.inMemoryEntries.some((e) => e.type === "message" && e.message.role === "assistant"); + const hasAssistant = this.fileEntries.some((e) => e.type === "message" && e.message.role === "assistant"); if (!hasAssistant) return; if (!this.flushed) { - for (const e of this.inMemoryEntries) { + for (const e of this.fileEntries) { appendFileSync(this.sessionFile, `${JSON.stringify(e)}\n`); } this.flushed = true; @@ -320,89 +588,405 @@ export class SessionManager { } } - saveMessage(message: AppMessage): void { + private _appendEntry(entry: SessionEntry): void { + this.fileEntries.push(entry); + this.byId.set(entry.id, entry); + this.leafId = entry.id; + this._persist(entry); + } + + /** Append a message as child of current leaf, then advance leaf. Returns entry id. + * Does not allow writing CompactionSummaryMessage and BranchSummaryMessage directly. + * Reason: we want these to be top-level entries in the session, not message session entries, + * so it is easier to find them. + * These need to be appended via appendCompaction() and appendBranchSummary() methods. + */ + appendMessage(message: Message | HookMessage | BashExecutionMessage): string { const entry: SessionMessageEntry = { type: "message", + id: generateId(this.byId), + parentId: this.leafId, timestamp: new Date().toISOString(), message, }; - this.inMemoryEntries.push(entry); - this._persist(entry); + this._appendEntry(entry); + return entry.id; } - saveThinkingLevelChange(thinkingLevel: string): void { + /** Append a thinking level change as child of current leaf, then advance leaf. Returns entry id. */ + appendThinkingLevelChange(thinkingLevel: string): string { const entry: ThinkingLevelChangeEntry = { type: "thinking_level_change", + id: generateId(this.byId), + parentId: this.leafId, timestamp: new Date().toISOString(), thinkingLevel, }; - this.inMemoryEntries.push(entry); - this._persist(entry); + this._appendEntry(entry); + return entry.id; } - saveModelChange(provider: string, modelId: string): void { + /** Append a model change as child of current leaf, then advance leaf. Returns entry id. */ + appendModelChange(provider: string, modelId: string): string { const entry: ModelChangeEntry = { type: "model_change", + id: generateId(this.byId), + parentId: this.leafId, timestamp: new Date().toISOString(), provider, modelId, }; - this.inMemoryEntries.push(entry); - this._persist(entry); + this._appendEntry(entry); + return entry.id; } - saveCompaction(entry: CompactionEntry): void { - this.inMemoryEntries.push(entry); - this._persist(entry); + /** Append a compaction summary as child of current leaf, then advance leaf. Returns entry id. */ + appendCompaction( + summary: string, + firstKeptEntryId: string, + tokensBefore: number, + details?: T, + fromHook?: boolean, + ): string { + const entry: CompactionEntry = { + type: "compaction", + id: generateId(this.byId), + parentId: this.leafId, + timestamp: new Date().toISOString(), + summary, + firstKeptEntryId, + tokensBefore, + details, + fromHook, + }; + this._appendEntry(entry); + return entry.id; + } + + /** Append a custom entry (for hooks) as child of current leaf, then advance leaf. Returns entry id. */ + appendCustomEntry(customType: string, data?: unknown): string { + const entry: CustomEntry = { + type: "custom", + customType, + data, + id: generateId(this.byId), + parentId: this.leafId, + timestamp: new Date().toISOString(), + }; + this._appendEntry(entry); + return entry.id; + } + + /** + * Append a custom message entry (for hooks) that participates in LLM context. + * @param customType Hook identifier for filtering on reload + * @param content Message content (string or TextContent/ImageContent array) + * @param display Whether to show in TUI (true = styled display, false = hidden) + * @param details Optional hook-specific metadata (not sent to LLM) + * @returns Entry id + */ + appendCustomMessageEntry( + customType: string, + content: string | (TextContent | ImageContent)[], + display: boolean, + details?: T, + ): string { + const entry: CustomMessageEntry = { + type: "custom_message", + customType, + content, + display, + details, + id: generateId(this.byId), + parentId: this.leafId, + timestamp: new Date().toISOString(), + }; + this._appendEntry(entry); + return entry.id; + } + + // ========================================================================= + // Tree Traversal + // ========================================================================= + + getLeafId(): string | null { + return this.leafId; + } + + getLeafEntry(): SessionEntry | undefined { + return this.leafId ? this.byId.get(this.leafId) : undefined; + } + + getEntry(id: string): SessionEntry | undefined { + return this.byId.get(id); + } + + /** + * Get all direct children of an entry. + */ + getChildren(parentId: string): SessionEntry[] { + const children: SessionEntry[] = []; + for (const entry of this.byId.values()) { + if (entry.parentId === parentId) { + children.push(entry); + } + } + return children; + } + + /** + * Get the label for an entry, if any. + */ + getLabel(id: string): string | undefined { + return this.labelsById.get(id); + } + + /** + * Set or clear a label on an entry. + * Labels are user-defined markers for bookmarking/navigation. + * Pass undefined or empty string to clear the label. + */ + appendLabelChange(targetId: string, label: string | undefined): string { + if (!this.byId.has(targetId)) { + throw new Error(`Entry ${targetId} not found`); + } + const entry: LabelEntry = { + type: "label", + id: generateId(this.byId), + parentId: this.leafId, + timestamp: new Date().toISOString(), + targetId, + label, + }; + this._appendEntry(entry); + if (label) { + this.labelsById.set(targetId, label); + } else { + this.labelsById.delete(targetId); + } + return entry.id; + } + + /** + * Walk from entry to root, returning all entries in path order. + * Includes all entry types (messages, compaction, model changes, etc.). + * Use buildSessionContext() to get the resolved messages for the LLM. + */ + getPath(fromId?: string): SessionEntry[] { + const path: SessionEntry[] = []; + const startId = fromId ?? this.leafId; + let current = startId ? this.byId.get(startId) : undefined; + while (current) { + path.unshift(current); + current = current.parentId ? this.byId.get(current.parentId) : undefined; + } + return path; } /** * Build the session context (what gets sent to the LLM). - * If compacted, returns summary + kept messages. Otherwise all messages. - * Includes thinking level and model. + * Uses tree traversal from current leaf. */ buildSessionContext(): SessionContext { - return buildSessionContext(this.getEntries()); + return buildSessionContext(this.getEntries(), this.leafId, this.byId); } /** - * Get all session entries. Returns a defensive copy. - * Use buildSessionContext() if you need the messages for the LLM. + * Get session header. */ - getEntries(): SessionEntry[] { - return [...this.inMemoryEntries]; + getHeader(): SessionHeader | null { + const h = this.fileEntries.find((e) => e.type === "session"); + return h ? (h as SessionHeader) : null; } - createBranchedSessionFromEntries(entries: SessionEntry[], branchBeforeIndex: number): string | null { - const newSessionId = uuidv4(); - const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); - const newSessionFile = join(this.getSessionDir(), `${timestamp}_${newSessionId}.jsonl`); + /** + * Get all session entries (excludes header). Returns a shallow copy. + * The session is append-only: use appendXXX() to add entries, branch() to + * change the leaf pointer. Entries cannot be modified or deleted. + */ + getEntries(): SessionEntry[] { + return this.fileEntries.filter((e): e is SessionEntry => e.type !== "session"); + } - const newEntries: SessionEntry[] = []; - for (let i = 0; i < branchBeforeIndex; i++) { - const entry = entries[i]; + /** + * Get the session as a tree structure. Returns a shallow defensive copy of all entries. + * A well-formed session has exactly one root (first entry with parentId === null). + * Orphaned entries (broken parent chain) are also returned as roots. + */ + getTree(): SessionTreeNode[] { + const entries = this.getEntries(); + const nodeMap = new Map(); + const roots: SessionTreeNode[] = []; - if (entry.type === "session") { - newEntries.push({ - ...entry, - id: newSessionId, - timestamp: new Date().toISOString(), - branchedFrom: this.persist ? this.sessionFile : undefined, - }); + // Create nodes with resolved labels + for (const entry of entries) { + const label = this.labelsById.get(entry.id); + nodeMap.set(entry.id, { entry, children: [], label }); + } + + // Build tree + for (const entry of entries) { + const node = nodeMap.get(entry.id)!; + if (entry.parentId === null || entry.parentId === entry.id) { + roots.push(node); } else { - newEntries.push(entry); + const parent = nodeMap.get(entry.parentId); + if (parent) { + parent.children.push(node); + } else { + // Orphan - treat as root + roots.push(node); + } + } + } + + // Sort children by timestamp (oldest first, newest at bottom) + // Use iterative approach to avoid stack overflow on deep trees + const stack: SessionTreeNode[] = [...roots]; + while (stack.length > 0) { + const node = stack.pop()!; + node.children.sort((a, b) => new Date(a.entry.timestamp).getTime() - new Date(b.entry.timestamp).getTime()); + stack.push(...node.children); + } + + return roots; + } + + // ========================================================================= + // Branching + // ========================================================================= + + /** + * Start a new branch from an earlier entry. + * Moves the leaf pointer to the specified entry. The next appendXXX() call + * will create a child of that entry, forming a new branch. Existing entries + * are not modified or deleted. + */ + branch(branchFromId: string): void { + if (!this.byId.has(branchFromId)) { + throw new Error(`Entry ${branchFromId} not found`); + } + this.leafId = branchFromId; + } + + /** + * Reset the leaf pointer to null (before any entries). + * The next appendXXX() call will create a new root entry (parentId = null). + * Use this when navigating to re-edit the first user message. + */ + resetLeaf(): void { + this.leafId = null; + } + + /** + * Start a new branch with a summary of the abandoned path. + * Same as branch(), but also appends a branch_summary entry that captures + * context from the abandoned conversation path. + */ + branchWithSummary(branchFromId: string | null, summary: string, details?: unknown, fromHook?: boolean): string { + if (branchFromId !== null && !this.byId.has(branchFromId)) { + throw new Error(`Entry ${branchFromId} not found`); + } + this.leafId = branchFromId; + const entry: BranchSummaryEntry = { + type: "branch_summary", + id: generateId(this.byId), + parentId: branchFromId, + timestamp: new Date().toISOString(), + fromId: branchFromId ?? "root", + summary, + details, + fromHook, + }; + this._appendEntry(entry); + return entry.id; + } + + /** + * Create a new session file containing only the path from root to the specified leaf. + * Useful for extracting a single conversation path from a branched session. + * Returns the new session file path, or undefined if not persisting. + */ + createBranchedSession(leafId: string): string | undefined { + const path = this.getPath(leafId); + if (path.length === 0) { + throw new Error(`Entry ${leafId} not found`); + } + + // Filter out LabelEntry from path - we'll recreate them from the resolved map + const pathWithoutLabels = path.filter((e) => e.type !== "label"); + + const newSessionId = randomUUID(); + const timestamp = new Date().toISOString(); + const fileTimestamp = timestamp.replace(/[:.]/g, "-"); + const newSessionFile = join(this.getSessionDir(), `${fileTimestamp}_${newSessionId}.jsonl`); + + const header: SessionHeader = { + type: "session", + version: CURRENT_SESSION_VERSION, + id: newSessionId, + timestamp, + cwd: this.cwd, + branchedFrom: this.persist ? this.sessionFile : undefined, + }; + + // Collect labels for entries in the path + const pathEntryIds = new Set(pathWithoutLabels.map((e) => e.id)); + const labelsToWrite: Array<{ targetId: string; label: string }> = []; + for (const [targetId, label] of this.labelsById) { + if (pathEntryIds.has(targetId)) { + labelsToWrite.push({ targetId, label }); } } if (this.persist) { - for (const entry of newEntries) { + appendFileSync(newSessionFile, `${JSON.stringify(header)}\n`); + for (const entry of pathWithoutLabels) { appendFileSync(newSessionFile, `${JSON.stringify(entry)}\n`); } + // Write fresh label entries at the end + const lastEntryId = pathWithoutLabels[pathWithoutLabels.length - 1]?.id || null; + let parentId = lastEntryId; + const labelEntries: LabelEntry[] = []; + for (const { targetId, label } of labelsToWrite) { + const labelEntry: LabelEntry = { + type: "label", + id: generateId(new Set(pathEntryIds)), + parentId, + timestamp: new Date().toISOString(), + targetId, + label, + }; + appendFileSync(newSessionFile, `${JSON.stringify(labelEntry)}\n`); + pathEntryIds.add(labelEntry.id); + labelEntries.push(labelEntry); + parentId = labelEntry.id; + } + this.fileEntries = [header, ...pathWithoutLabels, ...labelEntries]; + this.sessionId = newSessionId; + this._buildIndex(); return newSessionFile; } - this.inMemoryEntries = newEntries; + + // In-memory mode: replace current session with the path + labels + const labelEntries: LabelEntry[] = []; + let parentId = pathWithoutLabels[pathWithoutLabels.length - 1]?.id || null; + for (const { targetId, label } of labelsToWrite) { + const labelEntry: LabelEntry = { + type: "label", + id: generateId(new Set([...pathEntryIds, ...labelEntries.map((e) => e.id)])), + parentId, + timestamp: new Date().toISOString(), + targetId, + label, + }; + labelEntries.push(labelEntry); + parentId = labelEntry.id; + } + this.fileEntries = [header, ...pathWithoutLabels, ...labelEntries]; this.sessionId = newSessionId; - return null; + this._buildIndex(); + return undefined; } /** @@ -412,7 +996,7 @@ export class SessionManager { */ static create(cwd: string, sessionDir?: string): SessionManager { const dir = sessionDir ?? getDefaultSessionDir(cwd); - return new SessionManager(cwd, dir, null, true); + return new SessionManager(cwd, dir, undefined, true); } /** @@ -441,12 +1025,12 @@ export class SessionManager { if (mostRecent) { return new SessionManager(cwd, dir, mostRecent, true); } - return new SessionManager(cwd, dir, null, true); + return new SessionManager(cwd, dir, undefined, true); } /** Create an in-memory session (no file persistence) */ static inMemory(cwd: string = process.cwd()): SessionManager { - return new SessionManager(cwd, "", null, false); + return new SessionManager(cwd, "", undefined, false); } /** diff --git a/packages/coding-agent/src/core/settings-manager.ts b/packages/coding-agent/src/core/settings-manager.ts index 1e922c08..3a116d63 100644 --- a/packages/coding-agent/src/core/settings-manager.ts +++ b/packages/coding-agent/src/core/settings-manager.ts @@ -8,6 +8,10 @@ export interface CompactionSettings { keepRecentTokens?: number; // default: 20000 } +export interface BranchSummarySettings { + reserveTokens?: number; // default: 16384 (tokens reserved for prompt + LLM response) +} + export interface RetrySettings { enabled?: boolean; // default: true maxRetries?: number; // default: 3 @@ -38,6 +42,7 @@ export interface Settings { queueMode?: "all" | "one-at-a-time"; theme?: string; compaction?: CompactionSettings; + branchSummary?: BranchSummarySettings; retry?: RetrySettings; hideThinkingBlock?: boolean; shellPath?: string; // Custom shell path (e.g., for Cygwin users on Windows) @@ -255,6 +260,12 @@ export class SettingsManager { }; } + getBranchSummarySettings(): { reserveTokens: number } { + return { + reserveTokens: this.settings.branchSummary?.reserveTokens ?? 16384, + }; + } + getRetryEnabled(): boolean { return this.settings.retry?.enabled ?? true; } @@ -303,7 +314,7 @@ export class SettingsManager { } getHookPaths(): string[] { - return this.settings.hooks ?? []; + return [...(this.settings.hooks ?? [])]; } setHookPaths(paths: string[]): void { @@ -321,7 +332,7 @@ export class SettingsManager { } getCustomToolPaths(): string[] { - return this.settings.customTools ?? []; + return [...(this.settings.customTools ?? [])]; } setCustomToolPaths(paths: string[]): void { @@ -349,9 +360,9 @@ export class SettingsManager { enableClaudeProject: this.settings.skills?.enableClaudeProject ?? true, enablePiUser: this.settings.skills?.enablePiUser ?? true, enablePiProject: this.settings.skills?.enablePiProject ?? true, - customDirectories: this.settings.skills?.customDirectories ?? [], - ignoredSkills: this.settings.skills?.ignoredSkills ?? [], - includeSkills: this.settings.skills?.includeSkills ?? [], + customDirectories: [...(this.settings.skills?.customDirectories ?? [])], + ignoredSkills: [...(this.settings.skills?.ignoredSkills ?? [])], + includeSkills: [...(this.settings.skills?.includeSkills ?? [])], }; } diff --git a/packages/coding-agent/src/core/tools/bash.ts b/packages/coding-agent/src/core/tools/bash.ts index f11a677f..9d851922 100644 --- a/packages/coding-agent/src/core/tools/bash.ts +++ b/packages/coding-agent/src/core/tools/bash.ts @@ -2,7 +2,7 @@ import { randomBytes } from "node:crypto"; import { createWriteStream } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { spawn } from "child_process"; import { getShellConfig, killProcessTree } from "../../utils/shell.js"; diff --git a/packages/coding-agent/src/core/tools/edit.ts b/packages/coding-agent/src/core/tools/edit.ts index ff040091..a46a209d 100644 --- a/packages/coding-agent/src/core/tools/edit.ts +++ b/packages/coding-agent/src/core/tools/edit.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import * as Diff from "diff"; import { constants } from "fs"; @@ -23,8 +23,13 @@ function restoreLineEndings(text: string, ending: "\r\n" | "\n"): string { /** * Generate a unified diff string with line numbers and context + * Returns both the diff string and the first changed line number (in the new file) */ -function generateDiffString(oldContent: string, newContent: string, contextLines = 4): string { +function generateDiffString( + oldContent: string, + newContent: string, + contextLines = 4, +): { diff: string; firstChangedLine: number | undefined } { const parts = Diff.diffLines(oldContent, newContent); const output: string[] = []; @@ -36,6 +41,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines let oldLineNum = 1; let newLineNum = 1; let lastWasChange = false; + let firstChangedLine: number | undefined; for (let i = 0; i < parts.length; i++) { const part = parts[i]; @@ -45,6 +51,11 @@ function generateDiffString(oldContent: string, newContent: string, contextLines } if (part.added || part.removed) { + // Capture the first changed line (in the new file) + if (firstChangedLine === undefined) { + firstChangedLine = newLineNum; + } + // Show the change for (const line of raw) { if (part.added) { @@ -113,7 +124,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines } } - return output.join("\n"); + return { diff: output.join("\n"), firstChangedLine }; } const editSchema = Type.Object({ @@ -122,6 +133,13 @@ const editSchema = Type.Object({ newText: Type.String({ description: "New text to replace the old text with" }), }); +export interface EditToolDetails { + /** Unified diff of the changes made */ + diff: string; + /** Line number of the first change in the new file (for editor navigation) */ + firstChangedLine?: number; +} + export function createEditTool(cwd: string): AgentTool { return { name: "edit", @@ -138,7 +156,7 @@ export function createEditTool(cwd: string): AgentTool { return new Promise<{ content: Array<{ type: "text"; text: string }>; - details: { diff: string } | undefined; + details: EditToolDetails | undefined; }>((resolve, reject) => { // Check if already aborted if (signal?.aborted) { @@ -257,6 +275,7 @@ export function createEditTool(cwd: string): AgentTool { signal.removeEventListener("abort", onAbort); } + const diffResult = generateDiffString(normalizedContent, normalizedNewContent); resolve({ content: [ { @@ -264,7 +283,7 @@ export function createEditTool(cwd: string): AgentTool { text: `Successfully replaced text in ${path}.`, }, ], - details: { diff: generateDiffString(normalizedContent, normalizedNewContent) }, + details: { diff: diffResult.diff, firstChangedLine: diffResult.firstChangedLine }, }); } catch (error: any) { // Clean up abort handler diff --git a/packages/coding-agent/src/core/tools/find.ts b/packages/coding-agent/src/core/tools/find.ts index 174a46d3..07c7694b 100644 --- a/packages/coding-agent/src/core/tools/find.ts +++ b/packages/coding-agent/src/core/tools/find.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { spawnSync } from "child_process"; import { existsSync } from "fs"; diff --git a/packages/coding-agent/src/core/tools/grep.ts b/packages/coding-agent/src/core/tools/grep.ts index 80996c9f..5402bd83 100644 --- a/packages/coding-agent/src/core/tools/grep.ts +++ b/packages/coding-agent/src/core/tools/grep.ts @@ -1,5 +1,5 @@ import { createInterface } from "node:readline"; -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { spawn } from "child_process"; import { readFileSync, type Stats, statSync } from "fs"; diff --git a/packages/coding-agent/src/core/tools/index.ts b/packages/coding-agent/src/core/tools/index.ts index 539395fd..74701568 100644 --- a/packages/coding-agent/src/core/tools/index.ts +++ b/packages/coding-agent/src/core/tools/index.ts @@ -1,5 +1,3 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; - export { type BashToolDetails, bashTool, createBashTool } from "./bash.js"; export { createEditTool, editTool } from "./edit.js"; export { createFindTool, type FindToolDetails, findTool } from "./find.js"; @@ -9,6 +7,7 @@ export { createReadTool, type ReadToolDetails, readTool } from "./read.js"; export type { TruncationResult } from "./truncate.js"; export { createWriteTool, writeTool } from "./write.js"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { bashTool, createBashTool } from "./bash.js"; import { createEditTool, editTool } from "./edit.js"; import { createFindTool, findTool } from "./find.js"; diff --git a/packages/coding-agent/src/core/tools/ls.ts b/packages/coding-agent/src/core/tools/ls.ts index 4ebe26ed..ca27bfe4 100644 --- a/packages/coding-agent/src/core/tools/ls.ts +++ b/packages/coding-agent/src/core/tools/ls.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { existsSync, readdirSync, statSync } from "fs"; import nodePath from "path"; diff --git a/packages/coding-agent/src/core/tools/read.ts b/packages/coding-agent/src/core/tools/read.ts index 36d75818..da3a62b6 100644 --- a/packages/coding-agent/src/core/tools/read.ts +++ b/packages/coding-agent/src/core/tools/read.ts @@ -1,4 +1,5 @@ -import type { AgentTool, ImageContent, TextContent } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import type { ImageContent, TextContent } from "@mariozechner/pi-ai"; import { Type } from "@sinclair/typebox"; import { constants } from "fs"; import { access, readFile } from "fs/promises"; diff --git a/packages/coding-agent/src/core/tools/write.ts b/packages/coding-agent/src/core/tools/write.ts index 5aa2b336..02317b70 100644 --- a/packages/coding-agent/src/core/tools/write.ts +++ b/packages/coding-agent/src/core/tools/write.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { mkdir, writeFile } from "fs/promises"; import { dirname } from "path"; diff --git a/packages/coding-agent/src/index.ts b/packages/coding-agent/src/index.ts index 2683765e..1dae4688 100644 --- a/packages/coding-agent/src/index.ts +++ b/packages/coding-agent/src/index.ts @@ -4,7 +4,6 @@ export { type AgentSessionConfig, type AgentSessionEvent, type AgentSessionEventListener, - type CompactionResult, type ModelCycleResult, type PromptOptions, type SessionStats, @@ -13,17 +12,26 @@ export { export { type ApiKeyCredential, type AuthCredential, AuthStorage, type OAuthCredential } from "./core/auth-storage.js"; // Compaction export { + type BranchPreparation, + type BranchSummaryResult, + type CollectEntriesResult, + type CompactionResult, type CutPointResult, calculateContextTokens, + collectEntriesForBranchSummary, compact, DEFAULT_COMPACTION_SETTINGS, estimateTokens, + type FileOperations, findCutPoint, findTurnStartIndex, + type GenerateBranchSummaryOptions, + generateBranchSummary, generateSummary, getLastAssistantUsage, + prepareBranchEntries, shouldCompact, -} from "./core/compaction.js"; +} from "./core/compaction/index.js"; // Custom tools export type { AgentToolUpdateCallback, @@ -38,31 +46,7 @@ export type { ToolUIContext, } from "./core/custom-tools/index.js"; export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js"; -export type { - AgentEndEvent, - AgentStartEvent, - BashToolResultEvent, - CustomToolResultEvent, - EditToolResultEvent, - FindToolResultEvent, - GrepToolResultEvent, - HookAPI, - HookEvent, - HookEventContext, - HookFactory, - HookUIContext, - LsToolResultEvent, - ReadToolResultEvent, - SessionEvent, - SessionEventResult, - ToolCallEvent, - ToolCallEventResult, - ToolResultEvent, - ToolResultEventResult, - TurnEndEvent, - TurnStartEvent, - WriteToolResultEvent, -} from "./core/hooks/index.js"; +export type * from "./core/hooks/index.js"; // Hook system types and type guards export { isBashToolResult, @@ -73,7 +57,7 @@ export { isReadToolResult, isWriteToolResult, } from "./core/hooks/index.js"; -export { messageTransformer } from "./core/messages.js"; +export { convertToLlm } from "./core/messages.js"; export { ModelRegistry } from "./core/model-registry.js"; // SDK for programmatic usage export { @@ -107,20 +91,24 @@ export { readOnlyTools, } from "./core/sdk.js"; export { + type BranchSummaryEntry, buildSessionContext, type CompactionEntry, - createSummaryMessage, + CURRENT_SESSION_VERSION, + type CustomEntry, + type CustomMessageEntry, + type FileEntry, getLatestCompactionEntry, type ModelChangeEntry, + migrateSessionEntries, parseSessionEntries, - type SessionContext as LoadedSession, + type SessionContext, type SessionEntry, + type SessionEntryBase, type SessionHeader, type SessionInfo, SessionManager, type SessionMessageEntry, - SUMMARY_PREFIX, - SUMMARY_SUFFIX, type ThinkingLevelChangeEntry, } from "./core/session-manager.js"; export { diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index 5edba853..9b67d748 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -5,8 +5,7 @@ * createAgentSession() options. The SDK does the heavy lifting. */ -import type { Attachment } from "@mariozechner/pi-agent-core"; -import { supportsXhigh } from "@mariozechner/pi-ai"; +import { type ImageContent, supportsXhigh } from "@mariozechner/pi-ai"; import chalk from "chalk"; import { existsSync } from "fs"; import { join } from "path"; @@ -34,10 +33,10 @@ import { initTheme, stopThemeWatcher } from "./modes/interactive/theme/theme.js" import { getChangelogPath, getNewEntries, parseChangelog } from "./utils/changelog.js"; import { ensureTool } from "./utils/tools-manager.js"; -async function checkForNewVersion(currentVersion: string): Promise { +async function checkForNewVersion(currentVersion: string): Promise { try { const response = await fetch("https://registry.npmjs.org/@mariozechner/pi -coding-agent/latest"); - if (!response.ok) return null; + if (!response.ok) return undefined; const data = (await response.json()) as { version?: string }; const latestVersion = data.version; @@ -46,26 +45,26 @@ async function checkForNewVersion(currentVersion: string): Promise, + versionCheckPromise: Promise, initialMessages: string[], customTools: LoadedCustomTool[], setToolUIContext: (uiContext: HookUIContext, hasUI: boolean) => void, initialMessage?: string, - initialAttachments?: Attachment[], - fdPath: string | null = null, + initialImages?: ImageContent[], + fdPath: string | undefined = undefined, ): Promise { const mode = new InteractiveMode(session, version, changelogMarkdown, customTools, setToolUIContext, fdPath); @@ -77,7 +76,7 @@ async function runInteractiveMode( } }); - mode.renderInitialMessages(session.state); + mode.renderInitialMessages(); if (migratedProviders.length > 0) { mode.showWarning(`Migrated credentials to auth.json: ${migratedProviders.join(", ")}`); @@ -93,7 +92,7 @@ async function runInteractiveMode( if (initialMessage) { try { - await session.prompt(initialMessage, { attachments: initialAttachments }); + await session.prompt(initialMessage, { images: initialImages }); } catch (error: unknown) { const errorMessage = error instanceof Error ? error.message : "Unknown error occurred"; mode.showError(errorMessage); @@ -122,31 +121,31 @@ async function runInteractiveMode( async function prepareInitialMessage(parsed: Args): Promise<{ initialMessage?: string; - initialAttachments?: Attachment[]; + initialImages?: ImageContent[]; }> { if (parsed.fileArgs.length === 0) { return {}; } - const { textContent, imageAttachments } = await processFileArguments(parsed.fileArgs); + const { text, images } = await processFileArguments(parsed.fileArgs); let initialMessage: string; if (parsed.messages.length > 0) { - initialMessage = textContent + parsed.messages[0]; + initialMessage = text + parsed.messages[0]; parsed.messages.shift(); } else { - initialMessage = textContent; + initialMessage = text; } return { initialMessage, - initialAttachments: imageAttachments.length > 0 ? imageAttachments : undefined, + initialImages: images.length > 0 ? images : undefined, }; } -function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | null { +function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | undefined { if (parsed.continue || parsed.resume) { - return null; + return undefined; } const lastVersion = settingsManager.getLastChangelogVersion(); @@ -166,10 +165,10 @@ function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): } } - return null; + return undefined; } -function createSessionManager(parsed: Args, cwd: string): SessionManager | null { +function createSessionManager(parsed: Args, cwd: string): SessionManager | undefined { if (parsed.noSession) { return SessionManager.inMemory(); } @@ -184,8 +183,8 @@ function createSessionManager(parsed: Args, cwd: string): SessionManager | null if (parsed.sessionDir) { return SessionManager.create(cwd, parsed.sessionDir); } - // Default case (new session) returns null, SDK will create one - return null; + // Default case (new session) returns undefined, SDK will create one + return undefined; } /** Discover SYSTEM.md file if no CLI system prompt was provided */ @@ -208,7 +207,7 @@ function discoverSystemPromptFile(): string | undefined { function buildSessionOptions( parsed: Args, scopedModels: ScopedModel[], - sessionManager: SessionManager | null, + sessionManager: SessionManager | undefined, modelRegistry: ModelRegistry, ): CreateAgentSessionOptions { const options: CreateAgentSessionOptions = {}; @@ -330,7 +329,7 @@ export async function main(args: string[]) { } const cwd = process.cwd(); - const { initialMessage, initialAttachments } = await prepareInitialMessage(parsed); + const { initialMessage, initialImages } = await prepareInitialMessage(parsed); time("prepareInitialMessage"); const isInteractive = !parsed.print && parsed.mode === undefined; const mode = parsed.mode || "text"; @@ -409,7 +408,7 @@ export async function main(args: string[]) { if (mode === "rpc") { await runRpcMode(session); } else if (isInteractive) { - const versionCheckPromise = checkForNewVersion(VERSION).catch(() => null); + const versionCheckPromise = checkForNewVersion(VERSION).catch(() => undefined); const changelogMarkdown = getChangelogForDisplay(parsed, settingsManager); if (scopedModels.length > 0) { @@ -438,11 +437,11 @@ export async function main(args: string[]) { customToolsResult.tools, customToolsResult.setUIContext, initialMessage, - initialAttachments, + initialImages, fdPath, ); } else { - await runPrintMode(session, mode, parsed.messages, initialMessage, initialAttachments); + await runPrintMode(session, mode, parsed.messages, initialMessage, initialImages); stopThemeWatcher(); if (process.stdout.writableLength > 0) { await new Promise((resolve) => process.stdout.once("drain", resolve)); diff --git a/packages/coding-agent/src/modes/interactive/components/bash-execution.ts b/packages/coding-agent/src/modes/interactive/components/bash-execution.ts index 64a32af7..2e914fa9 100644 --- a/packages/coding-agent/src/modes/interactive/components/bash-execution.ts +++ b/packages/coding-agent/src/modes/interactive/components/bash-execution.ts @@ -21,7 +21,7 @@ export class BashExecutionComponent extends Container { private command: string; private outputLines: string[] = []; private status: "running" | "complete" | "cancelled" | "error" = "running"; - private exitCode: number | null = null; + private exitCode: number | undefined = undefined; private loader: Loader; private truncationResult?: TruncationResult; private fullOutputPath?: string; @@ -90,13 +90,17 @@ export class BashExecutionComponent extends Container { } setComplete( - exitCode: number | null, + exitCode: number | undefined, cancelled: boolean, truncationResult?: TruncationResult, fullOutputPath?: string, ): void { this.exitCode = exitCode; - this.status = cancelled ? "cancelled" : exitCode !== 0 && exitCode !== null ? "error" : "complete"; + this.status = cancelled + ? "cancelled" + : exitCode !== 0 && exitCode !== undefined && exitCode !== null + ? "error" + : "complete"; this.truncationResult = truncationResult; this.fullOutputPath = fullOutputPath; diff --git a/packages/coding-agent/src/modes/interactive/components/branch-summary-message.ts b/packages/coding-agent/src/modes/interactive/components/branch-summary-message.ts new file mode 100644 index 00000000..d46b2bc5 --- /dev/null +++ b/packages/coding-agent/src/modes/interactive/components/branch-summary-message.ts @@ -0,0 +1,42 @@ +import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui"; +import type { BranchSummaryMessage } from "../../../core/messages.js"; +import { getMarkdownTheme, theme } from "../theme/theme.js"; + +/** + * Component that renders a branch summary message with collapsed/expanded state. + * Uses same background color as hook messages for visual consistency. + */ +export class BranchSummaryMessageComponent extends Box { + private expanded = false; + private message: BranchSummaryMessage; + + constructor(message: BranchSummaryMessage) { + super(1, 1, (t) => theme.bg("customMessageBg", t)); + this.message = message; + this.updateDisplay(); + } + + setExpanded(expanded: boolean): void { + this.expanded = expanded; + this.updateDisplay(); + } + + private updateDisplay(): void { + this.clear(); + + const label = theme.fg("customMessageLabel", `\x1b[1m[branch]\x1b[22m`); + this.addChild(new Text(label, 0, 0)); + this.addChild(new Spacer(1)); + + if (this.expanded) { + const header = "**Branch Summary**\n\n"; + this.addChild( + new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), { + color: (text: string) => theme.fg("customMessageText", text), + }), + ); + } else { + this.addChild(new Text(theme.fg("customMessageText", "Branch summary (ctrl+o to expand)"), 0, 0)); + } + } +} diff --git a/packages/coding-agent/src/modes/interactive/components/compaction-summary-message.ts b/packages/coding-agent/src/modes/interactive/components/compaction-summary-message.ts new file mode 100644 index 00000000..dc07d3b5 --- /dev/null +++ b/packages/coding-agent/src/modes/interactive/components/compaction-summary-message.ts @@ -0,0 +1,45 @@ +import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui"; +import type { CompactionSummaryMessage } from "../../../core/messages.js"; +import { getMarkdownTheme, theme } from "../theme/theme.js"; + +/** + * Component that renders a compaction message with collapsed/expanded state. + * Uses same background color as hook messages for visual consistency. + */ +export class CompactionSummaryMessageComponent extends Box { + private expanded = false; + private message: CompactionSummaryMessage; + + constructor(message: CompactionSummaryMessage) { + super(1, 1, (t) => theme.bg("customMessageBg", t)); + this.message = message; + this.updateDisplay(); + } + + setExpanded(expanded: boolean): void { + this.expanded = expanded; + this.updateDisplay(); + } + + private updateDisplay(): void { + this.clear(); + + const tokenStr = this.message.tokensBefore.toLocaleString(); + const label = theme.fg("customMessageLabel", `\x1b[1m[compaction]\x1b[22m`); + this.addChild(new Text(label, 0, 0)); + this.addChild(new Spacer(1)); + + if (this.expanded) { + const header = `**Compacted from ${tokenStr} tokens**\n\n`; + this.addChild( + new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), { + color: (text: string) => theme.fg("customMessageText", text), + }), + ); + } else { + this.addChild( + new Text(theme.fg("customMessageText", `Compacted from ${tokenStr} tokens (ctrl+o to expand)`), 0, 0), + ); + } + } +} diff --git a/packages/coding-agent/src/modes/interactive/components/compaction.ts b/packages/coding-agent/src/modes/interactive/components/compaction.ts deleted file mode 100644 index f2835ee7..00000000 --- a/packages/coding-agent/src/modes/interactive/components/compaction.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui"; -import { getMarkdownTheme, theme } from "../theme/theme.js"; - -/** - * Component that renders a compaction indicator with collapsed/expanded state. - * Collapsed: shows "Context compacted from X tokens" - * Expanded: shows the full summary rendered as markdown (like a user message) - */ -export class CompactionComponent extends Container { - private expanded = false; - private tokensBefore: number; - private summary: string; - - constructor(tokensBefore: number, summary: string) { - super(); - this.tokensBefore = tokensBefore; - this.summary = summary; - this.updateDisplay(); - } - - setExpanded(expanded: boolean): void { - this.expanded = expanded; - this.updateDisplay(); - } - - private updateDisplay(): void { - this.clear(); - - if (this.expanded) { - // Show header + summary as markdown (like user message) - this.addChild(new Spacer(1)); - const header = `**Context compacted from ${this.tokensBefore.toLocaleString()} tokens**\n\n`; - this.addChild( - new Markdown(header + this.summary, 1, 1, getMarkdownTheme(), { - bgColor: (text: string) => theme.bg("userMessageBg", text), - color: (text: string) => theme.fg("userMessageText", text), - }), - ); - this.addChild(new Spacer(1)); - } else { - // Collapsed: simple text in warning color with token count - const tokenStr = this.tokensBefore.toLocaleString(); - this.addChild( - new Text( - theme.fg("warning", `Earlier messages compacted from ${tokenStr} tokens (ctrl+o to expand)`), - 1, - 1, - ), - ); - } - } -} diff --git a/packages/coding-agent/src/modes/interactive/components/hook-message.ts b/packages/coding-agent/src/modes/interactive/components/hook-message.ts new file mode 100644 index 00000000..186e902f --- /dev/null +++ b/packages/coding-agent/src/modes/interactive/components/hook-message.ts @@ -0,0 +1,96 @@ +import type { TextContent } from "@mariozechner/pi-ai"; +import type { Component } from "@mariozechner/pi-tui"; +import { Box, Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui"; +import type { HookMessageRenderer } from "../../../core/hooks/types.js"; +import type { HookMessage } from "../../../core/messages.js"; +import { getMarkdownTheme, theme } from "../theme/theme.js"; + +/** + * Component that renders a custom message entry from hooks. + * Uses distinct styling to differentiate from user messages. + */ +export class HookMessageComponent extends Container { + private message: HookMessage; + private customRenderer?: HookMessageRenderer; + private box: Box; + private customComponent?: Component; + private _expanded = false; + + constructor(message: HookMessage, customRenderer?: HookMessageRenderer) { + super(); + this.message = message; + this.customRenderer = customRenderer; + + this.addChild(new Spacer(1)); + + // Create box with purple background (used for default rendering) + this.box = new Box(1, 1, (t) => theme.bg("customMessageBg", t)); + + this.rebuild(); + } + + setExpanded(expanded: boolean): void { + if (this._expanded !== expanded) { + this._expanded = expanded; + this.rebuild(); + } + } + + private rebuild(): void { + // Remove previous content component + if (this.customComponent) { + this.removeChild(this.customComponent); + this.customComponent = undefined; + } + this.removeChild(this.box); + + // Try custom renderer first - it handles its own styling + if (this.customRenderer) { + try { + const component = this.customRenderer(this.message, { expanded: this._expanded }, theme); + if (component) { + // Custom renderer provides its own styled component + this.customComponent = component; + this.addChild(component); + return; + } + } catch { + // Fall through to default rendering + } + } + + // Default rendering uses our box + this.addChild(this.box); + this.box.clear(); + + // Default rendering: label + content + const label = theme.fg("customMessageLabel", `\x1b[1m[${this.message.customType}]\x1b[22m`); + this.box.addChild(new Text(label, 0, 0)); + this.box.addChild(new Spacer(1)); + + // Extract text content + let text: string; + if (typeof this.message.content === "string") { + text = this.message.content; + } else { + text = this.message.content + .filter((c): c is TextContent => c.type === "text") + .map((c) => c.text) + .join("\n"); + } + + // Limit lines when collapsed + if (!this._expanded) { + const lines = text.split("\n"); + if (lines.length > 5) { + text = `${lines.slice(0, 5).join("\n")}\n...`; + } + } + + this.box.addChild( + new Markdown(text, 0, 0, getMarkdownTheme(), { + color: (text: string) => theme.fg("customMessageText", text), + }), + ); + } +} diff --git a/packages/coding-agent/src/modes/interactive/components/model-selector.ts b/packages/coding-agent/src/modes/interactive/components/model-selector.ts index 40236899..bbf77f2a 100644 --- a/packages/coding-agent/src/modes/interactive/components/model-selector.ts +++ b/packages/coding-agent/src/modes/interactive/components/model-selector.ts @@ -36,18 +36,18 @@ export class ModelSelectorComponent extends Container { private allModels: ModelItem[] = []; private filteredModels: ModelItem[] = []; private selectedIndex: number = 0; - private currentModel: Model | null; + private currentModel?: Model; private settingsManager: SettingsManager; private modelRegistry: ModelRegistry; private onSelectCallback: (model: Model) => void; private onCancelCallback: () => void; - private errorMessage: string | null = null; + private errorMessage?: string; private tui: TUI; private scopedModels: ReadonlyArray; constructor( tui: TUI, - currentModel: Model | null, + currentModel: Model | undefined, settingsManager: SettingsManager, modelRegistry: ModelRegistry, scopedModels: ReadonlyArray, diff --git a/packages/coding-agent/src/modes/interactive/components/tool-execution.ts b/packages/coding-agent/src/modes/interactive/components/tool-execution.ts index fe32a3b4..7124c84b 100644 --- a/packages/coding-agent/src/modes/interactive/components/tool-execution.ts +++ b/packages/coding-agent/src/modes/interactive/components/tool-execution.ts @@ -415,10 +415,14 @@ export class ToolExecutionComponent extends Container { } else if (this.toolName === "edit") { const rawPath = this.args?.file_path || this.args?.path || ""; const path = shortenPath(rawPath); - text = - theme.fg("toolTitle", theme.bold("edit")) + - " " + - (path ? theme.fg("accent", path) : theme.fg("toolOutput", "...")); + + // Build path display, appending :line if we have a successful result with line info + let pathDisplay = path ? theme.fg("accent", path) : theme.fg("toolOutput", "..."); + if (this.result && !this.result.isError && this.result.details?.firstChangedLine) { + pathDisplay += theme.fg("warning", `:${this.result.details.firstChangedLine}`); + } + + text = `${theme.fg("toolTitle", theme.bold("edit"))} ${pathDisplay}`; if (this.result) { if (this.result.isError) { diff --git a/packages/coding-agent/src/modes/interactive/components/tree-selector.ts b/packages/coding-agent/src/modes/interactive/components/tree-selector.ts new file mode 100644 index 00000000..bdb61e42 --- /dev/null +++ b/packages/coding-agent/src/modes/interactive/components/tree-selector.ts @@ -0,0 +1,866 @@ +import { + type Component, + Container, + Input, + isArrowDown, + isArrowLeft, + isArrowRight, + isArrowUp, + isBackspace, + isCtrlC, + isCtrlO, + isEnter, + isEscape, + isShiftCtrlO, + Spacer, + Text, + TruncatedText, + truncateToWidth, +} from "@mariozechner/pi-tui"; +import type { SessionTreeNode } from "../../../core/session-manager.js"; +import { theme } from "../theme/theme.js"; +import { DynamicBorder } from "./dynamic-border.js"; + +/** Gutter info: position (displayIndent where connector was) and whether to show │ */ +interface GutterInfo { + position: number; // displayIndent level where the connector was shown + show: boolean; // true = show │, false = show spaces +} + +/** Flattened tree node for navigation */ +interface FlatNode { + node: SessionTreeNode; + /** Indentation level (each level = 3 chars) */ + indent: number; + /** Whether to show connector (├─ or └─) - true if parent has multiple children */ + showConnector: boolean; + /** If showConnector, true = last sibling (└─), false = not last (├─) */ + isLast: boolean; + /** Gutter info for each ancestor branch point */ + gutters: GutterInfo[]; + /** True if this node is a root under a virtual branching root (multiple roots) */ + isVirtualRootChild: boolean; +} + +/** Filter mode for tree display */ +type FilterMode = "default" | "no-tools" | "user-only" | "labeled-only" | "all"; + +/** + * Tree list component with selection and ASCII art visualization + */ +/** Tool call info for lookup */ +interface ToolCallInfo { + name: string; + arguments: Record; +} + +class TreeList implements Component { + private flatNodes: FlatNode[] = []; + private filteredNodes: FlatNode[] = []; + private selectedIndex = 0; + private currentLeafId: string | null; + private maxVisibleLines: number; + private filterMode: FilterMode = "default"; + private searchQuery = ""; + private toolCallMap: Map = new Map(); + private multipleRoots = false; + private activePathIds: Set = new Set(); + + public onSelect?: (entryId: string) => void; + public onCancel?: () => void; + public onLabelEdit?: (entryId: string, currentLabel: string | undefined) => void; + + constructor(tree: SessionTreeNode[], currentLeafId: string | null, maxVisibleLines: number) { + this.currentLeafId = currentLeafId; + this.maxVisibleLines = maxVisibleLines; + this.multipleRoots = tree.length > 1; + this.flatNodes = this.flattenTree(tree); + this.buildActivePath(); + this.applyFilter(); + + // Start with current leaf selected + const leafIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === currentLeafId); + if (leafIndex !== -1) { + this.selectedIndex = leafIndex; + } else { + this.selectedIndex = Math.max(0, this.filteredNodes.length - 1); + } + } + + /** Build the set of entry IDs on the path from root to current leaf */ + private buildActivePath(): void { + this.activePathIds.clear(); + if (!this.currentLeafId) return; + + // Build a map of id -> entry for parent lookup + const entryMap = new Map(); + for (const flatNode of this.flatNodes) { + entryMap.set(flatNode.node.entry.id, flatNode); + } + + // Walk from leaf to root + let currentId: string | null = this.currentLeafId; + while (currentId) { + this.activePathIds.add(currentId); + const node = entryMap.get(currentId); + if (!node) break; + currentId = node.node.entry.parentId ?? null; + } + } + + private flattenTree(roots: SessionTreeNode[]): FlatNode[] { + const result: FlatNode[] = []; + this.toolCallMap.clear(); + + // Indentation rules: + // - At indent 0: stay at 0 unless parent has >1 children (then +1) + // - At indent 1: children always go to indent 2 (visual grouping of subtree) + // - At indent 2+: stay flat for single-child chains, +1 only if parent branches + + // Stack items: [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild] + type StackItem = [SessionTreeNode, number, boolean, boolean, boolean, GutterInfo[], boolean]; + const stack: StackItem[] = []; + + // Determine which subtrees contain the active leaf (to sort current branch first) + // Use iterative post-order traversal to avoid stack overflow + const containsActive = new Map(); + const leafId = this.currentLeafId; + { + // Build list in pre-order, then process in reverse for post-order effect + const allNodes: SessionTreeNode[] = []; + const preOrderStack: SessionTreeNode[] = [...roots]; + while (preOrderStack.length > 0) { + const node = preOrderStack.pop()!; + allNodes.push(node); + // Push children in reverse so they're processed left-to-right + for (let i = node.children.length - 1; i >= 0; i--) { + preOrderStack.push(node.children[i]); + } + } + // Process in reverse (post-order): children before parents + for (let i = allNodes.length - 1; i >= 0; i--) { + const node = allNodes[i]; + let has = leafId !== null && node.entry.id === leafId; + for (const child of node.children) { + if (containsActive.get(child)) { + has = true; + } + } + containsActive.set(node, has); + } + } + + // Add roots in reverse order, prioritizing the one containing the active leaf + // If multiple roots, treat them as children of a virtual root that branches + const multipleRoots = roots.length > 1; + const orderedRoots = [...roots].sort((a, b) => Number(containsActive.get(b)) - Number(containsActive.get(a))); + for (let i = orderedRoots.length - 1; i >= 0; i--) { + const isLast = i === orderedRoots.length - 1; + stack.push([orderedRoots[i], multipleRoots ? 1 : 0, multipleRoots, multipleRoots, isLast, [], multipleRoots]); + } + + while (stack.length > 0) { + const [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild] = stack.pop()!; + + // Extract tool calls from assistant messages for later lookup + const entry = node.entry; + if (entry.type === "message" && entry.message.role === "assistant") { + const content = (entry.message as { content?: unknown }).content; + if (Array.isArray(content)) { + for (const block of content) { + if (typeof block === "object" && block !== null && "type" in block && block.type === "toolCall") { + const tc = block as { id: string; name: string; arguments: Record }; + this.toolCallMap.set(tc.id, { name: tc.name, arguments: tc.arguments }); + } + } + } + } + + result.push({ node, indent, showConnector, isLast, gutters, isVirtualRootChild }); + + const children = node.children; + const multipleChildren = children.length > 1; + + // Order children so the branch containing the active leaf comes first + const orderedChildren = (() => { + const prioritized: SessionTreeNode[] = []; + const rest: SessionTreeNode[] = []; + for (const child of children) { + if (containsActive.get(child)) { + prioritized.push(child); + } else { + rest.push(child); + } + } + return [...prioritized, ...rest]; + })(); + + // Calculate child indent + let childIndent: number; + if (multipleChildren) { + // Parent branches: children get +1 + childIndent = indent + 1; + } else if (justBranched && indent > 0) { + // First generation after a branch: +1 for visual grouping + childIndent = indent + 1; + } else { + // Single-child chain: stay flat + childIndent = indent; + } + + // Build gutters for children + // If this node showed a connector, add a gutter entry for descendants + // Only add gutter if connector is actually displayed (not suppressed for virtual root children) + const connectorDisplayed = showConnector && !isVirtualRootChild; + // When connector is displayed, add a gutter entry at the connector's position + // Connector is at position (displayIndent - 1), so gutter should be there too + const currentDisplayIndent = this.multipleRoots ? Math.max(0, indent - 1) : indent; + const connectorPosition = Math.max(0, currentDisplayIndent - 1); + const childGutters: GutterInfo[] = connectorDisplayed + ? [...gutters, { position: connectorPosition, show: !isLast }] + : gutters; + + // Add children in reverse order + for (let i = orderedChildren.length - 1; i >= 0; i--) { + const childIsLast = i === orderedChildren.length - 1; + stack.push([ + orderedChildren[i], + childIndent, + multipleChildren, + multipleChildren, + childIsLast, + childGutters, + false, + ]); + } + } + + return result; + } + + private applyFilter(): void { + // Remember currently selected node to preserve cursor position + const previouslySelectedId = this.filteredNodes[this.selectedIndex]?.node.entry.id; + + const searchTokens = this.searchQuery.toLowerCase().split(/\s+/).filter(Boolean); + + this.filteredNodes = this.flatNodes.filter((flatNode) => { + const entry = flatNode.node.entry; + const isCurrentLeaf = entry.id === this.currentLeafId; + + // Skip assistant messages with only tool calls (no text) unless error/aborted + // Always show current leaf so active position is visible + if (entry.type === "message" && entry.message.role === "assistant" && !isCurrentLeaf) { + const msg = entry.message as { stopReason?: string; content?: unknown }; + const hasText = this.hasTextContent(msg.content); + const isErrorOrAborted = msg.stopReason && msg.stopReason !== "stop" && msg.stopReason !== "toolUse"; + // Only hide if no text AND not an error/aborted message + if (!hasText && !isErrorOrAborted) { + return false; + } + } + + // Apply filter mode + let passesFilter = true; + // Entry types hidden in default view (settings/bookkeeping) + const isSettingsEntry = + entry.type === "label" || + entry.type === "custom" || + entry.type === "model_change" || + entry.type === "thinking_level_change"; + + switch (this.filterMode) { + case "user-only": + // Just user messages + passesFilter = entry.type === "message" && entry.message.role === "user"; + break; + case "no-tools": + // Default minus tool results + passesFilter = !isSettingsEntry && !(entry.type === "message" && entry.message.role === "toolResult"); + break; + case "labeled-only": + // Just labeled entries + passesFilter = flatNode.node.label !== undefined; + break; + case "all": + // Show everything + passesFilter = true; + break; + default: + // Default mode: hide settings/bookkeeping entries + passesFilter = !isSettingsEntry; + break; + } + + if (!passesFilter) return false; + + // Apply search filter + if (searchTokens.length > 0) { + const nodeText = this.getSearchableText(flatNode.node).toLowerCase(); + return searchTokens.every((token) => nodeText.includes(token)); + } + + return true; + }); + + // Try to preserve cursor on the same node after filtering + if (previouslySelectedId) { + const newIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === previouslySelectedId); + if (newIndex !== -1) { + this.selectedIndex = newIndex; + return; + } + } + + // Fall back: clamp index if out of bounds + if (this.selectedIndex >= this.filteredNodes.length) { + this.selectedIndex = Math.max(0, this.filteredNodes.length - 1); + } + } + + /** Get searchable text content from a node */ + private getSearchableText(node: SessionTreeNode): string { + const entry = node.entry; + const parts: string[] = []; + + if (node.label) { + parts.push(node.label); + } + + switch (entry.type) { + case "message": { + const msg = entry.message; + parts.push(msg.role); + if ("content" in msg && msg.content) { + parts.push(this.extractContent(msg.content)); + } + if (msg.role === "bashExecution") { + const bashMsg = msg as { command?: string }; + if (bashMsg.command) parts.push(bashMsg.command); + } + break; + } + case "custom_message": { + parts.push(entry.customType); + if (typeof entry.content === "string") { + parts.push(entry.content); + } else { + parts.push(this.extractContent(entry.content)); + } + break; + } + case "compaction": + parts.push("compaction"); + break; + case "branch_summary": + parts.push("branch summary", entry.summary); + break; + case "model_change": + parts.push("model", entry.modelId); + break; + case "thinking_level_change": + parts.push("thinking", entry.thinkingLevel); + break; + case "custom": + parts.push("custom", entry.customType); + break; + case "label": + parts.push("label", entry.label ?? ""); + break; + } + + return parts.join(" "); + } + + invalidate(): void {} + + getSearchQuery(): string { + return this.searchQuery; + } + + getSelectedNode(): SessionTreeNode | undefined { + return this.filteredNodes[this.selectedIndex]?.node; + } + + updateNodeLabel(entryId: string, label: string | undefined): void { + for (const flatNode of this.flatNodes) { + if (flatNode.node.entry.id === entryId) { + flatNode.node.label = label; + break; + } + } + } + + private getFilterLabel(): string { + switch (this.filterMode) { + case "no-tools": + return " [no-tools]"; + case "user-only": + return " [user]"; + case "labeled-only": + return " [labeled]"; + case "all": + return " [all]"; + default: + return ""; + } + } + + render(width: number): string[] { + const lines: string[] = []; + + if (this.filteredNodes.length === 0) { + lines.push(truncateToWidth(theme.fg("muted", " No entries found"), width)); + lines.push(truncateToWidth(theme.fg("muted", ` (0/0)${this.getFilterLabel()}`), width)); + return lines; + } + + const startIndex = Math.max( + 0, + Math.min( + this.selectedIndex - Math.floor(this.maxVisibleLines / 2), + this.filteredNodes.length - this.maxVisibleLines, + ), + ); + const endIndex = Math.min(startIndex + this.maxVisibleLines, this.filteredNodes.length); + + for (let i = startIndex; i < endIndex; i++) { + const flatNode = this.filteredNodes[i]; + const entry = flatNode.node.entry; + const isSelected = i === this.selectedIndex; + + // Build line: cursor + prefix + path marker + label + content + const cursor = isSelected ? theme.fg("accent", "› ") : " "; + + // If multiple roots, shift display (roots at 0, not 1) + const displayIndent = this.multipleRoots ? Math.max(0, flatNode.indent - 1) : flatNode.indent; + + // Build prefix with gutters at their correct positions + // Each gutter has a position (displayIndent where its connector was shown) + const connector = + flatNode.showConnector && !flatNode.isVirtualRootChild ? (flatNode.isLast ? "└─ " : "├─ ") : ""; + const connectorPosition = connector ? displayIndent - 1 : -1; + + // Build prefix char by char, placing gutters and connector at their positions + const totalChars = displayIndent * 3; + const prefixChars: string[] = []; + for (let i = 0; i < totalChars; i++) { + const level = Math.floor(i / 3); + const posInLevel = i % 3; + + // Check if there's a gutter at this level + const gutter = flatNode.gutters.find((g) => g.position === level); + if (gutter) { + if (posInLevel === 0) { + prefixChars.push(gutter.show ? "│" : " "); + } else { + prefixChars.push(" "); + } + } else if (connector && level === connectorPosition) { + // Connector at this level + if (posInLevel === 0) { + prefixChars.push(flatNode.isLast ? "└" : "├"); + } else if (posInLevel === 1) { + prefixChars.push("─"); + } else { + prefixChars.push(" "); + } + } else { + prefixChars.push(" "); + } + } + const prefix = prefixChars.join(""); + + // Active path marker - shown right before the entry text + const isOnActivePath = this.activePathIds.has(entry.id); + const pathMarker = isOnActivePath ? theme.fg("accent", "• ") : ""; + + const label = flatNode.node.label ? theme.fg("warning", `[${flatNode.node.label}] `) : ""; + const content = this.getEntryDisplayText(flatNode.node, isSelected); + + let line = cursor + theme.fg("dim", prefix) + pathMarker + label + content; + if (isSelected) { + line = theme.bg("selectedBg", line); + } + lines.push(truncateToWidth(line, width)); + } + + lines.push( + truncateToWidth( + theme.fg("muted", ` (${this.selectedIndex + 1}/${this.filteredNodes.length})${this.getFilterLabel()}`), + width, + ), + ); + + return lines; + } + + private getEntryDisplayText(node: SessionTreeNode, isSelected: boolean): string { + const entry = node.entry; + let result: string; + + const normalize = (s: string) => s.replace(/[\n\t]/g, " ").trim(); + + switch (entry.type) { + case "message": { + const msg = entry.message; + const role = msg.role; + if (role === "user") { + const msgWithContent = msg as { content?: unknown }; + const content = normalize(this.extractContent(msgWithContent.content)); + result = theme.fg("accent", "user: ") + content; + } else if (role === "assistant") { + const msgWithContent = msg as { content?: unknown; stopReason?: string; errorMessage?: string }; + const textContent = normalize(this.extractContent(msgWithContent.content)); + if (textContent) { + result = theme.fg("success", "assistant: ") + textContent; + } else if (msgWithContent.stopReason === "aborted") { + result = theme.fg("success", "assistant: ") + theme.fg("muted", "(aborted)"); + } else if (msgWithContent.errorMessage) { + const errMsg = normalize(msgWithContent.errorMessage).slice(0, 80); + result = theme.fg("success", "assistant: ") + theme.fg("error", errMsg); + } else { + result = theme.fg("success", "assistant: ") + theme.fg("muted", "(no content)"); + } + } else if (role === "toolResult") { + const toolMsg = msg as { toolCallId?: string; toolName?: string }; + const toolCall = toolMsg.toolCallId ? this.toolCallMap.get(toolMsg.toolCallId) : undefined; + if (toolCall) { + result = theme.fg("muted", this.formatToolCall(toolCall.name, toolCall.arguments)); + } else { + result = theme.fg("muted", `[${toolMsg.toolName ?? "tool"}]`); + } + } else if (role === "bashExecution") { + const bashMsg = msg as { command?: string }; + result = theme.fg("dim", `[bash]: ${normalize(bashMsg.command ?? "")}`); + } else { + result = theme.fg("dim", `[${role}]`); + } + break; + } + case "custom_message": { + const content = + typeof entry.content === "string" + ? entry.content + : entry.content + .filter((c): c is { type: "text"; text: string } => c.type === "text") + .map((c) => c.text) + .join(""); + result = theme.fg("customMessageLabel", `[${entry.customType}]: `) + normalize(content); + break; + } + case "compaction": { + const tokens = Math.round(entry.tokensBefore / 1000); + result = theme.fg("borderAccent", `[compaction: ${tokens}k tokens]`); + break; + } + case "branch_summary": + result = theme.fg("warning", `[branch summary]: `) + normalize(entry.summary); + break; + case "model_change": + result = theme.fg("dim", `[model: ${entry.modelId}]`); + break; + case "thinking_level_change": + result = theme.fg("dim", `[thinking: ${entry.thinkingLevel}]`); + break; + case "custom": + result = theme.fg("dim", `[custom: ${entry.customType}]`); + break; + case "label": + result = theme.fg("dim", `[label: ${entry.label ?? "(cleared)"}]`); + break; + default: + result = ""; + } + + return isSelected ? theme.bold(result) : result; + } + + private extractContent(content: unknown): string { + const maxLen = 200; + if (typeof content === "string") return content.slice(0, maxLen); + if (Array.isArray(content)) { + let result = ""; + for (const c of content) { + if (typeof c === "object" && c !== null && "type" in c && c.type === "text") { + result += (c as { text: string }).text; + if (result.length >= maxLen) return result.slice(0, maxLen); + } + } + return result; + } + return ""; + } + + private hasTextContent(content: unknown): boolean { + if (typeof content === "string") return content.trim().length > 0; + if (Array.isArray(content)) { + for (const c of content) { + if (typeof c === "object" && c !== null && "type" in c && c.type === "text") { + const text = (c as { text?: string }).text; + if (text && text.trim().length > 0) return true; + } + } + } + return false; + } + + private formatToolCall(name: string, args: Record): string { + const shortenPath = (p: string): string => { + const home = process.env.HOME || process.env.USERPROFILE || ""; + if (home && p.startsWith(home)) return `~${p.slice(home.length)}`; + return p; + }; + + switch (name) { + case "read": { + const path = shortenPath(String(args.path || args.file_path || "")); + const offset = args.offset as number | undefined; + const limit = args.limit as number | undefined; + let display = path; + if (offset !== undefined || limit !== undefined) { + const start = offset ?? 1; + const end = limit !== undefined ? start + limit - 1 : ""; + display += `:${start}${end ? `-${end}` : ""}`; + } + return `[read: ${display}]`; + } + case "write": { + const path = shortenPath(String(args.path || args.file_path || "")); + return `[write: ${path}]`; + } + case "edit": { + const path = shortenPath(String(args.path || args.file_path || "")); + return `[edit: ${path}]`; + } + case "bash": { + const rawCmd = String(args.command || ""); + const cmd = rawCmd + .replace(/[\n\t]/g, " ") + .trim() + .slice(0, 50); + return `[bash: ${cmd}${rawCmd.length > 50 ? "..." : ""}]`; + } + case "grep": { + const pattern = String(args.pattern || ""); + const path = shortenPath(String(args.path || ".")); + return `[grep: /${pattern}/ in ${path}]`; + } + case "find": { + const pattern = String(args.pattern || ""); + const path = shortenPath(String(args.path || ".")); + return `[find: ${pattern} in ${path}]`; + } + case "ls": { + const path = shortenPath(String(args.path || ".")); + return `[ls: ${path}]`; + } + default: { + // Custom tool - show name and truncated JSON args + const argsStr = JSON.stringify(args).slice(0, 40); + return `[${name}: ${argsStr}${JSON.stringify(args).length > 40 ? "..." : ""}]`; + } + } + } + + handleInput(keyData: string): void { + if (isArrowUp(keyData)) { + this.selectedIndex = this.selectedIndex === 0 ? this.filteredNodes.length - 1 : this.selectedIndex - 1; + } else if (isArrowDown(keyData)) { + this.selectedIndex = this.selectedIndex === this.filteredNodes.length - 1 ? 0 : this.selectedIndex + 1; + } else if (isArrowLeft(keyData)) { + // Page up + this.selectedIndex = Math.max(0, this.selectedIndex - this.maxVisibleLines); + } else if (isArrowRight(keyData)) { + // Page down + this.selectedIndex = Math.min(this.filteredNodes.length - 1, this.selectedIndex + this.maxVisibleLines); + } else if (isEnter(keyData)) { + const selected = this.filteredNodes[this.selectedIndex]; + if (selected && this.onSelect) { + this.onSelect(selected.node.entry.id); + } + } else if (isEscape(keyData)) { + if (this.searchQuery) { + this.searchQuery = ""; + this.applyFilter(); + } else { + this.onCancel?.(); + } + } else if (isCtrlC(keyData)) { + this.onCancel?.(); + } else if (isShiftCtrlO(keyData)) { + // Cycle filter backwards + const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"]; + const currentIndex = modes.indexOf(this.filterMode); + this.filterMode = modes[(currentIndex - 1 + modes.length) % modes.length]; + this.applyFilter(); + } else if (isCtrlO(keyData)) { + // Cycle filter forwards: default → no-tools → user-only → labeled-only → all → default + const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"]; + const currentIndex = modes.indexOf(this.filterMode); + this.filterMode = modes[(currentIndex + 1) % modes.length]; + this.applyFilter(); + } else if (isBackspace(keyData)) { + if (this.searchQuery.length > 0) { + this.searchQuery = this.searchQuery.slice(0, -1); + this.applyFilter(); + } + } else if (keyData === "l" && !this.searchQuery) { + const selected = this.filteredNodes[this.selectedIndex]; + if (selected && this.onLabelEdit) { + this.onLabelEdit(selected.node.entry.id, selected.node.label); + } + } else { + const hasControlChars = [...keyData].some((ch) => { + const code = ch.charCodeAt(0); + return code < 32 || code === 0x7f || (code >= 0x80 && code <= 0x9f); + }); + if (!hasControlChars && keyData.length > 0) { + this.searchQuery += keyData; + this.applyFilter(); + } + } + } +} + +/** Component that displays the current search query */ +class SearchLine implements Component { + constructor(private treeList: TreeList) {} + + invalidate(): void {} + + render(width: number): string[] { + const query = this.treeList.getSearchQuery(); + if (query) { + return [truncateToWidth(` ${theme.fg("muted", "Search:")} ${theme.fg("accent", query)}`, width)]; + } + return [truncateToWidth(` ${theme.fg("muted", "Search:")}`, width)]; + } + + handleInput(_keyData: string): void {} +} + +/** Label input component shown when editing a label */ +class LabelInput implements Component { + private input: Input; + private entryId: string; + public onSubmit?: (entryId: string, label: string | undefined) => void; + public onCancel?: () => void; + + constructor(entryId: string, currentLabel: string | undefined) { + this.entryId = entryId; + this.input = new Input(); + if (currentLabel) { + this.input.setValue(currentLabel); + } + } + + invalidate(): void {} + + render(width: number): string[] { + const lines: string[] = []; + const indent = " "; + const availableWidth = width - indent.length; + lines.push(truncateToWidth(`${indent}${theme.fg("muted", "Label (empty to remove):")}`, width)); + lines.push(...this.input.render(availableWidth).map((line) => truncateToWidth(`${indent}${line}`, width))); + lines.push(truncateToWidth(`${indent}${theme.fg("dim", "enter: save esc: cancel")}`, width)); + return lines; + } + + handleInput(keyData: string): void { + if (isEnter(keyData)) { + const value = this.input.getValue().trim(); + this.onSubmit?.(this.entryId, value || undefined); + } else if (isEscape(keyData)) { + this.onCancel?.(); + } else { + this.input.handleInput(keyData); + } + } +} + +/** + * Component that renders a session tree selector for navigation + */ +export class TreeSelectorComponent extends Container { + private treeList: TreeList; + private labelInput: LabelInput | null = null; + private labelInputContainer: Container; + private treeContainer: Container; + private onLabelChangeCallback?: (entryId: string, label: string | undefined) => void; + + constructor( + tree: SessionTreeNode[], + currentLeafId: string | null, + terminalHeight: number, + onSelect: (entryId: string) => void, + onCancel: () => void, + onLabelChange?: (entryId: string, label: string | undefined) => void, + ) { + super(); + + this.onLabelChangeCallback = onLabelChange; + const maxVisibleLines = Math.max(5, Math.floor(terminalHeight / 2)); + + this.treeList = new TreeList(tree, currentLeafId, maxVisibleLines); + this.treeList.onSelect = onSelect; + this.treeList.onCancel = onCancel; + this.treeList.onLabelEdit = (entryId, currentLabel) => this.showLabelInput(entryId, currentLabel); + + this.treeContainer = new Container(); + this.treeContainer.addChild(this.treeList); + + this.labelInputContainer = new Container(); + + this.addChild(new Spacer(1)); + this.addChild(new DynamicBorder()); + this.addChild(new Text(theme.bold(" Session Tree"), 1, 0)); + this.addChild( + new TruncatedText(theme.fg("muted", " ↑/↓: move. ←/→: page. l: label. ^O/⇧^O: filter. Type to search"), 0, 0), + ); + this.addChild(new SearchLine(this.treeList)); + this.addChild(new DynamicBorder()); + this.addChild(new Spacer(1)); + this.addChild(this.treeContainer); + this.addChild(this.labelInputContainer); + this.addChild(new Spacer(1)); + this.addChild(new DynamicBorder()); + + if (tree.length === 0) { + setTimeout(() => onCancel(), 100); + } + } + + private showLabelInput(entryId: string, currentLabel: string | undefined): void { + this.labelInput = new LabelInput(entryId, currentLabel); + this.labelInput.onSubmit = (id, label) => { + this.treeList.updateNodeLabel(id, label); + this.onLabelChangeCallback?.(id, label); + this.hideLabelInput(); + }; + this.labelInput.onCancel = () => this.hideLabelInput(); + + this.treeContainer.clear(); + this.labelInputContainer.clear(); + this.labelInputContainer.addChild(this.labelInput); + } + + private hideLabelInput(): void { + this.labelInput = null; + this.labelInputContainer.clear(); + this.treeContainer.clear(); + this.treeContainer.addChild(this.treeList); + } + + handleInput(keyData: string): void { + if (this.labelInput) { + this.labelInput.handleInput(keyData); + } else { + this.treeList.handleInput(keyData); + } + } + + getTreeList(): TreeList { + return this.treeList; + } +} diff --git a/packages/coding-agent/src/modes/interactive/components/user-message.ts b/packages/coding-agent/src/modes/interactive/components/user-message.ts index dfeee875..8b95a3b2 100644 --- a/packages/coding-agent/src/modes/interactive/components/user-message.ts +++ b/packages/coding-agent/src/modes/interactive/components/user-message.ts @@ -5,13 +5,9 @@ import { getMarkdownTheme, theme } from "../theme/theme.js"; * Component that renders a user message */ export class UserMessageComponent extends Container { - constructor(text: string, isFirst: boolean) { + constructor(text: string) { super(); - - // Add spacer before user message (except first one) - if (!isFirst) { - this.addChild(new Spacer(1)); - } + this.addChild(new Spacer(1)); this.addChild( new Markdown(text, 1, 1, getMarkdownTheme(), { bgColor: (text: string) => theme.bg("userMessageBg", text), diff --git a/packages/coding-agent/src/modes/interactive/interactive-mode.ts b/packages/coding-agent/src/modes/interactive/interactive-mode.ts index f0e1c01c..17fef8c0 100644 --- a/packages/coding-agent/src/modes/interactive/interactive-mode.ts +++ b/packages/coding-agent/src/modes/interactive/interactive-mode.ts @@ -6,7 +6,7 @@ import * as fs from "node:fs"; import * as os from "node:os"; import * as path from "node:path"; -import type { AgentState, AppMessage, Attachment } from "@mariozechner/pi-agent-core"; +import type { AgentMessage, AgentState } from "@mariozechner/pi-agent-core"; import type { AssistantMessage, Message, OAuthProvider } from "@mariozechner/pi-ai"; import type { SlashCommand } from "@mariozechner/pi-tui"; import { @@ -28,13 +28,8 @@ import { APP_NAME, getAuthPath, getDebugLogPath } from "../../config.js"; import type { AgentSession, AgentSessionEvent } from "../../core/agent-session.js"; import type { LoadedCustomTool, SessionEvent as ToolSessionEvent } from "../../core/custom-tools/index.js"; import type { HookUIContext } from "../../core/hooks/index.js"; -import { isBashExecutionMessage } from "../../core/messages.js"; -import { - getLatestCompactionEntry, - SessionManager, - SUMMARY_PREFIX, - SUMMARY_SUFFIX, -} from "../../core/session-manager.js"; +import { createCompactionSummaryMessage } from "../../core/messages.js"; +import { type SessionContext, SessionManager } from "../../core/session-manager.js"; import { loadSkills } from "../../core/skills.js"; import { loadProjectContextFiles } from "../../core/system-prompt.js"; import type { TruncationResult } from "../../core/tools/truncate.js"; @@ -43,21 +38,33 @@ import { copyToClipboard } from "../../utils/clipboard.js"; import { ArminComponent } from "./components/armin.js"; import { AssistantMessageComponent } from "./components/assistant-message.js"; import { BashExecutionComponent } from "./components/bash-execution.js"; -import { CompactionComponent } from "./components/compaction.js"; +import { BranchSummaryMessageComponent } from "./components/branch-summary-message.js"; +import { CompactionSummaryMessageComponent } from "./components/compaction-summary-message.js"; import { CustomEditor } from "./components/custom-editor.js"; import { DynamicBorder } from "./components/dynamic-border.js"; import { FooterComponent } from "./components/footer.js"; import { HookInputComponent } from "./components/hook-input.js"; +import { HookMessageComponent } from "./components/hook-message.js"; import { HookSelectorComponent } from "./components/hook-selector.js"; import { ModelSelectorComponent } from "./components/model-selector.js"; import { OAuthSelectorComponent } from "./components/oauth-selector.js"; import { SessionSelectorComponent } from "./components/session-selector.js"; import { SettingsSelectorComponent } from "./components/settings-selector.js"; import { ToolExecutionComponent } from "./components/tool-execution.js"; +import { TreeSelectorComponent } from "./components/tree-selector.js"; import { UserMessageComponent } from "./components/user-message.js"; import { UserMessageSelectorComponent } from "./components/user-message-selector.js"; import { getAvailableThemes, getEditorTheme, getMarkdownTheme, onThemeChange, setTheme, theme } from "./theme/theme.js"; +/** Interface for components that can be expanded/collapsed */ +interface Expandable { + setExpanded(expanded: boolean): void; +} + +function isExpandable(obj: unknown): obj is Expandable { + return typeof obj === "object" && obj !== null && "setExpanded" in obj && typeof obj.setExpanded === "function"; +} + export class InteractiveMode { private session: AgentSession; private ui: TUI; @@ -70,21 +77,18 @@ export class InteractiveMode { private version: string; private isInitialized = false; private onInputCallback?: (text: string) => void; - private loadingAnimation: Loader | null = null; + private loadingAnimation: Loader | undefined = undefined; private lastSigintTime = 0; private lastEscapeTime = 0; - private changelogMarkdown: string | null = null; + private changelogMarkdown: string | undefined = undefined; // Streaming message tracking - private streamingComponent: AssistantMessageComponent | null = null; + private streamingComponent: AssistantMessageComponent | undefined = undefined; // Tool execution tracking: toolCallId -> component private pendingTools = new Map(); - // Track if this is the first user message (to skip spacer) - private isFirstUserMessage = true; - // Tool output expansion state private toolOutputExpanded = false; @@ -98,22 +102,22 @@ export class InteractiveMode { private isBashMode = false; // Track current bash execution component - private bashComponent: BashExecutionComponent | null = null; + private bashComponent: BashExecutionComponent | undefined = undefined; // Track pending bash components (shown in pending area, moved to chat on submit) private pendingBashComponents: BashExecutionComponent[] = []; // Auto-compaction state - private autoCompactionLoader: Loader | null = null; + private autoCompactionLoader: Loader | undefined = undefined; private autoCompactionEscapeHandler?: () => void; // Auto-retry state - private retryLoader: Loader | null = null; + private retryLoader: Loader | undefined = undefined; private retryEscapeHandler?: () => void; // Hook UI state - private hookSelector: HookSelectorComponent | null = null; - private hookInput: HookInputComponent | null = null; + private hookSelector: HookSelectorComponent | undefined = undefined; + private hookInput: HookInputComponent | undefined = undefined; // Custom tools for custom rendering private customTools: Map; @@ -132,10 +136,10 @@ export class InteractiveMode { constructor( session: AgentSession, version: string, - changelogMarkdown: string | null = null, + changelogMarkdown: string | undefined = undefined, customTools: LoadedCustomTool[] = [], private setToolUIContext: (uiContext: HookUIContext, hasUI: boolean) => void = () => {}, - fdPath: string | null = null, + fdPath: string | undefined = undefined, ) { this.session = session; this.version = version; @@ -161,6 +165,7 @@ export class InteractiveMode { { name: "changelog", description: "Show changelog entries" }, { name: "hotkeys", description: "Show all keyboard shortcuts" }, { name: "branch", description: "Create a new branch from a previous message" }, + { name: "tree", description: "Navigate session tree (switch branches)" }, { name: "login", description: "Login with OAuth provider" }, { name: "logout", description: "Logout from OAuth provider" }, { name: "new", description: "Start a new session" }, @@ -177,9 +182,15 @@ export class InteractiveMode { description: cmd.description, })); + // Convert hook commands to SlashCommand format + const hookCommands: SlashCommand[] = (this.session.hookRunner?.getRegisteredCommands() ?? []).map((cmd) => ({ + name: cmd.name, + description: cmd.description ?? "(hook command)", + })); + // Setup autocomplete const autocompleteProvider = new CombinedAutocompleteProvider( - [...slashCommands, ...fileSlashCommands], + [...slashCommands, ...fileSlashCommands, ...hookCommands], process.cwd(), fdPath, ); @@ -350,7 +361,7 @@ export class InteractiveMode { await this.emitToolSessionEvent({ entries, sessionFile: this.session.sessionFile, - previousSessionFile: null, + previousSessionFile: undefined, reason: "start", }); @@ -361,16 +372,30 @@ export class InteractiveMode { // Set UI context on hook runner hookRunner.setUIContext(uiContext, true); - hookRunner.setSessionFile(this.session.sessionFile); // Subscribe to hook errors hookRunner.onError((error) => { this.showHookError(error.hookPath, error.error); }); - // Set up send handler for pi.send() - hookRunner.setSendHandler((text, attachments) => { - this.handleHookSend(text, attachments); + // Set up handlers for pi.sendMessage() and pi.appendEntry() + hookRunner.setSendMessageHandler((message, triggerTurn) => { + const wasStreaming = this.session.isStreaming; + this.session + .sendHookMessage(message, triggerTurn) + .then(() => { + // For non-streaming cases with display=true, update UI + // (streaming cases update via message_end event) + if (!wasStreaming && message.display) { + this.rebuildChatFromMessages(); + } + }) + .catch((err) => { + this.showError(`Hook sendMessage failed: ${err instanceof Error ? err.message : String(err)}`); + }); + }); + hookRunner.setAppendEntryHandler((customType, data) => { + this.sessionManager.appendCustomEntry(customType, data); }); // Show loaded hooks @@ -381,13 +406,9 @@ export class InteractiveMode { this.chatContainer.addChild(new Spacer(1)); } - // Emit session event + // Emit session_start event await hookRunner.emit({ - type: "session", - entries, - sessionFile: this.session.sessionFile, - previousSessionFile: null, - reason: "start", + type: "session_start", }); } @@ -424,13 +445,14 @@ export class InteractiveMode { confirm: (title, message) => this.showHookConfirm(title, message), input: (title, placeholder) => this.showHookInput(title, placeholder), notify: (message, type) => this.showHookNotify(message, type), + custom: (component) => this.showHookCustom(component), }; } /** * Show a selector for hooks. */ - private showHookSelector(title: string, options: string[]): Promise { + private showHookSelector(title: string, options: string[]): Promise { return new Promise((resolve) => { this.hookSelector = new HookSelectorComponent( title, @@ -441,7 +463,7 @@ export class InteractiveMode { }, () => { this.hideHookSelector(); - resolve(null); + resolve(undefined); }, ); @@ -458,7 +480,7 @@ export class InteractiveMode { private hideHookSelector(): void { this.editorContainer.clear(); this.editorContainer.addChild(this.editor); - this.hookSelector = null; + this.hookSelector = undefined; this.ui.setFocus(this.editor); this.ui.requestRender(); } @@ -474,7 +496,7 @@ export class InteractiveMode { /** * Show a text input for hooks. */ - private showHookInput(title: string, placeholder?: string): Promise { + private showHookInput(title: string, placeholder?: string): Promise { return new Promise((resolve) => { this.hookInput = new HookInputComponent( title, @@ -485,7 +507,7 @@ export class InteractiveMode { }, () => { this.hideHookInput(); - resolve(null); + resolve(undefined); }, ); @@ -502,7 +524,7 @@ export class InteractiveMode { private hideHookInput(): void { this.editorContainer.clear(); this.editorContainer.addChild(this.editor); - this.hookInput = null; + this.hookInput = undefined; this.ui.setFocus(this.editor); this.ui.requestRender(); } @@ -520,6 +542,42 @@ export class InteractiveMode { } } + /** + * Show a custom component with keyboard focus. + * Returns a function to call when done. + */ + private showHookCustom(component: Component & { dispose?(): void }): { + close: () => void; + requestRender: () => void; + } { + // Store current editor content + const savedText = this.editor.getText(); + + // Replace editor with custom component + this.editorContainer.clear(); + this.editorContainer.addChild(component); + this.ui.setFocus(component); + this.ui.requestRender(); + + // Return control object + return { + close: () => { + // Call dispose if available + component.dispose?.(); + + // Restore editor + this.editorContainer.clear(); + this.editorContainer.addChild(this.editor); + this.editor.setText(savedText); + this.ui.setFocus(this.editor); + this.ui.requestRender(); + }, + requestRender: () => { + this.ui.requestRender(); + }, + }; + } + /** * Show a hook error in the UI. */ @@ -533,19 +591,6 @@ export class InteractiveMode { * Handle pi.send() from hooks. * If streaming, queue the message. Otherwise, start a new agent loop. */ - private handleHookSend(text: string, attachments?: Attachment[]): void { - if (this.session.isStreaming) { - // Queue the message for later (note: attachments are lost when queuing) - this.session.queueMessage(text); - this.updatePendingMessagesDisplay(); - } else { - // Start a new agent loop immediately - this.session.prompt(text, { attachments }).catch((err) => { - this.showError(err instanceof Error ? err.message : String(err)); - }); - } - } - // ========================================================================= // Key Handlers // ========================================================================= @@ -585,6 +630,9 @@ export class InteractiveMode { this.editor.onShiftTab = () => this.cycleThinkingLevel(); this.editor.onCtrlP = () => this.cycleModel("forward"); this.editor.onShiftCtrlP = () => this.cycleModel("backward"); + + // Global debug handler on TUI (works regardless of focus) + this.ui.onDebug = () => this.handleDebugCommand(); this.editor.onCtrlL = () => this.showModelSelector(); this.editor.onCtrlO = () => this.toggleToolOutputExpansion(); this.editor.onCtrlT = () => this.toggleThinkingBlockVisibility(); @@ -645,6 +693,11 @@ export class InteractiveMode { this.editor.setText(""); return; } + if (text === "/tree") { + this.showTreeSelector(); + this.editor.setText(""); + return; + } if (text === "/login") { this.showOAuthSelector("login"); this.editor.setText(""); @@ -709,7 +762,21 @@ export class InteractiveMode { return; } - // Queue message if agent is streaming + // Hook commands always run immediately, even during streaming + // (if they need to interact with LLM, they use pi.sendMessage which handles queueing) + if (text.startsWith("/") && this.session.hookRunner) { + const spaceIndex = text.indexOf(" "); + const commandName = spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex); + const command = this.session.hookRunner.getCommand(commandName); + if (command) { + this.editor.addToHistory(text); + this.editor.setText(""); + await this.session.prompt(text); + return; + } + } + + // Queue regular messages if agent is streaming if (this.session.isStreaming) { await this.session.queueMessage(text); this.updatePendingMessagesDisplay(); @@ -760,7 +827,10 @@ export class InteractiveMode { break; case "message_start": - if (event.message.role === "user") { + if (event.message.role === "hookMessage") { + this.addMessageToChat(event.message); + this.ui.requestRender(); + } else if (event.message.role === "user") { this.addMessageToChat(event.message); this.editor.setText(""); this.updatePendingMessagesDisplay(); @@ -768,7 +838,7 @@ export class InteractiveMode { } else if (event.message.role === "assistant") { this.streamingComponent = new AssistantMessageComponent(undefined, this.hideThinkingBlock); this.chatContainer.addChild(this.streamingComponent); - this.streamingComponent.updateContent(event.message as AssistantMessage); + this.streamingComponent.updateContent(event.message); this.ui.requestRender(); } break; @@ -822,7 +892,7 @@ export class InteractiveMode { } this.pendingTools.clear(); } - this.streamingComponent = null; + this.streamingComponent = undefined; this.footer.invalidate(); } this.ui.requestRender(); @@ -868,12 +938,12 @@ export class InteractiveMode { case "agent_end": if (this.loadingAnimation) { this.loadingAnimation.stop(); - this.loadingAnimation = null; + this.loadingAnimation = undefined; this.statusContainer.clear(); } if (this.streamingComponent) { this.chatContainer.removeChild(this.streamingComponent); - this.streamingComponent = null; + this.streamingComponent = undefined; } this.pendingTools.clear(); this.ui.requestRender(); @@ -912,7 +982,7 @@ export class InteractiveMode { // Stop loader if (this.autoCompactionLoader) { this.autoCompactionLoader.stop(); - this.autoCompactionLoader = null; + this.autoCompactionLoader = undefined; this.statusContainer.clear(); } // Handle result @@ -922,10 +992,13 @@ export class InteractiveMode { // Rebuild chat to show compacted state this.chatContainer.clear(); this.rebuildChatFromMessages(); - // Add compaction component (same as manual /compact) - const compactionComponent = new CompactionComponent(event.result.tokensBefore, event.result.summary); - compactionComponent.setExpanded(this.toolOutputExpanded); - this.chatContainer.addChild(compactionComponent); + // Add compaction component at bottom so user sees it without scrolling + this.addMessageToChat({ + role: "compactionSummary", + tokensBefore: event.result.tokensBefore, + summary: event.result.summary, + timestamp: Date.now(), + }); this.footer.updateState(this.session.state); } this.ui.requestRender(); @@ -961,7 +1034,7 @@ export class InteractiveMode { // Stop loader if (this.retryLoader) { this.retryLoader.stop(); - this.retryLoader = null; + this.retryLoader = undefined; this.statusContainer.clear(); } // Show error only on final failure (success shows normal response) @@ -991,46 +1064,79 @@ export class InteractiveMode { this.ui.requestRender(); } - private addMessageToChat(message: Message | AppMessage): void { - if (isBashExecutionMessage(message)) { - const component = new BashExecutionComponent(message.command, this.ui); - if (message.output) { - component.appendOutput(message.output); + private addMessageToChat(message: AgentMessage, options?: { populateHistory?: boolean }): void { + switch (message.role) { + case "bashExecution": { + const component = new BashExecutionComponent(message.command, this.ui); + if (message.output) { + component.appendOutput(message.output); + } + component.setComplete( + message.exitCode, + message.cancelled, + message.truncated ? ({ truncated: true } as TruncationResult) : undefined, + message.fullOutputPath, + ); + this.chatContainer.addChild(component); + break; } - component.setComplete( - message.exitCode, - message.cancelled, - message.truncated ? ({ truncated: true } as TruncationResult) : undefined, - message.fullOutputPath, - ); - this.chatContainer.addChild(component); - return; - } - - if (message.role === "user") { - const textContent = this.getUserMessageText(message); - if (textContent) { - const userComponent = new UserMessageComponent(textContent, this.isFirstUserMessage); - this.chatContainer.addChild(userComponent); - this.isFirstUserMessage = false; + case "hookMessage": { + if (message.display) { + const renderer = this.session.hookRunner?.getMessageRenderer(message.customType); + this.chatContainer.addChild(new HookMessageComponent(message, renderer)); + } + break; + } + case "compactionSummary": { + this.chatContainer.addChild(new Spacer(1)); + const component = new CompactionSummaryMessageComponent(message); + component.setExpanded(this.toolOutputExpanded); + this.chatContainer.addChild(component); + break; + } + case "branchSummary": { + this.chatContainer.addChild(new Spacer(1)); + const component = new BranchSummaryMessageComponent(message); + component.setExpanded(this.toolOutputExpanded); + this.chatContainer.addChild(component); + break; + } + case "user": { + const textContent = this.getUserMessageText(message); + if (textContent) { + const userComponent = new UserMessageComponent(textContent); + this.chatContainer.addChild(userComponent); + if (options?.populateHistory) { + this.editor.addToHistory(textContent); + } + } + break; + } + case "assistant": { + const assistantComponent = new AssistantMessageComponent(message, this.hideThinkingBlock); + this.chatContainer.addChild(assistantComponent); + break; + } + case "toolResult": { + // Tool results are rendered inline with tool calls, handled separately + break; + } + default: { + const _exhaustive: never = message; } - } else if (message.role === "assistant") { - const assistantComponent = new AssistantMessageComponent(message as AssistantMessage, this.hideThinkingBlock); - this.chatContainer.addChild(assistantComponent); } } /** - * Render messages to chat. Used for initial load and rebuild after compaction. - * @param messages Messages to render + * Render session context to chat. Used for initial load and rebuild after compaction. + * @param sessionContext Session context to render * @param options.updateFooter Update footer state * @param options.populateHistory Add user messages to editor history */ - private renderMessages( - messages: readonly (Message | AppMessage)[], + private renderSessionContext( + sessionContext: SessionContext, options: { updateFooter?: boolean; populateHistory?: boolean } = {}, ): void { - this.isFirstUserMessage = true; this.pendingTools.clear(); if (options.updateFooter) { @@ -1038,54 +1144,25 @@ export class InteractiveMode { this.updateEditorBorderColor(); } - const compactionEntry = getLatestCompactionEntry(this.sessionManager.getEntries()); - - for (const message of messages) { - if (isBashExecutionMessage(message)) { + for (const message of sessionContext.messages) { + // Assistant messages need special handling for tool calls + if (message.role === "assistant") { this.addMessageToChat(message); - continue; - } - - if (message.role === "user") { - const textContent = this.getUserMessageText(message); - if (textContent) { - if (textContent.startsWith(SUMMARY_PREFIX) && compactionEntry) { - const summary = textContent.slice(SUMMARY_PREFIX.length, -SUMMARY_SUFFIX.length); - const component = new CompactionComponent(compactionEntry.tokensBefore, summary); - component.setExpanded(this.toolOutputExpanded); - this.chatContainer.addChild(component); - } else { - const userComponent = new UserMessageComponent(textContent, this.isFirstUserMessage); - this.chatContainer.addChild(userComponent); - this.isFirstUserMessage = false; - if (options.populateHistory) { - this.editor.addToHistory(textContent); - } - } - } - } else if (message.role === "assistant") { - const assistantMsg = message as AssistantMessage; - const assistantComponent = new AssistantMessageComponent(assistantMsg, this.hideThinkingBlock); - this.chatContainer.addChild(assistantComponent); - - for (const content of assistantMsg.content) { + // Render tool call components + for (const content of message.content) { if (content.type === "toolCall") { const component = new ToolExecutionComponent( content.name, content.arguments, - { - showImages: this.settingsManager.getShowImages(), - }, + { showImages: this.settingsManager.getShowImages() }, this.customTools.get(content.name)?.tool, this.ui, ); this.chatContainer.addChild(component); - if (assistantMsg.stopReason === "aborted" || assistantMsg.stopReason === "error") { + if (message.stopReason === "aborted" || message.stopReason === "error") { const errorMessage = - assistantMsg.stopReason === "aborted" - ? "Operation aborted" - : assistantMsg.errorMessage || "Error"; + message.stopReason === "aborted" ? "Operation aborted" : message.errorMessage || "Error"; component.updateResult({ content: [{ type: "text", text: errorMessage }], isError: true }); } else { this.pendingTools.set(content.id, component); @@ -1093,23 +1170,33 @@ export class InteractiveMode { } } } else if (message.role === "toolResult") { + // Match tool results to pending tool components const component = this.pendingTools.get(message.toolCallId); if (component) { component.updateResult(message); this.pendingTools.delete(message.toolCallId); } + } else { + // All other messages use standard rendering + this.addMessageToChat(message, options); } } + this.pendingTools.clear(); this.ui.requestRender(); } - renderInitialMessages(state: AgentState): void { - this.renderMessages(state.messages, { updateFooter: true, populateHistory: true }); + renderInitialMessages(): void { + // Get aligned messages and entries from session context + const context = this.sessionManager.buildSessionContext(); + this.renderSessionContext(context, { + updateFooter: true, + populateHistory: true, + }); // Show compaction info if session was compacted - const entries = this.sessionManager.getEntries(); - const compactionCount = entries.filter((e) => e.type === "compaction").length; + const allEntries = this.sessionManager.getEntries(); + const compactionCount = allEntries.filter((e) => e.type === "compaction").length; if (compactionCount > 0) { const times = compactionCount === 1 ? "1 time" : `${compactionCount} times`; this.showStatus(`Session compacted ${times}`); @@ -1126,7 +1213,9 @@ export class InteractiveMode { } private rebuildChatFromMessages(): void { - this.renderMessages(this.session.messages); + this.chatContainer.clear(); + const context = this.sessionManager.buildSessionContext(); + this.renderSessionContext(context); } // ========================================================================= @@ -1155,14 +1244,9 @@ export class InteractiveMode { private async shutdown(): Promise { // Emit shutdown event to hooks const hookRunner = this.session.hookRunner; - if (hookRunner?.hasHandlers("session")) { - const entries = this.sessionManager.getEntries(); + if (hookRunner?.hasHandlers("session_shutdown")) { await hookRunner.emit({ - type: "session", - entries, - sessionFile: this.session.sessionFile, - previousSessionFile: null, - reason: "shutdown", + type: "session_shutdown", }); } @@ -1196,7 +1280,7 @@ export class InteractiveMode { private cycleThinkingLevel(): void { const newLevel = this.session.cycleThinkingLevel(); - if (newLevel === null) { + if (newLevel === undefined) { this.showStatus("Current model does not support thinking"); } else { this.footer.updateState(this.session.state); @@ -1208,7 +1292,7 @@ export class InteractiveMode { private async cycleModel(direction: "forward" | "backward"): Promise { try { const result = await this.session.cycleModel(direction); - if (result === null) { + if (result === undefined) { const msg = this.session.scopedModels.length > 0 ? "Only one model in scope" : "Only one model available"; this.showStatus(msg); } else { @@ -1226,11 +1310,7 @@ export class InteractiveMode { private toggleToolOutputExpansion(): void { this.toolOutputExpanded = !this.toolOutputExpanded; for (const child of this.chatContainer.children) { - if (child instanceof ToolExecutionComponent) { - child.setExpanded(this.toolOutputExpanded); - } else if (child instanceof CompactionComponent) { - child.setExpanded(this.toolOutputExpanded); - } else if (child instanceof BashExecutionComponent) { + if (isExpandable(child)) { child.setExpanded(this.toolOutputExpanded); } } @@ -1502,8 +1582,7 @@ export class InteractiveMode { } this.chatContainer.clear(); - this.isFirstUserMessage = true; - this.renderInitialMessages(this.session.state); + this.renderInitialMessages(); this.editor.setText(result.selectedText); done(); this.showStatus("Branched to new session"); @@ -1517,6 +1596,108 @@ export class InteractiveMode { }); } + private showTreeSelector(): void { + const tree = this.sessionManager.getTree(); + const realLeafId = this.sessionManager.getLeafId(); + + // Find the visible leaf for display (skip metadata entries like labels) + let visibleLeafId = realLeafId; + while (visibleLeafId) { + const entry = this.sessionManager.getEntry(visibleLeafId); + if (!entry) break; + if (entry.type !== "label" && entry.type !== "custom") break; + visibleLeafId = entry.parentId ?? null; + } + + if (tree.length === 0) { + this.showStatus("No entries in session"); + return; + } + + this.showSelector((done) => { + const selector = new TreeSelectorComponent( + tree, + visibleLeafId, + this.ui.terminal.rows, + async (entryId) => { + // Selecting the visible leaf is a no-op (already there) + if (entryId === visibleLeafId) { + done(); + this.showStatus("Already at this point"); + return; + } + + // Ask about summarization + done(); // Close selector first + + const wantsSummary = await this.showHookConfirm( + "Summarize branch?", + "Create a summary of the branch you're leaving?", + ); + + // Set up escape handler and loader if summarizing + let summaryLoader: Loader | undefined; + const originalOnEscape = this.editor.onEscape; + + if (wantsSummary) { + this.editor.onEscape = () => { + this.session.abortBranchSummary(); + }; + this.chatContainer.addChild(new Spacer(1)); + summaryLoader = new Loader( + this.ui, + (spinner) => theme.fg("accent", spinner), + (text) => theme.fg("muted", text), + "Summarizing branch... (esc to cancel)", + ); + this.statusContainer.addChild(summaryLoader); + this.ui.requestRender(); + } + + try { + const result = await this.session.navigateTree(entryId, { summarize: wantsSummary }); + + if (result.aborted) { + // Summarization aborted - re-show tree selector + this.showStatus("Branch summarization cancelled"); + this.showTreeSelector(); + return; + } + if (result.cancelled) { + this.showStatus("Navigation cancelled"); + return; + } + + // Update UI + this.chatContainer.clear(); + this.renderInitialMessages(); + if (result.editorText) { + this.editor.setText(result.editorText); + } + this.showStatus("Navigated to selected point"); + } catch (error) { + this.showError(error instanceof Error ? error.message : String(error)); + } finally { + if (summaryLoader) { + summaryLoader.stop(); + this.statusContainer.clear(); + } + this.editor.onEscape = originalOnEscape; + } + }, + () => { + done(); + this.ui.requestRender(); + }, + (entryId, label) => { + this.sessionManager.appendLabelChange(entryId, label); + this.ui.requestRender(); + }, + ); + return { component: selector, focus: selector }; + }); + } + private showSessionSelector(): void { this.showSelector((done) => { const sessions = SessionManager.list(this.sessionManager.getCwd(), this.sessionManager.getSessionDir()); @@ -1542,13 +1723,13 @@ export class InteractiveMode { // Stop loading animation if (this.loadingAnimation) { this.loadingAnimation.stop(); - this.loadingAnimation = null; + this.loadingAnimation = undefined; } this.statusContainer.clear(); // Clear UI state this.pendingMessagesContainer.clear(); - this.streamingComponent = null; + this.streamingComponent = undefined; this.pendingTools.clear(); // Switch session via AgentSession (emits hook and tool session events) @@ -1556,8 +1737,7 @@ export class InteractiveMode { // Clear and re-render the chat this.chatContainer.clear(); - this.isFirstUserMessage = true; - this.renderInitialMessages(this.session.state); + this.renderInitialMessages(); this.showStatus("Resumed session"); } @@ -1805,7 +1985,7 @@ export class InteractiveMode { // Stop loading animation if (this.loadingAnimation) { this.loadingAnimation.stop(); - this.loadingAnimation = null; + this.loadingAnimation = undefined; } this.statusContainer.clear(); @@ -1815,9 +1995,8 @@ export class InteractiveMode { // Clear UI state this.chatContainer.clear(); this.pendingMessagesContainer.clear(); - this.streamingComponent = null; + this.streamingComponent = undefined; this.pendingTools.clear(); - this.isFirstUserMessage = true; this.chatContainer.addChild(new Spacer(1)); this.chatContainer.addChild(new Text(`${theme.fg("accent", "✓ New session started")}`, 1, 1)); @@ -1894,12 +2073,12 @@ export class InteractiveMode { } } catch (error) { if (this.bashComponent) { - this.bashComponent.setComplete(null, false); + this.bashComponent.setComplete(undefined, false); } this.showError(`Bash command failed: ${error instanceof Error ? error.message : "Unknown error"}`); } - this.bashComponent = null; + this.bashComponent = undefined; this.ui.requestRender(); } @@ -1919,7 +2098,7 @@ export class InteractiveMode { // Stop loading animation if (this.loadingAnimation) { this.loadingAnimation.stop(); - this.loadingAnimation = null; + this.loadingAnimation = undefined; } this.statusContainer.clear(); @@ -1945,13 +2124,11 @@ export class InteractiveMode { const result = await this.session.compact(customInstructions); // Rebuild UI - this.chatContainer.clear(); this.rebuildChatFromMessages(); - // Add compaction component - const compactionComponent = new CompactionComponent(result.tokensBefore, result.summary); - compactionComponent.setExpanded(this.toolOutputExpanded); - this.chatContainer.addChild(compactionComponent); + // Add compaction component at bottom so user sees it without scrolling + const msg = createCompactionSummaryMessage(result.summary, result.tokensBefore, new Date().toISOString()); + this.addMessageToChat(msg); this.footer.updateState(this.session.state); } catch (error) { @@ -1971,7 +2148,7 @@ export class InteractiveMode { stop(): void { if (this.loadingAnimation) { this.loadingAnimation.stop(); - this.loadingAnimation = null; + this.loadingAnimation = undefined; } this.footer.dispose(); if (this.unsubscribe) { diff --git a/packages/coding-agent/src/modes/interactive/theme/dark.json b/packages/coding-agent/src/modes/interactive/theme/dark.json index 51ad7749..f55be9f7 100644 --- a/packages/coding-agent/src/modes/interactive/theme/dark.json +++ b/packages/coding-agent/src/modes/interactive/theme/dark.json @@ -11,10 +11,12 @@ "dimGray": "#666666", "darkGray": "#505050", "accent": "#8abeb7", + "selectedBg": "#3a3a4a", "userMsgBg": "#343541", "toolPendingBg": "#282832", "toolSuccessBg": "#283228", - "toolErrorBg": "#3c2828" + "toolErrorBg": "#3c2828", + "customMsgBg": "#2d2838" }, "colors": { "accent": "accent", @@ -28,8 +30,12 @@ "dim": "dimGray", "text": "", + "selectedBg": "selectedBg", "userMessageBg": "userMsgBg", "userMessageText": "", + "customMessageBg": "customMsgBg", + "customMessageText": "", + "customMessageLabel": "#9575cd", "toolPendingBg": "toolPendingBg", "toolSuccessBg": "toolSuccessBg", "toolErrorBg": "toolErrorBg", diff --git a/packages/coding-agent/src/modes/interactive/theme/light.json b/packages/coding-agent/src/modes/interactive/theme/light.json index 57eb2643..a4276853 100644 --- a/packages/coding-agent/src/modes/interactive/theme/light.json +++ b/packages/coding-agent/src/modes/interactive/theme/light.json @@ -10,10 +10,12 @@ "mediumGray": "#6c6c6c", "dimGray": "#8a8a8a", "lightGray": "#b0b0b0", + "selectedBg": "#d0d0e0", "userMsgBg": "#e8e8e8", "toolPendingBg": "#e8e8f0", "toolSuccessBg": "#e8f0e8", - "toolErrorBg": "#f0e8e8" + "toolErrorBg": "#f0e8e8", + "customMsgBg": "#ede7f6" }, "colors": { "accent": "teal", @@ -27,8 +29,12 @@ "dim": "dimGray", "text": "", + "selectedBg": "selectedBg", "userMessageBg": "userMsgBg", "userMessageText": "", + "customMessageBg": "customMsgBg", + "customMessageText": "", + "customMessageLabel": "#7e57c2", "toolPendingBg": "toolPendingBg", "toolSuccessBg": "toolSuccessBg", "toolErrorBg": "toolErrorBg", diff --git a/packages/coding-agent/src/modes/interactive/theme/theme-schema.json b/packages/coding-agent/src/modes/interactive/theme/theme-schema.json index 7f060d23..f561ecb1 100644 --- a/packages/coding-agent/src/modes/interactive/theme/theme-schema.json +++ b/packages/coding-agent/src/modes/interactive/theme/theme-schema.json @@ -47,6 +47,9 @@ "text", "userMessageBg", "userMessageText", + "customMessageBg", + "customMessageText", + "customMessageLabel", "toolPendingBg", "toolSuccessBg", "toolErrorBg", @@ -122,6 +125,18 @@ "$ref": "#/$defs/colorValue", "description": "User message text color" }, + "customMessageBg": { + "$ref": "#/$defs/colorValue", + "description": "Custom message background (hook-injected messages)" + }, + "customMessageText": { + "$ref": "#/$defs/colorValue", + "description": "Custom message text color" + }, + "customMessageLabel": { + "$ref": "#/$defs/colorValue", + "description": "Custom message type label color" + }, "toolPendingBg": { "$ref": "#/$defs/colorValue", "description": "Tool execution box (pending state)" diff --git a/packages/coding-agent/src/modes/interactive/theme/theme.ts b/packages/coding-agent/src/modes/interactive/theme/theme.ts index 8f56b8e6..c0d8bf66 100644 --- a/packages/coding-agent/src/modes/interactive/theme/theme.ts +++ b/packages/coding-agent/src/modes/interactive/theme/theme.ts @@ -34,9 +34,13 @@ const ThemeJsonSchema = Type.Object({ muted: ColorValueSchema, dim: ColorValueSchema, text: ColorValueSchema, - // Backgrounds & Content Text (7 colors) + // Backgrounds & Content Text (11 colors) + selectedBg: ColorValueSchema, userMessageBg: ColorValueSchema, userMessageText: ColorValueSchema, + customMessageBg: ColorValueSchema, + customMessageText: ColorValueSchema, + customMessageLabel: ColorValueSchema, toolPendingBg: ColorValueSchema, toolSuccessBg: ColorValueSchema, toolErrorBg: ColorValueSchema, @@ -95,6 +99,8 @@ export type ThemeColor = | "dim" | "text" | "userMessageText" + | "customMessageText" + | "customMessageLabel" | "toolTitle" | "toolOutput" | "mdHeading" @@ -127,7 +133,13 @@ export type ThemeColor = | "thinkingXhigh" | "bashMode"; -export type ThemeBg = "userMessageBg" | "toolPendingBg" | "toolSuccessBg" | "toolErrorBg"; +export type ThemeBg = + | "selectedBg" + | "userMessageBg" + | "customMessageBg" + | "toolPendingBg" + | "toolSuccessBg" + | "toolErrorBg"; type ColorMode = "truecolor" | "256color"; @@ -482,7 +494,14 @@ function createTheme(themeJson: ThemeJson, mode?: ColorMode): Theme { const resolvedColors = resolveThemeColors(themeJson.colors, themeJson.vars); const fgColors: Record = {} as Record; const bgColors: Record = {} as Record; - const bgColorKeys: Set = new Set(["userMessageBg", "toolPendingBg", "toolSuccessBg", "toolErrorBg"]); + const bgColorKeys: Set = new Set([ + "selectedBg", + "userMessageBg", + "customMessageBg", + "toolPendingBg", + "toolSuccessBg", + "toolErrorBg", + ]); for (const [key, value] of Object.entries(resolvedColors)) { if (bgColorKeys.has(key)) { bgColors[key as ThemeBg] = value; diff --git a/packages/coding-agent/src/modes/print-mode.ts b/packages/coding-agent/src/modes/print-mode.ts index 46724701..fbf3037a 100644 --- a/packages/coding-agent/src/modes/print-mode.ts +++ b/packages/coding-agent/src/modes/print-mode.ts @@ -6,8 +6,7 @@ * - `pi --mode json "prompt"` - JSON event stream */ -import type { Attachment } from "@mariozechner/pi-agent-core"; -import type { AssistantMessage } from "@mariozechner/pi-ai"; +import type { AssistantMessage, ImageContent } from "@mariozechner/pi-ai"; import type { AgentSession } from "../core/agent-session.js"; /** @@ -18,14 +17,14 @@ import type { AgentSession } from "../core/agent-session.js"; * @param mode Output mode: "text" for final response only, "json" for all events * @param messages Array of prompts to send * @param initialMessage Optional first message (may contain @file content) - * @param initialAttachments Optional attachments for the initial message + * @param initialImages Optional images for the initial message */ export async function runPrintMode( session: AgentSession, mode: "text" | "json", messages: string[], initialMessage?: string, - initialAttachments?: Attachment[], + initialImages?: ImageContent[], ): Promise { // Load entries once for session start events const entries = session.sessionManager.getEntries(); @@ -34,22 +33,21 @@ export async function runPrintMode( // Set up hooks for print mode (no UI) const hookRunner = session.hookRunner; if (hookRunner) { - // Use actual session file if configured (via --session), otherwise null - hookRunner.setSessionFile(session.sessionFile); hookRunner.onError((err) => { console.error(`Hook error (${err.hookPath}): ${err.error}`); }); - // No-op send handler for print mode (single-shot, no async messages) - hookRunner.setSendHandler(() => { - console.error("Warning: pi.send() is not supported in print mode"); + // Set up handlers - sendHookMessage handles queuing/direct append as needed + hookRunner.setSendMessageHandler((message, triggerTurn) => { + session.sendHookMessage(message, triggerTurn).catch((e) => { + console.error(`Hook sendMessage failed: ${e instanceof Error ? e.message : String(e)}`); + }); }); - // Emit session event + hookRunner.setAppendEntryHandler((customType, data) => { + session.sessionManager.appendCustomEntry(customType, data); + }); + // Emit session_start event await hookRunner.emit({ - type: "session", - entries, - sessionFile: session.sessionFile, - previousSessionFile: null, - reason: "start", + type: "session_start", }); } @@ -60,7 +58,7 @@ export async function runPrintMode( await tool.onSession({ entries, sessionFile: session.sessionFile, - previousSessionFile: null, + previousSessionFile: undefined, reason: "start", }); } catch (_err) { @@ -79,7 +77,7 @@ export async function runPrintMode( // Send initial message with attachments if (initialMessage) { - await session.prompt(initialMessage, { attachments: initialAttachments }); + await session.prompt(initialMessage, { images: initialImages }); } // Send remaining messages diff --git a/packages/coding-agent/src/modes/rpc/rpc-client.ts b/packages/coding-agent/src/modes/rpc/rpc-client.ts index 4b79b946..0249ca11 100644 --- a/packages/coding-agent/src/modes/rpc/rpc-client.ts +++ b/packages/coding-agent/src/modes/rpc/rpc-client.ts @@ -6,9 +6,11 @@ import { type ChildProcess, spawn } from "node:child_process"; import * as readline from "node:readline"; -import type { AgentEvent, AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core"; -import type { CompactionResult, SessionStats } from "../../core/agent-session.js"; +import type { AgentEvent, AgentMessage, ThinkingLevel } from "@mariozechner/pi-agent-core"; +import type { ImageContent } from "@mariozechner/pi-ai"; +import type { SessionStats } from "../../core/agent-session.js"; import type { BashResult } from "../../core/bash-executor.js"; +import type { CompactionResult } from "../../core/compaction/index.js"; import type { RpcCommand, RpcResponse, RpcSessionState } from "./rpc-types.js"; // ============================================================================ @@ -166,8 +168,8 @@ export class RpcClient { * Returns immediately after sending; use onEvent() to receive streaming events. * Use waitForIdle() to wait for completion. */ - async prompt(message: string, attachments?: Attachment[]): Promise { - await this.send({ type: "prompt", message, attachments }); + async prompt(message: string, images?: ImageContent[]): Promise { + await this.send({ type: "prompt", message, images }); } /** @@ -348,9 +350,9 @@ export class RpcClient { /** * Get all messages in the session. */ - async getMessages(): Promise { + async getMessages(): Promise { const response = await this.send({ type: "get_messages" }); - return this.getData<{ messages: AppMessage[] }>(response).messages; + return this.getData<{ messages: AgentMessage[] }>(response).messages; } // ========================================================================= @@ -403,9 +405,9 @@ export class RpcClient { /** * Send prompt and wait for completion, returning all events. */ - async promptAndWait(message: string, attachments?: Attachment[], timeout = 60000): Promise { + async promptAndWait(message: string, images?: ImageContent[], timeout = 60000): Promise { const eventsPromise = this.collectEvents(timeout); - await this.prompt(message, attachments); + await this.prompt(message, images); return eventsPromise; } diff --git a/packages/coding-agent/src/modes/rpc/rpc-mode.ts b/packages/coding-agent/src/modes/rpc/rpc-mode.ts index 0c67969b..de378612 100644 --- a/packages/coding-agent/src/modes/rpc/rpc-mode.ts +++ b/packages/coding-agent/src/modes/rpc/rpc-mode.ts @@ -51,17 +51,17 @@ export async function runRpcMode(session: AgentSession): Promise { * Create a hook UI context that uses the RPC protocol. */ const createHookUIContext = (): HookUIContext => ({ - async select(title: string, options: string[]): Promise { + async select(title: string, options: string[]): Promise { const id = crypto.randomUUID(); return new Promise((resolve, reject) => { pendingHookRequests.set(id, { resolve: (response: RpcHookUIResponse) => { if ("cancelled" in response && response.cancelled) { - resolve(null); + resolve(undefined); } else if ("value" in response) { resolve(response.value); } else { - resolve(null); + resolve(undefined); } }, reject, @@ -89,17 +89,17 @@ export async function runRpcMode(session: AgentSession): Promise { }); }, - async input(title: string, placeholder?: string): Promise { + async input(title: string, placeholder?: string): Promise { const id = crypto.randomUUID(); return new Promise((resolve, reject) => { pendingHookRequests.set(id, { resolve: (response: RpcHookUIResponse) => { if ("cancelled" in response && response.cancelled) { - resolve(null); + resolve(undefined); } else if ("value" in response) { resolve(response.value); } else { - resolve(null); + resolve(undefined); } }, reject, @@ -118,6 +118,11 @@ export async function runRpcMode(session: AgentSession): Promise { notifyType: type, } as RpcHookUIRequest); }, + + custom() { + // Custom UI not supported in RPC mode + return { close: () => {}, requestRender: () => {} }; + }, }); // Load entries once for session start events @@ -127,28 +132,21 @@ export async function runRpcMode(session: AgentSession): Promise { const hookRunner = session.hookRunner; if (hookRunner) { hookRunner.setUIContext(createHookUIContext(), false); - hookRunner.setSessionFile(session.sessionFile); hookRunner.onError((err) => { output({ type: "hook_error", hookPath: err.hookPath, event: err.event, error: err.error }); }); - // Set up send handler for pi.send() - hookRunner.setSendHandler((text, attachments) => { - // In RPC mode, just queue or prompt based on streaming state - if (session.isStreaming) { - session.queueMessage(text); - } else { - session.prompt(text, { attachments }).catch((e) => { - output(error(undefined, "hook_send", e.message)); - }); - } + // Set up handlers for pi.sendMessage() and pi.appendEntry() + hookRunner.setSendMessageHandler((message, triggerTurn) => { + session.sendHookMessage(message, triggerTurn).catch((e) => { + output(error(undefined, "hook_send", e.message)); + }); }); - // Emit session event + hookRunner.setAppendEntryHandler((customType, data) => { + session.sessionManager.appendCustomEntry(customType, data); + }); + // Emit session_start event await hookRunner.emit({ - type: "session", - entries, - sessionFile: session.sessionFile, - previousSessionFile: null, - reason: "start", + type: "session_start", }); } @@ -160,7 +158,7 @@ export async function runRpcMode(session: AgentSession): Promise { await tool.onSession({ entries, sessionFile: session.sessionFile, - previousSessionFile: null, + previousSessionFile: undefined, reason: "start", }); } catch (_err) { @@ -185,10 +183,10 @@ export async function runRpcMode(session: AgentSession): Promise { case "prompt": { // Don't await - events will stream + // Hook commands and file slash commands are handled in session.prompt() session .prompt(command.message, { - attachments: command.attachments, - expandSlashCommands: false, + images: command.images, }) .catch((e) => output(error(id, "prompt", e.message))); return success(id, "prompt"); diff --git a/packages/coding-agent/src/modes/rpc/rpc-types.ts b/packages/coding-agent/src/modes/rpc/rpc-types.ts index ab4f0b61..5feead90 100644 --- a/packages/coding-agent/src/modes/rpc/rpc-types.ts +++ b/packages/coding-agent/src/modes/rpc/rpc-types.ts @@ -5,10 +5,11 @@ * Responses and events are emitted as JSON lines on stdout. */ -import type { AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core"; -import type { Model } from "@mariozechner/pi-ai"; -import type { CompactionResult, SessionStats } from "../../core/agent-session.js"; +import type { AgentMessage, ThinkingLevel } from "@mariozechner/pi-agent-core"; +import type { ImageContent, Model } from "@mariozechner/pi-ai"; +import type { SessionStats } from "../../core/agent-session.js"; import type { BashResult } from "../../core/bash-executor.js"; +import type { CompactionResult } from "../../core/compaction/index.js"; // ============================================================================ // RPC Commands (stdin) @@ -16,7 +17,7 @@ import type { BashResult } from "../../core/bash-executor.js"; export type RpcCommand = // Prompting - | { id?: string; type: "prompt"; message: string; attachments?: Attachment[] } + | { id?: string; type: "prompt"; message: string; images?: ImageContent[] } | { id?: string; type: "queue_message"; message: string } | { id?: string; type: "abort" } | { id?: string; type: "reset" } @@ -64,12 +65,12 @@ export type RpcCommand = // ============================================================================ export interface RpcSessionState { - model: Model | null; + model?: Model; thinkingLevel: ThinkingLevel; isStreaming: boolean; isCompacting: boolean; queueMode: "all" | "one-at-a-time"; - sessionFile: string | null; + sessionFile?: string; sessionId: string; autoCompactionEnabled: boolean; messageCount: number; @@ -160,7 +161,7 @@ export type RpcResponse = } // Messages - | { id?: string; type: "response"; command: "get_messages"; success: true; data: { messages: AppMessage[] } } + | { id?: string; type: "response"; command: "get_messages"; success: true; data: { messages: AgentMessage[] } } // Error response (any command can fail) | { id?: string; type: "response"; command: string; success: false; error: string }; diff --git a/packages/coding-agent/src/utils/tools-manager.ts b/packages/coding-agent/src/utils/tools-manager.ts index 8111b3d2..e969950c 100644 --- a/packages/coding-agent/src/utils/tools-manager.ts +++ b/packages/coding-agent/src/utils/tools-manager.ts @@ -184,14 +184,14 @@ async function downloadTool(tool: "fd" | "rg"): Promise { // Ensure a tool is available, downloading if necessary // Returns the path to the tool, or null if unavailable -export async function ensureTool(tool: "fd" | "rg", silent: boolean = false): Promise { +export async function ensureTool(tool: "fd" | "rg", silent: boolean = false): Promise { const existingPath = getToolPath(tool); if (existingPath) { return existingPath; } const config = TOOLS[tool]; - if (!config) return null; + if (!config) return undefined; // Tool not found - download it if (!silent) { @@ -208,6 +208,6 @@ export async function ensureTool(tool: "fd" | "rg", silent: boolean = false): Pr if (!silent) { console.log(chalk.yellow(`Failed to download ${config.name}: ${e instanceof Error ? e.message : e}`)); } - return null; + return undefined; } } diff --git a/packages/coding-agent/test/agent-session-branching.test.ts b/packages/coding-agent/test/agent-session-branching.test.ts index d6f2db1c..71b78ab7 100644 --- a/packages/coding-agent/test/agent-session-branching.test.ts +++ b/packages/coding-agent/test/agent-session-branching.test.ts @@ -10,7 +10,7 @@ import { existsSync, mkdirSync, rmSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core"; +import { Agent } from "@mariozechner/pi-agent-core"; import { getModel } from "@mariozechner/pi-ai"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { AgentSession } from "../src/core/agent-session.js"; @@ -19,8 +19,7 @@ import { ModelRegistry } from "../src/core/model-registry.js"; import { SessionManager } from "../src/core/session-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js"; import { codingTools } from "../src/core/tools/index.js"; - -const API_KEY = process.env.ANTHROPIC_API_KEY || process.env.ANTHROPIC_OAUTH_TOKEN; +import { API_KEY } from "./utilities.js"; describe.skipIf(!API_KEY)("AgentSession branching", () => { let session: AgentSession; @@ -44,13 +43,8 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => { function createSession(noSession: boolean = false) { const model = getModel("anthropic", "claude-sonnet-4-5")!; - - const transport = new ProviderTransport({ - getApiKey: () => API_KEY, - }); - const agent = new Agent({ - transport, + getApiKey: () => API_KEY, initialState: { model, systemPrompt: "You are a helpful assistant. Be extremely concise, reply with just a few words.", @@ -105,7 +99,7 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => { createSession(true); // Verify sessions are disabled - expect(session.sessionFile).toBeNull(); + expect(session.sessionFile).toBeUndefined(); // Send one message await session.prompt("Say hi"); @@ -126,8 +120,8 @@ describe.skipIf(!API_KEY)("AgentSession branching", () => { // After branching, conversation should be empty expect(session.messages.length).toBe(0); - // Session file should still be null (no file created) - expect(session.sessionFile).toBeNull(); + // Session file should still be undefined (no file created) + expect(session.sessionFile).toBeUndefined(); }); it("should branch from middle of conversation", async () => { diff --git a/packages/coding-agent/test/agent-session-compaction.test.ts b/packages/coding-agent/test/agent-session-compaction.test.ts index ff519eb5..14a664bb 100644 --- a/packages/coding-agent/test/agent-session-compaction.test.ts +++ b/packages/coding-agent/test/agent-session-compaction.test.ts @@ -10,7 +10,7 @@ import { existsSync, mkdirSync, rmSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core"; +import { Agent } from "@mariozechner/pi-agent-core"; import { getModel } from "@mariozechner/pi-ai"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { AgentSession, type AgentSessionEvent } from "../src/core/agent-session.js"; @@ -19,8 +19,7 @@ import { ModelRegistry } from "../src/core/model-registry.js"; import { SessionManager } from "../src/core/session-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js"; import { codingTools } from "../src/core/tools/index.js"; - -const API_KEY = process.env.ANTHROPIC_API_KEY || process.env.ANTHROPIC_OAUTH_TOKEN; +import { API_KEY } from "./utilities.js"; describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { let session: AgentSession; @@ -46,15 +45,10 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { } }); - function createSession() { + function createSession(inMemory = false) { const model = getModel("anthropic", "claude-sonnet-4-5")!; - - const transport = new ProviderTransport({ - getApiKey: () => API_KEY, - }); - const agent = new Agent({ - transport, + getApiKey: () => API_KEY, initialState: { model, systemPrompt: "You are a helpful assistant. Be concise.", @@ -62,8 +56,10 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { }, }); - sessionManager = SessionManager.create(tempDir); + sessionManager = inMemory ? SessionManager.inMemory() : SessionManager.create(tempDir); const settingsManager = SettingsManager.create(tempDir, tempDir); + // Use minimal keepRecentTokens so small test conversations have something to summarize + settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } }); const authStorage = new AuthStorage(join(tempDir, "auth.json")); const modelRegistry = new ModelRegistry(authStorage); @@ -105,7 +101,7 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { // First message should be the summary (a user message with summary content) const firstMsg = messages[0]; - expect(firstMsg.role).toBe("user"); + expect(firstMsg.role).toBe("compactionSummary"); }, 120000); it("should maintain valid session state after compaction", async () => { @@ -156,64 +152,31 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { expect(compaction.type).toBe("compaction"); if (compaction.type === "compaction") { expect(compaction.summary.length).toBeGreaterThan(0); - // firstKeptEntryIndex can be 0 if all messages fit within keepRecentTokens - // (which is the case for small conversations) - expect(compaction.firstKeptEntryIndex).toBeGreaterThanOrEqual(0); + expect(typeof compaction.firstKeptEntryId).toBe("string"); expect(compaction.tokensBefore).toBeGreaterThan(0); } }, 120000); it("should work with --no-session mode (in-memory only)", async () => { - const model = getModel("anthropic", "claude-sonnet-4-5")!; + createSession(true); // in-memory mode - const transport = new ProviderTransport({ - getApiKey: () => API_KEY, - }); + // Send prompts + await session.prompt("What is 2+2? Reply with just the number."); + await session.agent.waitForIdle(); - const agent = new Agent({ - transport, - initialState: { - model, - systemPrompt: "You are a helpful assistant. Be concise.", - tools: codingTools, - }, - }); + await session.prompt("What is 3+3? Reply with just the number."); + await session.agent.waitForIdle(); - // Create in-memory session manager - const noSessionManager = SessionManager.inMemory(); + // Compact should work even without file persistence + const result = await session.compact(); - const settingsManager = SettingsManager.create(tempDir, tempDir); - const authStorage = new AuthStorage(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage); + expect(result.summary).toBeDefined(); + expect(result.summary.length).toBeGreaterThan(0); - const noSessionSession = new AgentSession({ - agent, - sessionManager: noSessionManager, - settingsManager, - modelRegistry, - }); - - try { - // Send prompts - await noSessionSession.prompt("What is 2+2? Reply with just the number."); - await noSessionSession.agent.waitForIdle(); - - await noSessionSession.prompt("What is 3+3? Reply with just the number."); - await noSessionSession.agent.waitForIdle(); - - // Compact should work even without file persistence - const result = await noSessionSession.compact(); - - expect(result.summary).toBeDefined(); - expect(result.summary.length).toBeGreaterThan(0); - - // In-memory entries should have the compaction - const entries = noSessionManager.getEntries(); - const compactionEntries = entries.filter((e) => e.type === "compaction"); - expect(compactionEntries.length).toBe(1); - } finally { - noSessionSession.dispose(); - } + // In-memory entries should have the compaction + const entries = sessionManager.getEntries(); + const compactionEntries = entries.filter((e) => e.type === "compaction"); + expect(compactionEntries.length).toBe(1); }, 120000); it("should emit correct events during auto-compaction", async () => { diff --git a/packages/coding-agent/test/agent-session-tree-navigation.test.ts b/packages/coding-agent/test/agent-session-tree-navigation.test.ts new file mode 100644 index 00000000..82e82439 --- /dev/null +++ b/packages/coding-agent/test/agent-session-tree-navigation.test.ts @@ -0,0 +1,318 @@ +/** + * E2E tests for AgentSession tree navigation with branch summarization. + * + * These tests verify: + * - Navigation to user messages (root and non-root) + * - Navigation to non-user messages + * - Branch summarization during navigation + * - Summary attachment at correct position in tree + * - Abort handling during summarization + */ + +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { API_KEY, createTestSession, type TestSessionContext } from "./utilities.js"; + +describe.skipIf(!API_KEY)("AgentSession tree navigation e2e", () => { + let ctx: TestSessionContext; + + beforeEach(() => { + ctx = createTestSession({ + systemPrompt: "You are a helpful assistant. Reply with just a few words.", + settingsOverrides: { compaction: { keepRecentTokens: 1 } }, + }); + }); + + afterEach(() => { + ctx.cleanup(); + }); + + it("should navigate to user message and put text in editor", async () => { + const { session } = ctx; + + // Build conversation: u1 -> a1 -> u2 -> a2 + await session.prompt("First message"); + await session.agent.waitForIdle(); + await session.prompt("Second message"); + await session.agent.waitForIdle(); + + // Get tree entries + const tree = session.sessionManager.getTree(); + expect(tree.length).toBe(1); + + // Find the first user entry (u1) + const rootNode = tree[0]; + expect(rootNode.entry.type).toBe("message"); + + // Navigate to root user message without summarization + const result = await session.navigateTree(rootNode.entry.id, { summarize: false }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBe("First message"); + + // After navigating to root user message, leaf should be null (empty conversation) + expect(session.sessionManager.getLeafId()).toBeNull(); + }, 60000); + + it("should navigate to non-user message without editor text", async () => { + const { session, sessionManager } = ctx; + + // Build conversation + await session.prompt("Hello"); + await session.agent.waitForIdle(); + + // Get the assistant message + const entries = sessionManager.getEntries(); + const assistantEntry = entries.find((e) => e.type === "message" && e.message.role === "assistant"); + expect(assistantEntry).toBeDefined(); + + // Navigate to assistant message + const result = await session.navigateTree(assistantEntry!.id, { summarize: false }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBeUndefined(); + + // Leaf should be the assistant entry + expect(sessionManager.getLeafId()).toBe(assistantEntry!.id); + }, 60000); + + it("should create branch summary when navigating with summarize=true", async () => { + const { session, sessionManager } = ctx; + + // Build conversation: u1 -> a1 -> u2 -> a2 + await session.prompt("What is 2+2?"); + await session.agent.waitForIdle(); + await session.prompt("What is 3+3?"); + await session.agent.waitForIdle(); + + // Get tree and find first user message + const tree = sessionManager.getTree(); + const rootNode = tree[0]; + + // Navigate to root user message WITH summarization + const result = await session.navigateTree(rootNode.entry.id, { summarize: true }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBe("What is 2+2?"); + expect(result.summaryEntry).toBeDefined(); + expect(result.summaryEntry?.type).toBe("branch_summary"); + expect(result.summaryEntry?.summary).toBeTruthy(); + expect(result.summaryEntry?.summary.length).toBeGreaterThan(0); + + // Summary should be a root entry (parentId = null) since we navigated to root user + expect(result.summaryEntry?.parentId).toBeNull(); + + // Leaf should be the summary entry + expect(sessionManager.getLeafId()).toBe(result.summaryEntry?.id); + }, 120000); + + it("should attach summary to correct parent when navigating to nested user message", async () => { + const { session, sessionManager } = ctx; + + // Build conversation: u1 -> a1 -> u2 -> a2 -> u3 -> a3 + await session.prompt("Message one"); + await session.agent.waitForIdle(); + await session.prompt("Message two"); + await session.agent.waitForIdle(); + await session.prompt("Message three"); + await session.agent.waitForIdle(); + + // Get the second user message (u2) + const entries = sessionManager.getEntries(); + const userEntries = entries.filter((e) => e.type === "message" && e.message.role === "user"); + expect(userEntries.length).toBe(3); + + const u2 = userEntries[1]; + const a1 = entries.find((e) => e.id === u2.parentId); // a1 is parent of u2 + + // Navigate to u2 with summarization + const result = await session.navigateTree(u2.id, { summarize: true }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBe("Message two"); + expect(result.summaryEntry).toBeDefined(); + + // Summary should be attached to a1 (parent of u2) + // So a1 now has two children: u2 and the summary + expect(result.summaryEntry?.parentId).toBe(a1?.id); + + // Verify tree structure + const children = sessionManager.getChildren(a1!.id); + expect(children.length).toBe(2); + + const childTypes = children.map((c) => c.type).sort(); + expect(childTypes).toContain("branch_summary"); + expect(childTypes).toContain("message"); + }, 120000); + + it("should attach summary to selected node when navigating to assistant message", async () => { + const { session, sessionManager } = ctx; + + // Build conversation: u1 -> a1 -> u2 -> a2 + await session.prompt("Hello"); + await session.agent.waitForIdle(); + await session.prompt("Goodbye"); + await session.agent.waitForIdle(); + + // Get the first assistant message (a1) + const entries = sessionManager.getEntries(); + const assistantEntries = entries.filter((e) => e.type === "message" && e.message.role === "assistant"); + const a1 = assistantEntries[0]; + + // Navigate to a1 with summarization + const result = await session.navigateTree(a1.id, { summarize: true }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBeUndefined(); // No editor text for assistant messages + expect(result.summaryEntry).toBeDefined(); + + // Summary should be attached to a1 (the selected node) + expect(result.summaryEntry?.parentId).toBe(a1.id); + + // Leaf should be the summary entry + expect(sessionManager.getLeafId()).toBe(result.summaryEntry?.id); + }, 120000); + + it("should handle abort during summarization", async () => { + const { session, sessionManager } = ctx; + + // Build conversation + await session.prompt("Tell me about something"); + await session.agent.waitForIdle(); + await session.prompt("Continue"); + await session.agent.waitForIdle(); + + const entriesBefore = sessionManager.getEntries(); + const leafBefore = sessionManager.getLeafId(); + + // Get root user message + const tree = sessionManager.getTree(); + const rootNode = tree[0]; + + // Start navigation with summarization but abort immediately + const navigationPromise = session.navigateTree(rootNode.entry.id, { summarize: true }); + + // Abort after a short delay (let the LLM call start) + await new Promise((resolve) => setTimeout(resolve, 100)); + session.abortBranchSummary(); + + const result = await navigationPromise; + + expect(result.cancelled).toBe(true); + expect(result.aborted).toBe(true); + expect(result.summaryEntry).toBeUndefined(); + + // Session should be unchanged + const entriesAfter = sessionManager.getEntries(); + expect(entriesAfter.length).toBe(entriesBefore.length); + expect(sessionManager.getLeafId()).toBe(leafBefore); + }, 60000); + + it("should not create summary when navigating without summarize option", async () => { + const { session, sessionManager } = ctx; + + // Build conversation + await session.prompt("First"); + await session.agent.waitForIdle(); + await session.prompt("Second"); + await session.agent.waitForIdle(); + + const entriesBefore = sessionManager.getEntries().length; + + // Navigate without summarization + const tree = sessionManager.getTree(); + await session.navigateTree(tree[0].entry.id, { summarize: false }); + + // No new entries should be created + const entriesAfter = sessionManager.getEntries().length; + expect(entriesAfter).toBe(entriesBefore); + + // No branch_summary entries + const summaries = sessionManager.getEntries().filter((e) => e.type === "branch_summary"); + expect(summaries.length).toBe(0); + }, 60000); + + it("should handle navigation to same position (no-op)", async () => { + const { session, sessionManager } = ctx; + + // Build conversation + await session.prompt("Hello"); + await session.agent.waitForIdle(); + + const leafBefore = sessionManager.getLeafId(); + expect(leafBefore).toBeTruthy(); + const entriesBefore = sessionManager.getEntries().length; + + // Navigate to current leaf + const result = await session.navigateTree(leafBefore!, { summarize: false }); + + expect(result.cancelled).toBe(false); + expect(sessionManager.getLeafId()).toBe(leafBefore); + expect(sessionManager.getEntries().length).toBe(entriesBefore); + }, 60000); + + it("should support custom summarization instructions", async () => { + const { session, sessionManager } = ctx; + + // Build conversation + await session.prompt("What is TypeScript?"); + await session.agent.waitForIdle(); + + // Navigate with custom instructions + const tree = sessionManager.getTree(); + const result = await session.navigateTree(tree[0].entry.id, { + summarize: true, + customInstructions: "Summarize in exactly 3 words.", + }); + + expect(result.summaryEntry).toBeDefined(); + expect(result.summaryEntry?.summary).toBeTruthy(); + // Can't reliably test 3 words exactly, but summary should be short + expect(result.summaryEntry?.summary.split(/\s+/).length).toBeLessThan(20); + }, 120000); +}); + +describe.skipIf(!API_KEY)("AgentSession tree navigation - branch scenarios", () => { + let ctx: TestSessionContext; + + beforeEach(() => { + ctx = createTestSession({ + systemPrompt: "You are a helpful assistant. Reply with just a few words.", + }); + }); + + afterEach(() => { + ctx.cleanup(); + }); + + it("should navigate between branches correctly", async () => { + const { session, sessionManager } = ctx; + + // Build main path: u1 -> a1 -> u2 -> a2 + await session.prompt("Main branch start"); + await session.agent.waitForIdle(); + await session.prompt("Main branch continue"); + await session.agent.waitForIdle(); + + // Get a1 id for branching + const entries = sessionManager.getEntries(); + const a1 = entries.find((e) => e.type === "message" && e.message.role === "assistant"); + + // Create a branch from a1: a1 -> u3 -> a3 + sessionManager.branch(a1!.id); + await session.prompt("Branch path"); + await session.agent.waitForIdle(); + + // Now navigate back to u2 (on main branch) with summarization + const userEntries = entries.filter((e) => e.type === "message" && e.message.role === "user"); + const u2 = userEntries[1]; // "Main branch continue" + + const result = await session.navigateTree(u2.id, { summarize: true }); + + expect(result.cancelled).toBe(false); + expect(result.editorText).toBe("Main branch continue"); + expect(result.summaryEntry).toBeDefined(); + + // Summary captures the branch we're leaving (the "Branch path" conversation) + expect(result.summaryEntry?.summary.length).toBeGreaterThan(0); + }, 180000); +}); diff --git a/packages/coding-agent/test/compaction-hooks-example.test.ts b/packages/coding-agent/test/compaction-hooks-example.test.ts index 476d40fb..fbf1dadd 100644 --- a/packages/coding-agent/test/compaction-hooks-example.test.ts +++ b/packages/coding-agent/test/compaction-hooks-example.test.ts @@ -3,46 +3,45 @@ */ import { describe, expect, it } from "vitest"; -import type { HookAPI } from "../src/core/hooks/index.js"; -import type { CompactionEntry } from "../src/core/session-manager.js"; +import type { HookAPI, SessionBeforeCompactEvent, SessionCompactEvent } from "../src/core/hooks/index.js"; describe("Documentation example", () => { it("custom compaction example should type-check correctly", () => { // This is the example from hooks.md - verify it compiles const exampleHook = (pi: HookAPI) => { - pi.on("session", async (event, _ctx) => { - if (event.reason !== "before_compact") return; + pi.on("session_before_compact", async (event: SessionBeforeCompactEvent, ctx) => { + // All these should be accessible on the event + const { preparation, previousCompactions, model } = event; + // sessionManager and modelRegistry come from ctx, not event + const { sessionManager, modelRegistry } = ctx; + const { messagesToSummarize, messagesToKeep, tokensBefore, firstKeptEntryId, cutPoint } = preparation; - // After narrowing, these should all be accessible - const messages = event.messagesToSummarize; - const messagesToKeep = event.messagesToKeep; - const cutPoint = event.cutPoint; - const tokensBefore = event.tokensBefore; - const model = event.model; - const resolveApiKey = event.resolveApiKey; + // Get previous summary from most recent compaction + const _previousSummary = previousCompactions[0]?.summary; // Verify types - expect(Array.isArray(messages)).toBe(true); + expect(Array.isArray(messagesToSummarize)).toBe(true); expect(Array.isArray(messagesToKeep)).toBe(true); expect(typeof cutPoint.firstKeptEntryIndex).toBe("number"); expect(typeof tokensBefore).toBe("number"); expect(model).toBeDefined(); - expect(typeof resolveApiKey).toBe("function"); + expect(typeof sessionManager.getEntries).toBe("function"); + expect(typeof modelRegistry.getApiKey).toBe("function"); + expect(typeof firstKeptEntryId).toBe("string"); - const summary = messages + const summary = messagesToSummarize .filter((m) => m.role === "user") .map((m) => `- ${typeof m.content === "string" ? m.content.slice(0, 100) : "[complex]"}`) .join("\n"); - const compactionEntry: CompactionEntry = { - type: "compaction", - timestamp: new Date().toISOString(), - summary: `User requests:\n${summary}`, - firstKeptEntryIndex: event.cutPoint.firstKeptEntryIndex, - tokensBefore: event.tokensBefore, + // Hooks return compaction content - SessionManager adds id/parentId + return { + compaction: { + summary: `User requests:\n${summary}`, + firstKeptEntryId, + tokensBefore, + }, }; - - return { compactionEntry }; }); }; @@ -50,19 +49,16 @@ describe("Documentation example", () => { expect(typeof exampleHook).toBe("function"); }); - it("compact event should have correct fields after narrowing", () => { + it("compact event should have correct fields", () => { const checkCompactEvent = (pi: HookAPI) => { - pi.on("session", async (event, _ctx) => { - if (event.reason !== "compact") return; - - // After narrowing, these should all be accessible + pi.on("session_compact", async (event: SessionCompactEvent) => { + // These should all be accessible const entry = event.compactionEntry; - const tokensBefore = event.tokensBefore; const fromHook = event.fromHook; expect(entry.type).toBe("compaction"); expect(typeof entry.summary).toBe("string"); - expect(typeof tokensBefore).toBe("number"); + expect(typeof entry.tokensBefore).toBe("number"); expect(typeof fromHook).toBe("boolean"); }); }; diff --git a/packages/coding-agent/test/compaction-hooks.test.ts b/packages/coding-agent/test/compaction-hooks.test.ts index e3773742..9ca29281 100644 --- a/packages/coding-agent/test/compaction-hooks.test.ts +++ b/packages/coding-agent/test/compaction-hooks.test.ts @@ -5,18 +5,24 @@ import { existsSync, mkdirSync, rmSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { Agent, ProviderTransport } from "@mariozechner/pi-agent-core"; +import { Agent } from "@mariozechner/pi-agent-core"; import { getModel } from "@mariozechner/pi-ai"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { AgentSession } from "../src/core/agent-session.js"; import { AuthStorage } from "../src/core/auth-storage.js"; -import { HookRunner, type LoadedHook, type SessionEvent } from "../src/core/hooks/index.js"; +import { + HookRunner, + type LoadedHook, + type SessionBeforeCompactEvent, + type SessionCompactEvent, + type SessionEvent, +} from "../src/core/hooks/index.js"; import { ModelRegistry } from "../src/core/model-registry.js"; import { SessionManager } from "../src/core/session-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js"; import { codingTools } from "../src/core/tools/index.js"; -const API_KEY = process.env.ANTHROPIC_API_KEY || process.env.ANTHROPIC_OAUTH_TOKEN; +const API_KEY = process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; describe.skipIf(!API_KEY)("Compaction hooks", () => { let session: AgentSession; @@ -40,19 +46,25 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { }); function createHook( - onBeforeCompact?: (event: SessionEvent) => { cancel?: boolean; compactionEntry?: any } | undefined, - onCompact?: (event: SessionEvent) => void, + onBeforeCompact?: (event: SessionBeforeCompactEvent) => { cancel?: boolean; compaction?: any } | undefined, + onCompact?: (event: SessionCompactEvent) => void, ): LoadedHook { const handlers = new Map Promise)[]>(); - handlers.set("session", [ - async (event: SessionEvent) => { + handlers.set("session_before_compact", [ + async (event: SessionBeforeCompactEvent) => { capturedEvents.push(event); - - if (event.reason === "before_compact" && onBeforeCompact) { + if (onBeforeCompact) { return onBeforeCompact(event); } - if (event.reason === "compact" && onCompact) { + return undefined; + }, + ]); + + handlers.set("session_compact", [ + async (event: SessionCompactEvent) => { + capturedEvents.push(event); + if (onCompact) { onCompact(event); } return undefined; @@ -63,19 +75,17 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { path: "test-hook", resolvedPath: "/test/test-hook.ts", handlers, - setSendHandler: () => {}, + messageRenderers: new Map(), + commands: new Map(), + setSendMessageHandler: () => {}, + setAppendEntryHandler: () => {}, }; } function createSession(hooks: LoadedHook[]) { const model = getModel("anthropic", "claude-sonnet-4-5")!; - - const transport = new ProviderTransport({ - getApiKey: () => API_KEY, - }); - const agent = new Agent({ - transport, + getApiKey: () => API_KEY, initialState: { model, systemPrompt: "You are a helpful assistant. Be concise.", @@ -88,17 +98,17 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { const authStorage = new AuthStorage(join(tempDir, "auth.json")); const modelRegistry = new ModelRegistry(authStorage); - hookRunner = new HookRunner(hooks, tempDir); + hookRunner = new HookRunner(hooks, tempDir, sessionManager, modelRegistry); hookRunner.setUIContext( { - select: async () => null, + select: async () => undefined, confirm: async () => false, - input: async () => null, + input: async () => undefined, notify: () => {}, + custom: () => ({ close: () => {}, requestRender: () => {} }), }, false, ); - hookRunner.setSessionFile(sessionManager.getSessionFile()); session = new AgentSession({ agent, @@ -123,30 +133,28 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { await session.compact(); - const beforeCompactEvents = capturedEvents.filter((e) => e.reason === "before_compact"); - const compactEvents = capturedEvents.filter((e) => e.reason === "compact"); + const beforeCompactEvents = capturedEvents.filter( + (e): e is SessionBeforeCompactEvent => e.type === "session_before_compact", + ); + const compactEvents = capturedEvents.filter((e): e is SessionCompactEvent => e.type === "session_compact"); expect(beforeCompactEvents.length).toBe(1); expect(compactEvents.length).toBe(1); const beforeEvent = beforeCompactEvents[0]; - if (beforeEvent.reason === "before_compact") { - expect(beforeEvent.cutPoint).toBeDefined(); - expect(beforeEvent.cutPoint.firstKeptEntryIndex).toBeGreaterThanOrEqual(0); - expect(beforeEvent.messagesToSummarize).toBeDefined(); - expect(beforeEvent.messagesToKeep).toBeDefined(); - expect(beforeEvent.tokensBefore).toBeGreaterThanOrEqual(0); - expect(beforeEvent.model).toBeDefined(); - expect(beforeEvent.resolveApiKey).toBeDefined(); - } + expect(beforeEvent.preparation).toBeDefined(); + expect(beforeEvent.preparation.cutPoint.firstKeptEntryIndex).toBeGreaterThanOrEqual(0); + expect(beforeEvent.preparation.messagesToSummarize).toBeDefined(); + expect(beforeEvent.preparation.messagesToKeep).toBeDefined(); + expect(beforeEvent.preparation.tokensBefore).toBeGreaterThanOrEqual(0); + expect(beforeEvent.model).toBeDefined(); + // sessionManager and modelRegistry are now on ctx, not event const afterEvent = compactEvents[0]; - if (afterEvent.reason === "compact") { - expect(afterEvent.compactionEntry).toBeDefined(); - expect(afterEvent.compactionEntry.summary.length).toBeGreaterThan(0); - expect(afterEvent.tokensBefore).toBeGreaterThanOrEqual(0); - expect(afterEvent.fromHook).toBe(false); - } + expect(afterEvent.compactionEntry).toBeDefined(); + expect(afterEvent.compactionEntry.summary.length).toBeGreaterThan(0); + expect(afterEvent.compactionEntry.tokensBefore).toBeGreaterThanOrEqual(0); + expect(afterEvent.fromHook).toBe(false); }, 120000); it("should allow hooks to cancel compaction", async () => { @@ -158,22 +166,20 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { await expect(session.compact()).rejects.toThrow("Compaction cancelled"); - const compactEvents = capturedEvents.filter((e) => e.reason === "compact"); + const compactEvents = capturedEvents.filter((e) => e.type === "session_compact"); expect(compactEvents.length).toBe(0); }, 120000); - it("should allow hooks to provide custom compactionEntry", async () => { + it("should allow hooks to provide custom compaction", async () => { const customSummary = "Custom summary from hook"; const hook = createHook((event) => { - if (event.reason === "before_compact") { + if (event.type === "session_before_compact") { return { - compactionEntry: { - type: "compaction" as const, - timestamp: new Date().toISOString(), + compaction: { summary: customSummary, - firstKeptEntryIndex: event.cutPoint.firstKeptEntryIndex, - tokensBefore: event.tokensBefore, + firstKeptEntryId: event.preparation.firstKeptEntryId, + tokensBefore: event.preparation.tokensBefore, }, }; } @@ -191,11 +197,11 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { expect(result.summary).toBe(customSummary); - const compactEvents = capturedEvents.filter((e) => e.reason === "compact"); + const compactEvents = capturedEvents.filter((e) => e.type === "session_compact"); expect(compactEvents.length).toBe(1); const afterEvent = compactEvents[0]; - if (afterEvent.reason === "compact") { + if (afterEvent.type === "session_compact") { expect(afterEvent.compactionEntry.summary).toBe(customSummary); expect(afterEvent.fromHook).toBe(true); } @@ -210,12 +216,14 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { await session.compact(); - const compactEvents = capturedEvents.filter((e) => e.reason === "compact"); + const compactEvents = capturedEvents.filter((e) => e.type === "session_compact"); expect(compactEvents.length).toBe(1); const afterEvent = compactEvents[0]; - if (afterEvent.reason === "compact") { - const hasCompactionEntry = afterEvent.entries.some((e) => e.type === "compaction"); + if (afterEvent.type === "session_compact") { + // sessionManager is now on ctx, use session.sessionManager directly + const entries = session.sessionManager.getEntries(); + const hasCompactionEntry = entries.some((e: { type: string }) => e.type === "compaction"); expect(hasCompactionEntry).toBe(true); } }, 120000); @@ -226,19 +234,28 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { resolvedPath: "/test/throwing-hook.ts", handlers: new Map Promise)[]>([ [ - "session", + "session_before_compact", [ - async (event: SessionEvent) => { + async (event: SessionBeforeCompactEvent) => { + capturedEvents.push(event); + throw new Error("Hook intentionally throws"); + }, + ], + ], + [ + "session_compact", + [ + async (event: SessionCompactEvent) => { capturedEvents.push(event); - if (event.reason === "before_compact") { - throw new Error("Hook intentionally failed"); - } return undefined; }, ], ], ]), - setSendHandler: () => {}, + messageRenderers: new Map(), + commands: new Map(), + setSendMessageHandler: () => {}, + setAppendEntryHandler: () => {}, }; createSession([throwingHook]); @@ -251,12 +268,9 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { expect(result.summary).toBeDefined(); expect(result.summary.length).toBeGreaterThan(0); - const compactEvents = capturedEvents.filter((e) => e.reason === "compact"); + const compactEvents = capturedEvents.filter((e): e is SessionCompactEvent => e.type === "session_compact"); expect(compactEvents.length).toBe(1); - - if (compactEvents[0].reason === "compact") { - expect(compactEvents[0].fromHook).toBe(false); - } + expect(compactEvents[0].fromHook).toBe(false); }, 120000); it("should call multiple hooks in order", async () => { @@ -267,21 +281,28 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { resolvedPath: "/test/hook1.ts", handlers: new Map Promise)[]>([ [ - "session", + "session_before_compact", [ - async (event: SessionEvent) => { - if (event.reason === "before_compact") { - callOrder.push("hook1-before"); - } - if (event.reason === "compact") { - callOrder.push("hook1-after"); - } + async () => { + callOrder.push("hook1-before"); + return undefined; + }, + ], + ], + [ + "session_compact", + [ + async () => { + callOrder.push("hook1-after"); return undefined; }, ], ], ]), - setSendHandler: () => {}, + messageRenderers: new Map(), + commands: new Map(), + setSendMessageHandler: () => {}, + setAppendEntryHandler: () => {}, }; const hook2: LoadedHook = { @@ -289,21 +310,28 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { resolvedPath: "/test/hook2.ts", handlers: new Map Promise)[]>([ [ - "session", + "session_before_compact", [ - async (event: SessionEvent) => { - if (event.reason === "before_compact") { - callOrder.push("hook2-before"); - } - if (event.reason === "compact") { - callOrder.push("hook2-after"); - } + async () => { + callOrder.push("hook2-before"); + return undefined; + }, + ], + ], + [ + "session_compact", + [ + async () => { + callOrder.push("hook2-after"); return undefined; }, ], ], ]), - setSendHandler: () => {}, + messageRenderers: new Map(), + commands: new Map(), + setSendMessageHandler: () => {}, + setAppendEntryHandler: () => {}, }; createSession([hook1, hook2]); @@ -317,12 +345,10 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { }, 120000); it("should pass correct data in before_compact event", async () => { - let capturedBeforeEvent: (SessionEvent & { reason: "before_compact" }) | null = null; + let capturedBeforeEvent: SessionBeforeCompactEvent | null = null; const hook = createHook((event) => { - if (event.reason === "before_compact") { - capturedBeforeEvent = event; - } + capturedBeforeEvent = event; return undefined; }); createSession([hook]); @@ -337,35 +363,37 @@ describe.skipIf(!API_KEY)("Compaction hooks", () => { expect(capturedBeforeEvent).not.toBeNull(); const event = capturedBeforeEvent!; - expect(event.cutPoint).toHaveProperty("firstKeptEntryIndex"); - expect(event.cutPoint).toHaveProperty("isSplitTurn"); - expect(event.cutPoint).toHaveProperty("turnStartIndex"); + expect(event.preparation.cutPoint).toHaveProperty("firstKeptEntryIndex"); + expect(event.preparation.cutPoint).toHaveProperty("isSplitTurn"); + expect(event.preparation.cutPoint).toHaveProperty("turnStartIndex"); - expect(Array.isArray(event.messagesToSummarize)).toBe(true); - expect(Array.isArray(event.messagesToKeep)).toBe(true); + expect(Array.isArray(event.preparation.messagesToSummarize)).toBe(true); + expect(Array.isArray(event.preparation.messagesToKeep)).toBe(true); - expect(typeof event.tokensBefore).toBe("number"); + expect(typeof event.preparation.tokensBefore).toBe("number"); expect(event.model).toHaveProperty("provider"); expect(event.model).toHaveProperty("id"); - expect(typeof event.resolveApiKey).toBe("function"); + // sessionManager and modelRegistry are now on ctx, not event + // Verify they're accessible via session + expect(typeof session.sessionManager.getEntries).toBe("function"); + expect(typeof session.modelRegistry.getApiKey).toBe("function"); - expect(Array.isArray(event.entries)).toBe(true); - expect(event.entries.length).toBeGreaterThan(0); + const entries = session.sessionManager.getEntries(); + expect(Array.isArray(entries)).toBe(true); + expect(entries.length).toBeGreaterThan(0); }, 120000); - it("should use hook compactionEntry even with different firstKeptEntryIndex", async () => { - const customSummary = "Custom summary with modified index"; + it("should use hook compaction even with different values", async () => { + const customSummary = "Custom summary with modified values"; const hook = createHook((event) => { - if (event.reason === "before_compact") { + if (event.type === "session_before_compact") { return { - compactionEntry: { - type: "compaction" as const, - timestamp: new Date().toISOString(), + compaction: { summary: customSummary, - firstKeptEntryIndex: 0, + firstKeptEntryId: event.preparation.firstKeptEntryId, tokensBefore: 999, }, }; diff --git a/packages/coding-agent/test/compaction.test.ts b/packages/coding-agent/test/compaction.test.ts index 24c7e89d..0468a582 100644 --- a/packages/coding-agent/test/compaction.test.ts +++ b/packages/coding-agent/test/compaction.test.ts @@ -1,9 +1,9 @@ -import type { AppMessage } from "@mariozechner/pi-agent-core"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { AssistantMessage, Usage } from "@mariozechner/pi-ai"; import { getModel } from "@mariozechner/pi-ai"; import { readFileSync } from "fs"; import { join } from "path"; -import { describe, expect, it } from "vitest"; +import { beforeEach, describe, expect, it } from "vitest"; import { type CompactionSettings, calculateContextTokens, @@ -12,14 +12,16 @@ import { findCutPoint, getLastAssistantUsage, shouldCompact, -} from "../src/core/compaction.js"; +} from "../src/core/compaction/index.js"; import { buildSessionContext, type CompactionEntry, - createSummaryMessage, + type ModelChangeEntry, + migrateSessionEntries, parseSessionEntries, type SessionEntry, type SessionMessageEntry, + type ThinkingLevelChangeEntry, } from "../src/core/session-manager.js"; // ============================================================================ @@ -29,7 +31,9 @@ import { function loadLargeSessionEntries(): SessionEntry[] { const sessionPath = join(__dirname, "fixtures/large-session.jsonl"); const content = readFileSync(sessionPath, "utf-8"); - return parseSessionEntries(content); + const entries = parseSessionEntries(content); + migrateSessionEntries(entries); // Add id/parentId for v1 fixtures + return entries.filter((e): e is SessionEntry => e.type !== "session"); } function createMockUsage(input: number, output: number, cacheRead = 0, cacheWrite = 0): Usage { @@ -43,7 +47,7 @@ function createMockUsage(input: number, output: number, cacheRead = 0, cacheWrit }; } -function createUserMessage(text: string): AppMessage { +function createUserMessage(text: string): AgentMessage { return { role: "user", content: text, timestamp: Date.now() }; } @@ -60,18 +64,72 @@ function createAssistantMessage(text: string, usage?: Usage): AssistantMessage { }; } -function createMessageEntry(message: AppMessage): SessionMessageEntry { - return { type: "message", timestamp: new Date().toISOString(), message }; +let entryCounter = 0; +let lastId: string | null = null; + +function resetEntryCounter() { + entryCounter = 0; + lastId = null; } -function createCompactionEntry(summary: string, firstKeptEntryIndex: number): CompactionEntry { - return { +// Reset counter before each test to get predictable IDs +beforeEach(() => { + resetEntryCounter(); +}); + +function createMessageEntry(message: AgentMessage): SessionMessageEntry { + const id = `test-id-${entryCounter++}`; + const entry: SessionMessageEntry = { + type: "message", + id, + parentId: lastId, + timestamp: new Date().toISOString(), + message, + }; + lastId = id; + return entry; +} + +function createCompactionEntry(summary: string, firstKeptEntryId: string): CompactionEntry { + const id = `test-id-${entryCounter++}`; + const entry: CompactionEntry = { type: "compaction", + id, + parentId: lastId, timestamp: new Date().toISOString(), summary, - firstKeptEntryIndex, + firstKeptEntryId, tokensBefore: 10000, }; + lastId = id; + return entry; +} + +function createModelChangeEntry(provider: string, modelId: string): ModelChangeEntry { + const id = `test-id-${entryCounter++}`; + const entry: ModelChangeEntry = { + type: "model_change", + id, + parentId: lastId, + timestamp: new Date().toISOString(), + provider, + modelId, + }; + lastId = id; + return entry; +} + +function createThinkingLevelEntry(thinkingLevel: string): ThinkingLevelChangeEntry { + const id = `test-id-${entryCounter++}`; + const entry: ThinkingLevelChangeEntry = { + type: "thinking_level_change", + id, + parentId: lastId, + timestamp: new Date().toISOString(), + thinkingLevel, + }; + lastId = id; + return entry; } // ============================================================================ @@ -122,9 +180,9 @@ describe("getLastAssistantUsage", () => { expect(usage!.input).toBe(100); }); - it("should return null if no assistant messages", () => { + it("should return undefined if no assistant messages", () => { const entries: SessionEntry[] = [createMessageEntry(createUserMessage("Hello"))]; - expect(getLastAssistantUsage(entries)).toBeNull(); + expect(getLastAssistantUsage(entries)).toBeUndefined(); }); }); @@ -213,28 +271,9 @@ describe("findCutPoint", () => { }); }); -describe("createSummaryMessage", () => { - it("should create user message with prefix", () => { - const msg = createSummaryMessage("This is the summary"); - expect(msg.role).toBe("user"); - if (msg.role === "user") { - expect(msg.content).toContain( - "The conversation history before this point was compacted into the following summary:", - ); - expect(msg.content).toContain("This is the summary"); - } - }); -}); - describe("buildSessionContext", () => { it("should load all messages when no compaction", () => { const entries: SessionEntry[] = [ - { - type: "session", - id: "1", - timestamp: "", - cwd: "", - }, createMessageEntry(createUserMessage("1")), createMessageEntry(createAssistantMessage("a")), createMessageEntry(createUserMessage("2")), @@ -248,92 +287,67 @@ describe("buildSessionContext", () => { }); it("should handle single compaction", () => { - // indices: 0=session, 1=u1, 2=a1, 3=u2, 4=a2, 5=compaction, 6=u3, 7=a3 - const entries: SessionEntry[] = [ - { - type: "session", - id: "1", - timestamp: "", - cwd: "", - }, - createMessageEntry(createUserMessage("1")), - createMessageEntry(createAssistantMessage("a")), - createMessageEntry(createUserMessage("2")), - createMessageEntry(createAssistantMessage("b")), - createCompactionEntry("Summary of 1,a,2,b", 3), // keep from index 3 (u2) onwards - createMessageEntry(createUserMessage("3")), - createMessageEntry(createAssistantMessage("c")), - ]; + // IDs: u1=test-id-0, a1=test-id-1, u2=test-id-2, a2=test-id-3, compaction=test-id-4, u3=test-id-5, a3=test-id-6 + const u1 = createMessageEntry(createUserMessage("1")); + const a1 = createMessageEntry(createAssistantMessage("a")); + const u2 = createMessageEntry(createUserMessage("2")); + const a2 = createMessageEntry(createAssistantMessage("b")); + const compaction = createCompactionEntry("Summary of 1,a,2,b", u2.id); // keep from u2 onwards + const u3 = createMessageEntry(createUserMessage("3")); + const a3 = createMessageEntry(createAssistantMessage("c")); + + const entries: SessionEntry[] = [u1, a1, u2, a2, compaction, u3, a3]; const loaded = buildSessionContext(entries); - // summary + kept (u2,a2 from idx 3-4) + after (u3,a3 from idx 6-7) = 5 + // summary + kept (u2, a2) + after (u3, a3) = 5 expect(loaded.messages.length).toBe(5); - expect(loaded.messages[0].role).toBe("user"); - expect((loaded.messages[0] as any).content).toContain("Summary of 1,a,2,b"); + expect(loaded.messages[0].role).toBe("compactionSummary"); + expect((loaded.messages[0] as any).summary).toContain("Summary of 1,a,2,b"); }); it("should handle multiple compactions (only latest matters)", () => { - // indices: 0=session, 1=u1, 2=a1, 3=compact1, 4=u2, 5=b, 6=u3, 7=c, 8=compact2, 9=u4, 10=d - const entries: SessionEntry[] = [ - { - type: "session", - id: "1", - timestamp: "", - cwd: "", - }, - createMessageEntry(createUserMessage("1")), - createMessageEntry(createAssistantMessage("a")), - createCompactionEntry("First summary", 1), // keep from index 1 - createMessageEntry(createUserMessage("2")), - createMessageEntry(createAssistantMessage("b")), - createMessageEntry(createUserMessage("3")), - createMessageEntry(createAssistantMessage("c")), - createCompactionEntry("Second summary", 6), // keep from index 6 (u3) onwards - createMessageEntry(createUserMessage("4")), - createMessageEntry(createAssistantMessage("d")), - ]; + // First batch + const u1 = createMessageEntry(createUserMessage("1")); + const a1 = createMessageEntry(createAssistantMessage("a")); + const compact1 = createCompactionEntry("First summary", u1.id); + // Second batch + const u2 = createMessageEntry(createUserMessage("2")); + const b = createMessageEntry(createAssistantMessage("b")); + const u3 = createMessageEntry(createUserMessage("3")); + const c = createMessageEntry(createAssistantMessage("c")); + const compact2 = createCompactionEntry("Second summary", u3.id); // keep from u3 onwards + // After second compaction + const u4 = createMessageEntry(createUserMessage("4")); + const d = createMessageEntry(createAssistantMessage("d")); + + const entries: SessionEntry[] = [u1, a1, compact1, u2, b, u3, c, compact2, u4, d]; const loaded = buildSessionContext(entries); - // summary + kept from idx 6 (u3,c) + after (u4,d) = 5 + // summary + kept from u3 (u3, c) + after (u4, d) = 5 expect(loaded.messages.length).toBe(5); - expect((loaded.messages[0] as any).content).toContain("Second summary"); + expect((loaded.messages[0] as any).summary).toContain("Second summary"); }); - it("should clamp firstKeptEntryIndex to valid range", () => { - // indices: 0=session, 1=u1, 2=a1, 3=compact1, 4=u2, 5=b, 6=compact2 - const entries: SessionEntry[] = [ - { - type: "session", - id: "1", - timestamp: "", - cwd: "", - }, - createMessageEntry(createUserMessage("1")), - createMessageEntry(createAssistantMessage("a")), - createCompactionEntry("First summary", 1), - createMessageEntry(createUserMessage("2")), - createMessageEntry(createAssistantMessage("b")), - createCompactionEntry("Second summary", 0), // index 0 is before compaction1, should still work - ]; + it("should keep all messages when firstKeptEntryId is first entry", () => { + const u1 = createMessageEntry(createUserMessage("1")); + const a1 = createMessageEntry(createAssistantMessage("a")); + const compact1 = createCompactionEntry("First summary", u1.id); // keep from first entry + const u2 = createMessageEntry(createUserMessage("2")); + const b = createMessageEntry(createAssistantMessage("b")); + + const entries: SessionEntry[] = [u1, a1, compact1, u2, b]; const loaded = buildSessionContext(entries); - // Keeps from index 0, but compaction entries are skipped, so u1,a1,u2,b = 4 + summary = 5 - // Actually index 0 is session header, so messages are u1,a1,u2,b - expect(loaded.messages.length).toBe(5); // summary + 4 messages + // summary + all messages (u1, a1, u2, b) = 5 + expect(loaded.messages.length).toBe(5); }); it("should track model and thinking level changes", () => { const entries: SessionEntry[] = [ - { - type: "session", - id: "1", - timestamp: "", - cwd: "", - }, createMessageEntry(createUserMessage("1")), - { type: "model_change", timestamp: "", provider: "openai", modelId: "gpt-4" }, + createModelChangeEntry("openai", "gpt-4"), createMessageEntry(createAssistantMessage("a")), - { type: "thinking_level_change", timestamp: "", thinkingLevel: "high" }, + createThinkingLevelEntry("high"), ]; const loaded = buildSessionContext(entries); @@ -380,27 +394,26 @@ describe("Large session fixture", () => { // ============================================================================ describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("LLM summarization", () => { - it("should generate a compaction event for the large session", async () => { + it("should generate a compaction result for the large session", async () => { const entries = loadLargeSessionEntries(); const model = getModel("anthropic", "claude-sonnet-4-5")!; - const compactionEvent = await compact( + const compactionResult = await compact( entries, model, DEFAULT_COMPACTION_SETTINGS, process.env.ANTHROPIC_OAUTH_TOKEN!, ); - expect(compactionEvent.type).toBe("compaction"); - expect(compactionEvent.summary.length).toBeGreaterThan(100); - expect(compactionEvent.firstKeptEntryIndex).toBeGreaterThan(0); - expect(compactionEvent.tokensBefore).toBeGreaterThan(0); + expect(compactionResult.summary.length).toBeGreaterThan(100); + expect(compactionResult.firstKeptEntryId).toBeTruthy(); + expect(compactionResult.tokensBefore).toBeGreaterThan(0); - console.log("Summary length:", compactionEvent.summary.length); - console.log("First kept entry index:", compactionEvent.firstKeptEntryIndex); - console.log("Tokens before:", compactionEvent.tokensBefore); + console.log("Summary length:", compactionResult.summary.length); + console.log("First kept entry ID:", compactionResult.firstKeptEntryId); + console.log("Tokens before:", compactionResult.tokensBefore); console.log("\n--- SUMMARY ---\n"); - console.log(compactionEvent.summary); + console.log(compactionResult.summary); }, 60000); it("should produce valid session after compaction", async () => { @@ -408,21 +421,30 @@ describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("LLM summarization", () => { const loaded = buildSessionContext(entries); const model = getModel("anthropic", "claude-sonnet-4-5")!; - const compactionEvent = await compact( + const compactionResult = await compact( entries, model, DEFAULT_COMPACTION_SETTINGS, process.env.ANTHROPIC_OAUTH_TOKEN!, ); - // Simulate appending compaction to entries - const newEntries = [...entries, compactionEvent]; + // Simulate appending compaction to entries by creating a proper entry + const lastEntry = entries[entries.length - 1]; + const parentId = lastEntry.id; + const compactionEntry: CompactionEntry = { + type: "compaction", + id: "compaction-test-id", + parentId, + timestamp: new Date().toISOString(), + ...compactionResult, + }; + const newEntries = [...entries, compactionEntry]; const reloaded = buildSessionContext(newEntries); // Should have summary + kept messages expect(reloaded.messages.length).toBeLessThan(loaded.messages.length); - expect(reloaded.messages[0].role).toBe("user"); - expect((reloaded.messages[0] as any).content).toContain(compactionEvent.summary); + expect(reloaded.messages[0].role).toBe("compactionSummary"); + expect((reloaded.messages[0] as any).summary).toContain(compactionResult.summary); console.log("Original messages:", loaded.messages.length); console.log("After compaction:", reloaded.messages.length); diff --git a/packages/coding-agent/test/model-resolver.test.ts b/packages/coding-agent/test/model-resolver.test.ts index 59162d6f..0b7b47ca 100644 --- a/packages/coding-agent/test/model-resolver.test.ts +++ b/packages/coding-agent/test/model-resolver.test.ts @@ -66,21 +66,21 @@ describe("parseModelPattern", () => { const result = parseModelPattern("claude-sonnet-4-5", allModels); expect(result.model?.id).toBe("claude-sonnet-4-5"); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("partial match returns best model", () => { const result = parseModelPattern("sonnet", allModels); expect(result.model?.id).toBe("claude-sonnet-4-5"); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("no match returns null model", () => { const result = parseModelPattern("nonexistent", allModels); - expect(result.model).toBeNull(); + expect(result.model).toBeUndefined(); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); }); @@ -89,14 +89,14 @@ describe("parseModelPattern", () => { const result = parseModelPattern("sonnet:high", allModels); expect(result.model?.id).toBe("claude-sonnet-4-5"); expect(result.thinkingLevel).toBe("high"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("gpt-4o:medium returns gpt-4o with medium thinking level", () => { const result = parseModelPattern("gpt-4o:medium", allModels); expect(result.model?.id).toBe("gpt-4o"); expect(result.thinkingLevel).toBe("medium"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("all valid thinking levels work", () => { @@ -104,7 +104,7 @@ describe("parseModelPattern", () => { const result = parseModelPattern(`sonnet:${level}`, allModels); expect(result.model?.id).toBe("claude-sonnet-4-5"); expect(result.thinkingLevel).toBe(level); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); } }); }); @@ -131,7 +131,7 @@ describe("parseModelPattern", () => { const result = parseModelPattern("qwen/qwen3-coder:exacto", allModels); expect(result.model?.id).toBe("qwen/qwen3-coder:exacto"); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("openrouter/qwen/qwen3-coder:exacto matches with provider prefix", () => { @@ -139,14 +139,14 @@ describe("parseModelPattern", () => { expect(result.model?.id).toBe("qwen/qwen3-coder:exacto"); expect(result.model?.provider).toBe("openrouter"); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("qwen3-coder:exacto:high matches model with high thinking level", () => { const result = parseModelPattern("qwen/qwen3-coder:exacto:high", allModels); expect(result.model?.id).toBe("qwen/qwen3-coder:exacto"); expect(result.thinkingLevel).toBe("high"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("openrouter/qwen/qwen3-coder:exacto:high matches with provider and thinking level", () => { @@ -154,14 +154,14 @@ describe("parseModelPattern", () => { expect(result.model?.id).toBe("qwen/qwen3-coder:exacto"); expect(result.model?.provider).toBe("openrouter"); expect(result.thinkingLevel).toBe("high"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); test("gpt-4o:extended matches the extended model", () => { const result = parseModelPattern("openai/gpt-4o:extended", allModels); expect(result.model?.id).toBe("openai/gpt-4o:extended"); expect(result.thinkingLevel).toBe("off"); - expect(result.warning).toBeNull(); + expect(result.warning).toBeUndefined(); }); }); diff --git a/packages/coding-agent/test/rpc.test.ts b/packages/coding-agent/test/rpc.test.ts index 8c0b3aac..e0d6edac 100644 --- a/packages/coding-agent/test/rpc.test.ts +++ b/packages/coding-agent/test/rpc.test.ts @@ -273,7 +273,7 @@ describe.skipIf(!process.env.ANTHROPIC_API_KEY && !process.env.ANTHROPIC_OAUTH_T // Initially null let text = await client.getLastAssistantText(); - expect(text).toBeNull(); + expect(text).toBeUndefined(); // Send prompt await client.promptAndWait("Reply with just: test123"); diff --git a/packages/coding-agent/test/session-manager/build-context.test.ts b/packages/coding-agent/test/session-manager/build-context.test.ts new file mode 100644 index 00000000..5bc336a6 --- /dev/null +++ b/packages/coding-agent/test/session-manager/build-context.test.ts @@ -0,0 +1,268 @@ +import { describe, expect, it } from "vitest"; +import { + type BranchSummaryEntry, + buildSessionContext, + type CompactionEntry, + type ModelChangeEntry, + type SessionEntry, + type SessionMessageEntry, + type ThinkingLevelChangeEntry, +} from "../../src/core/session-manager.js"; + +function msg(id: string, parentId: string | null, role: "user" | "assistant", text: string): SessionMessageEntry { + const base = { type: "message" as const, id, parentId, timestamp: "2025-01-01T00:00:00Z" }; + if (role === "user") { + return { ...base, message: { role, content: text, timestamp: 1 } }; + } + return { + ...base, + message: { + role, + content: [{ type: "text", text }], + api: "anthropic-messages", + provider: "anthropic", + model: "claude-test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: 1, + }, + }; +} + +function compaction(id: string, parentId: string | null, summary: string, firstKeptEntryId: string): CompactionEntry { + return { + type: "compaction", + id, + parentId, + timestamp: "2025-01-01T00:00:00Z", + summary, + firstKeptEntryId, + tokensBefore: 1000, + }; +} + +function branchSummary(id: string, parentId: string | null, summary: string, fromId: string): BranchSummaryEntry { + return { type: "branch_summary", id, parentId, timestamp: "2025-01-01T00:00:00Z", summary, fromId }; +} + +function thinkingLevel(id: string, parentId: string | null, level: string): ThinkingLevelChangeEntry { + return { type: "thinking_level_change", id, parentId, timestamp: "2025-01-01T00:00:00Z", thinkingLevel: level }; +} + +function modelChange(id: string, parentId: string | null, provider: string, modelId: string): ModelChangeEntry { + return { type: "model_change", id, parentId, timestamp: "2025-01-01T00:00:00Z", provider, modelId }; +} + +describe("buildSessionContext", () => { + describe("trivial cases", () => { + it("empty entries returns empty context", () => { + const ctx = buildSessionContext([]); + expect(ctx.messages).toEqual([]); + expect(ctx.thinkingLevel).toBe("off"); + expect(ctx.model).toBeNull(); + }); + + it("single user message", () => { + const entries: SessionEntry[] = [msg("1", null, "user", "hello")]; + const ctx = buildSessionContext(entries); + expect(ctx.messages).toHaveLength(1); + expect(ctx.messages[0].role).toBe("user"); + }); + + it("simple conversation", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "hello"), + msg("2", "1", "assistant", "hi there"), + msg("3", "2", "user", "how are you"), + msg("4", "3", "assistant", "great"), + ]; + const ctx = buildSessionContext(entries); + expect(ctx.messages).toHaveLength(4); + expect(ctx.messages.map((m) => m.role)).toEqual(["user", "assistant", "user", "assistant"]); + }); + + it("tracks thinking level changes", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "hello"), + thinkingLevel("2", "1", "high"), + msg("3", "2", "assistant", "thinking hard"), + ]; + const ctx = buildSessionContext(entries); + expect(ctx.thinkingLevel).toBe("high"); + expect(ctx.messages).toHaveLength(2); + }); + + it("tracks model from assistant message", () => { + const entries: SessionEntry[] = [msg("1", null, "user", "hello"), msg("2", "1", "assistant", "hi")]; + const ctx = buildSessionContext(entries); + expect(ctx.model).toEqual({ provider: "anthropic", modelId: "claude-test" }); + }); + + it("tracks model from model change entry", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "hello"), + modelChange("2", "1", "openai", "gpt-4"), + msg("3", "2", "assistant", "hi"), + ]; + const ctx = buildSessionContext(entries); + // Assistant message overwrites model change + expect(ctx.model).toEqual({ provider: "anthropic", modelId: "claude-test" }); + }); + }); + + describe("with compaction", () => { + it("includes summary before kept messages", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "first"), + msg("2", "1", "assistant", "response1"), + msg("3", "2", "user", "second"), + msg("4", "3", "assistant", "response2"), + compaction("5", "4", "Summary of first two turns", "3"), + msg("6", "5", "user", "third"), + msg("7", "6", "assistant", "response3"), + ]; + const ctx = buildSessionContext(entries); + + // Should have: summary + kept (3,4) + after (6,7) = 5 messages + expect(ctx.messages).toHaveLength(5); + expect((ctx.messages[0] as any).summary).toContain("Summary of first two turns"); + expect((ctx.messages[1] as any).content).toBe("second"); + expect((ctx.messages[2] as any).content[0].text).toBe("response2"); + expect((ctx.messages[3] as any).content).toBe("third"); + expect((ctx.messages[4] as any).content[0].text).toBe("response3"); + }); + + it("handles compaction keeping from first message", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "first"), + msg("2", "1", "assistant", "response"), + compaction("3", "2", "Empty summary", "1"), + msg("4", "3", "user", "second"), + ]; + const ctx = buildSessionContext(entries); + + // Summary + all messages (1,2,4) + expect(ctx.messages).toHaveLength(4); + expect((ctx.messages[0] as any).summary).toContain("Empty summary"); + }); + + it("multiple compactions uses latest", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "a"), + msg("2", "1", "assistant", "b"), + compaction("3", "2", "First summary", "1"), + msg("4", "3", "user", "c"), + msg("5", "4", "assistant", "d"), + compaction("6", "5", "Second summary", "4"), + msg("7", "6", "user", "e"), + ]; + const ctx = buildSessionContext(entries); + + // Should use second summary, keep from 4 + expect(ctx.messages).toHaveLength(4); + expect((ctx.messages[0] as any).summary).toContain("Second summary"); + }); + }); + + describe("with branches", () => { + it("follows path to specified leaf", () => { + // Tree: + // 1 -> 2 -> 3 (branch A) + // \-> 4 (branch B) + const entries: SessionEntry[] = [ + msg("1", null, "user", "start"), + msg("2", "1", "assistant", "response"), + msg("3", "2", "user", "branch A"), + msg("4", "2", "user", "branch B"), + ]; + + const ctxA = buildSessionContext(entries, "3"); + expect(ctxA.messages).toHaveLength(3); + expect((ctxA.messages[2] as any).content).toBe("branch A"); + + const ctxB = buildSessionContext(entries, "4"); + expect(ctxB.messages).toHaveLength(3); + expect((ctxB.messages[2] as any).content).toBe("branch B"); + }); + + it("includes branch summary in path", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "start"), + msg("2", "1", "assistant", "response"), + msg("3", "2", "user", "abandoned path"), + branchSummary("4", "2", "Summary of abandoned work", "3"), + msg("5", "4", "user", "new direction"), + ]; + const ctx = buildSessionContext(entries, "5"); + + expect(ctx.messages).toHaveLength(4); + expect((ctx.messages[2] as any).summary).toContain("Summary of abandoned work"); + expect((ctx.messages[3] as any).content).toBe("new direction"); + }); + + it("complex tree with multiple branches and compaction", () => { + // Tree: + // 1 -> 2 -> 3 -> 4 -> compaction(5) -> 6 -> 7 (main path) + // \-> 8 -> 9 (abandoned branch) + // \-> branchSummary(10) -> 11 (resumed from 3) + const entries: SessionEntry[] = [ + msg("1", null, "user", "start"), + msg("2", "1", "assistant", "r1"), + msg("3", "2", "user", "q2"), + msg("4", "3", "assistant", "r2"), + compaction("5", "4", "Compacted history", "3"), + msg("6", "5", "user", "q3"), + msg("7", "6", "assistant", "r3"), + // Abandoned branch from 3 + msg("8", "3", "user", "wrong path"), + msg("9", "8", "assistant", "wrong response"), + // Branch summary resuming from 3 + branchSummary("10", "3", "Tried wrong approach", "9"), + msg("11", "10", "user", "better approach"), + ]; + + // Main path to 7: summary + kept(3,4) + after(6,7) + const ctxMain = buildSessionContext(entries, "7"); + expect(ctxMain.messages).toHaveLength(5); + expect((ctxMain.messages[0] as any).summary).toContain("Compacted history"); + expect((ctxMain.messages[1] as any).content).toBe("q2"); + expect((ctxMain.messages[2] as any).content[0].text).toBe("r2"); + expect((ctxMain.messages[3] as any).content).toBe("q3"); + expect((ctxMain.messages[4] as any).content[0].text).toBe("r3"); + + // Branch path to 11: 1,2,3 + branch_summary + 11 + const ctxBranch = buildSessionContext(entries, "11"); + expect(ctxBranch.messages).toHaveLength(5); + expect((ctxBranch.messages[0] as any).content).toBe("start"); + expect((ctxBranch.messages[1] as any).content[0].text).toBe("r1"); + expect((ctxBranch.messages[2] as any).content).toBe("q2"); + expect((ctxBranch.messages[3] as any).summary).toContain("Tried wrong approach"); + expect((ctxBranch.messages[4] as any).content).toBe("better approach"); + }); + }); + + describe("edge cases", () => { + it("uses last entry when leafId not found", () => { + const entries: SessionEntry[] = [msg("1", null, "user", "hello"), msg("2", "1", "assistant", "hi")]; + const ctx = buildSessionContext(entries, "nonexistent"); + expect(ctx.messages).toHaveLength(2); + }); + + it("handles orphaned entries gracefully", () => { + const entries: SessionEntry[] = [ + msg("1", null, "user", "hello"), + msg("2", "missing", "assistant", "orphan"), // parent doesn't exist + ]; + const ctx = buildSessionContext(entries, "2"); + // Should only get the orphan since parent chain is broken + expect(ctx.messages).toHaveLength(1); + }); + }); +}); diff --git a/packages/coding-agent/test/session-manager/file-operations.test.ts b/packages/coding-agent/test/session-manager/file-operations.test.ts new file mode 100644 index 00000000..4b598578 --- /dev/null +++ b/packages/coding-agent/test/session-manager/file-operations.test.ts @@ -0,0 +1,127 @@ +import { mkdirSync, rmSync, writeFileSync } from "fs"; +import { tmpdir } from "os"; +import { join } from "path"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { findMostRecentSession, loadEntriesFromFile } from "../../src/core/session-manager.js"; + +describe("loadEntriesFromFile", () => { + let tempDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `session-test-${Date.now()}`); + mkdirSync(tempDir, { recursive: true }); + }); + + afterEach(() => { + rmSync(tempDir, { recursive: true, force: true }); + }); + + it("returns empty array for non-existent file", () => { + const entries = loadEntriesFromFile(join(tempDir, "nonexistent.jsonl")); + expect(entries).toEqual([]); + }); + + it("returns empty array for empty file", () => { + const file = join(tempDir, "empty.jsonl"); + writeFileSync(file, ""); + expect(loadEntriesFromFile(file)).toEqual([]); + }); + + it("returns empty array for file without valid session header", () => { + const file = join(tempDir, "no-header.jsonl"); + writeFileSync(file, '{"type":"message","id":"1"}\n'); + expect(loadEntriesFromFile(file)).toEqual([]); + }); + + it("returns empty array for malformed JSON", () => { + const file = join(tempDir, "malformed.jsonl"); + writeFileSync(file, "not json\n"); + expect(loadEntriesFromFile(file)).toEqual([]); + }); + + it("loads valid session file", () => { + const file = join(tempDir, "valid.jsonl"); + writeFileSync( + file, + '{"type":"session","id":"abc","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n' + + '{"type":"message","id":"1","parentId":null,"timestamp":"2025-01-01T00:00:01Z","message":{"role":"user","content":"hi","timestamp":1}}\n', + ); + const entries = loadEntriesFromFile(file); + expect(entries).toHaveLength(2); + expect(entries[0].type).toBe("session"); + expect(entries[1].type).toBe("message"); + }); + + it("skips malformed lines but keeps valid ones", () => { + const file = join(tempDir, "mixed.jsonl"); + writeFileSync( + file, + '{"type":"session","id":"abc","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n' + + "not valid json\n" + + '{"type":"message","id":"1","parentId":null,"timestamp":"2025-01-01T00:00:01Z","message":{"role":"user","content":"hi","timestamp":1}}\n', + ); + const entries = loadEntriesFromFile(file); + expect(entries).toHaveLength(2); + }); +}); + +describe("findMostRecentSession", () => { + let tempDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `session-test-${Date.now()}`); + mkdirSync(tempDir, { recursive: true }); + }); + + afterEach(() => { + rmSync(tempDir, { recursive: true, force: true }); + }); + + it("returns null for empty directory", () => { + expect(findMostRecentSession(tempDir)).toBeNull(); + }); + + it("returns null for non-existent directory", () => { + expect(findMostRecentSession(join(tempDir, "nonexistent"))).toBeNull(); + }); + + it("ignores non-jsonl files", () => { + writeFileSync(join(tempDir, "file.txt"), "hello"); + writeFileSync(join(tempDir, "file.json"), "{}"); + expect(findMostRecentSession(tempDir)).toBeNull(); + }); + + it("ignores jsonl files without valid session header", () => { + writeFileSync(join(tempDir, "invalid.jsonl"), '{"type":"message"}\n'); + expect(findMostRecentSession(tempDir)).toBeNull(); + }); + + it("returns single valid session file", () => { + const file = join(tempDir, "session.jsonl"); + writeFileSync(file, '{"type":"session","id":"abc","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n'); + expect(findMostRecentSession(tempDir)).toBe(file); + }); + + it("returns most recently modified session", async () => { + const file1 = join(tempDir, "older.jsonl"); + const file2 = join(tempDir, "newer.jsonl"); + + writeFileSync(file1, '{"type":"session","id":"old","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n'); + // Small delay to ensure different mtime + await new Promise((r) => setTimeout(r, 10)); + writeFileSync(file2, '{"type":"session","id":"new","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n'); + + expect(findMostRecentSession(tempDir)).toBe(file2); + }); + + it("skips invalid files and returns valid one", async () => { + const invalid = join(tempDir, "invalid.jsonl"); + const valid = join(tempDir, "valid.jsonl"); + + writeFileSync(invalid, '{"type":"not-session"}\n'); + await new Promise((r) => setTimeout(r, 10)); + writeFileSync(valid, '{"type":"session","id":"abc","timestamp":"2025-01-01T00:00:00Z","cwd":"/tmp"}\n'); + + expect(findMostRecentSession(tempDir)).toBe(valid); + }); +}); diff --git a/packages/coding-agent/test/session-manager/labels.test.ts b/packages/coding-agent/test/session-manager/labels.test.ts new file mode 100644 index 00000000..e349aa10 --- /dev/null +++ b/packages/coding-agent/test/session-manager/labels.test.ts @@ -0,0 +1,178 @@ +import { describe, expect, it } from "vitest"; +import { type LabelEntry, SessionManager } from "../../src/core/session-manager.js"; + +describe("SessionManager labels", () => { + it("sets and gets labels", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + + // No label initially + expect(session.getLabel(msgId)).toBeUndefined(); + + // Set a label + const labelId = session.appendLabelChange(msgId, "checkpoint"); + expect(session.getLabel(msgId)).toBe("checkpoint"); + + // Label entry should be in entries + const entries = session.getEntries(); + const labelEntry = entries.find((e) => e.type === "label") as LabelEntry; + expect(labelEntry).toBeDefined(); + expect(labelEntry.id).toBe(labelId); + expect(labelEntry.targetId).toBe(msgId); + expect(labelEntry.label).toBe("checkpoint"); + }); + + it("clears labels with undefined", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + + session.appendLabelChange(msgId, "checkpoint"); + expect(session.getLabel(msgId)).toBe("checkpoint"); + + // Clear the label + session.appendLabelChange(msgId, undefined); + expect(session.getLabel(msgId)).toBeUndefined(); + }); + + it("last label wins", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + + session.appendLabelChange(msgId, "first"); + session.appendLabelChange(msgId, "second"); + session.appendLabelChange(msgId, "third"); + + expect(session.getLabel(msgId)).toBe("third"); + }); + + it("labels are included in tree nodes", () => { + const session = SessionManager.inMemory(); + + const msg1Id = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + const msg2Id = session.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "hi" }], + api: "anthropic-messages", + provider: "anthropic", + model: "test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: 2, + }); + + session.appendLabelChange(msg1Id, "start"); + session.appendLabelChange(msg2Id, "response"); + + const tree = session.getTree(); + + // Find the message nodes (skip label entries) + const msg1Node = tree.find((n) => n.entry.id === msg1Id); + expect(msg1Node?.label).toBe("start"); + + // msg2 is a child of msg1 + const msg2Node = msg1Node?.children.find((n) => n.entry.id === msg2Id); + expect(msg2Node?.label).toBe("response"); + }); + + it("labels are preserved in createBranchedSession", () => { + const session = SessionManager.inMemory(); + + const msg1Id = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + const msg2Id = session.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "hi" }], + api: "anthropic-messages", + provider: "anthropic", + model: "test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: 2, + }); + + session.appendLabelChange(msg1Id, "important"); + session.appendLabelChange(msg2Id, "also-important"); + + // Branch from msg2 (in-memory mode returns null, but updates internal state) + session.createBranchedSession(msg2Id); + + // Labels should be preserved + expect(session.getLabel(msg1Id)).toBe("important"); + expect(session.getLabel(msg2Id)).toBe("also-important"); + + // New label entries should exist + const entries = session.getEntries(); + const labelEntries = entries.filter((e) => e.type === "label") as LabelEntry[]; + expect(labelEntries).toHaveLength(2); + }); + + it("labels not on path are not preserved in createBranchedSession", () => { + const session = SessionManager.inMemory(); + + const msg1Id = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + const msg2Id = session.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "hi" }], + api: "anthropic-messages", + provider: "anthropic", + model: "test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: 2, + }); + const msg3Id = session.appendMessage({ role: "user", content: "followup", timestamp: 3 }); + + // Label all messages + session.appendLabelChange(msg1Id, "first"); + session.appendLabelChange(msg2Id, "second"); + session.appendLabelChange(msg3Id, "third"); + + // Branch from msg2 (excludes msg3) + session.createBranchedSession(msg2Id); + + // Only labels for msg1 and msg2 should be preserved + expect(session.getLabel(msg1Id)).toBe("first"); + expect(session.getLabel(msg2Id)).toBe("second"); + expect(session.getLabel(msg3Id)).toBeUndefined(); + }); + + it("labels are not included in buildSessionContext", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + session.appendLabelChange(msgId, "checkpoint"); + + const ctx = session.buildSessionContext(); + expect(ctx.messages).toHaveLength(1); + expect(ctx.messages[0].role).toBe("user"); + }); + + it("throws when labeling non-existent entry", () => { + const session = SessionManager.inMemory(); + + expect(() => session.appendLabelChange("non-existent", "label")).toThrow("Entry non-existent not found"); + }); +}); diff --git a/packages/coding-agent/test/session-manager/migration.test.ts b/packages/coding-agent/test/session-manager/migration.test.ts new file mode 100644 index 00000000..129ba87d --- /dev/null +++ b/packages/coding-agent/test/session-manager/migration.test.ts @@ -0,0 +1,78 @@ +import { describe, expect, it } from "vitest"; +import { type FileEntry, migrateSessionEntries } from "../../src/core/session-manager.js"; + +describe("migrateSessionEntries", () => { + it("should add id/parentId to v1 entries", () => { + const entries: FileEntry[] = [ + { type: "session", id: "sess-1", timestamp: "2025-01-01T00:00:00Z", cwd: "/tmp" }, + { type: "message", timestamp: "2025-01-01T00:00:01Z", message: { role: "user", content: "hi", timestamp: 1 } }, + { + type: "message", + timestamp: "2025-01-01T00:00:02Z", + message: { + role: "assistant", + content: [{ type: "text", text: "hello" }], + api: "test", + provider: "test", + model: "test", + usage: { input: 1, output: 1, cacheRead: 0, cacheWrite: 0 }, + stopReason: "stop", + timestamp: 2, + }, + }, + ] as FileEntry[]; + + migrateSessionEntries(entries); + + // Header should have version set + expect((entries[0] as any).version).toBe(2); + + // Entries should have id/parentId + const msg1 = entries[1] as any; + const msg2 = entries[2] as any; + + expect(msg1.id).toBeDefined(); + expect(msg1.id.length).toBe(8); + expect(msg1.parentId).toBeNull(); + + expect(msg2.id).toBeDefined(); + expect(msg2.id.length).toBe(8); + expect(msg2.parentId).toBe(msg1.id); + }); + + it("should be idempotent (skip already migrated)", () => { + const entries: FileEntry[] = [ + { type: "session", id: "sess-1", version: 2, timestamp: "2025-01-01T00:00:00Z", cwd: "/tmp" }, + { + type: "message", + id: "abc12345", + parentId: null, + timestamp: "2025-01-01T00:00:01Z", + message: { role: "user", content: "hi", timestamp: 1 }, + }, + { + type: "message", + id: "def67890", + parentId: "abc12345", + timestamp: "2025-01-01T00:00:02Z", + message: { + role: "assistant", + content: [{ type: "text", text: "hello" }], + api: "test", + provider: "test", + model: "test", + usage: { input: 1, output: 1, cacheRead: 0, cacheWrite: 0 }, + stopReason: "stop", + timestamp: 2, + }, + }, + ] as FileEntry[]; + + migrateSessionEntries(entries); + + // IDs should be unchanged + expect((entries[1] as any).id).toBe("abc12345"); + expect((entries[2] as any).id).toBe("def67890"); + expect((entries[2] as any).parentId).toBe("abc12345"); + }); +}); diff --git a/packages/coding-agent/test/session-manager/save-entry.test.ts b/packages/coding-agent/test/session-manager/save-entry.test.ts new file mode 100644 index 00000000..2a618986 --- /dev/null +++ b/packages/coding-agent/test/session-manager/save-entry.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, it } from "vitest"; +import { type CustomEntry, SessionManager } from "../../src/core/session-manager.js"; + +describe("SessionManager.saveCustomEntry", () => { + it("saves custom entries and includes them in tree traversal", () => { + const session = SessionManager.inMemory(); + + // Save a message + const msgId = session.appendMessage({ role: "user", content: "hello", timestamp: 1 }); + + // Save a custom entry + const customId = session.appendCustomEntry("my_hook", { foo: "bar" }); + + // Save another message + const msg2Id = session.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "hi" }], + api: "anthropic-messages", + provider: "anthropic", + model: "test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: 2, + }); + + // Custom entry should be in entries + const entries = session.getEntries(); + expect(entries).toHaveLength(3); + + const customEntry = entries.find((e) => e.type === "custom") as CustomEntry; + expect(customEntry).toBeDefined(); + expect(customEntry.customType).toBe("my_hook"); + expect(customEntry.data).toEqual({ foo: "bar" }); + expect(customEntry.id).toBe(customId); + expect(customEntry.parentId).toBe(msgId); + + // Tree structure should be correct + const path = session.getPath(); + expect(path).toHaveLength(3); + expect(path[0].id).toBe(msgId); + expect(path[1].id).toBe(customId); + expect(path[2].id).toBe(msg2Id); + + // buildSessionContext should work (custom entries skipped in messages) + const ctx = session.buildSessionContext(); + expect(ctx.messages).toHaveLength(2); // only message entries + }); +}); diff --git a/packages/coding-agent/test/session-manager/tree-traversal.test.ts b/packages/coding-agent/test/session-manager/tree-traversal.test.ts new file mode 100644 index 00000000..5fe7610a --- /dev/null +++ b/packages/coding-agent/test/session-manager/tree-traversal.test.ts @@ -0,0 +1,460 @@ +import { describe, expect, it } from "vitest"; +import { type CustomEntry, SessionManager } from "../../src/core/session-manager.js"; +import { assistantMsg, userMsg } from "../utilities.js"; + +describe("SessionManager append and tree traversal", () => { + describe("append operations", () => { + it("appendMessage creates entry with correct parentId chain", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("first")); + const id2 = session.appendMessage(assistantMsg("second")); + const id3 = session.appendMessage(userMsg("third")); + + const entries = session.getEntries(); + expect(entries).toHaveLength(3); + + expect(entries[0].id).toBe(id1); + expect(entries[0].parentId).toBeNull(); + expect(entries[0].type).toBe("message"); + + expect(entries[1].id).toBe(id2); + expect(entries[1].parentId).toBe(id1); + + expect(entries[2].id).toBe(id3); + expect(entries[2].parentId).toBe(id2); + }); + + it("appendThinkingLevelChange integrates into tree", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage(userMsg("hello")); + const thinkingId = session.appendThinkingLevelChange("high"); + const _msg2Id = session.appendMessage(assistantMsg("response")); + + const entries = session.getEntries(); + expect(entries).toHaveLength(3); + + const thinkingEntry = entries.find((e) => e.type === "thinking_level_change"); + expect(thinkingEntry).toBeDefined(); + expect(thinkingEntry!.id).toBe(thinkingId); + expect(thinkingEntry!.parentId).toBe(msgId); + + expect(entries[2].parentId).toBe(thinkingId); + }); + + it("appendModelChange integrates into tree", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage(userMsg("hello")); + const modelId = session.appendModelChange("openai", "gpt-4"); + const _msg2Id = session.appendMessage(assistantMsg("response")); + + const entries = session.getEntries(); + const modelEntry = entries.find((e) => e.type === "model_change"); + expect(modelEntry).toBeDefined(); + expect(modelEntry?.id).toBe(modelId); + expect(modelEntry?.parentId).toBe(msgId); + if (modelEntry?.type === "model_change") { + expect(modelEntry.provider).toBe("openai"); + expect(modelEntry.modelId).toBe("gpt-4"); + } + + expect(entries[2].parentId).toBe(modelId); + }); + + it("appendCompaction integrates into tree", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const compactionId = session.appendCompaction("summary", id1, 1000); + const _id3 = session.appendMessage(userMsg("3")); + + const entries = session.getEntries(); + const compactionEntry = entries.find((e) => e.type === "compaction"); + expect(compactionEntry).toBeDefined(); + expect(compactionEntry?.id).toBe(compactionId); + expect(compactionEntry?.parentId).toBe(id2); + if (compactionEntry?.type === "compaction") { + expect(compactionEntry.summary).toBe("summary"); + expect(compactionEntry.firstKeptEntryId).toBe(id1); + expect(compactionEntry.tokensBefore).toBe(1000); + } + + expect(entries[3].parentId).toBe(compactionId); + }); + + it("appendCustomEntry integrates into tree", () => { + const session = SessionManager.inMemory(); + + const msgId = session.appendMessage(userMsg("hello")); + const customId = session.appendCustomEntry("my_hook", { key: "value" }); + const _msg2Id = session.appendMessage(assistantMsg("response")); + + const entries = session.getEntries(); + const customEntry = entries.find((e) => e.type === "custom") as CustomEntry; + expect(customEntry).toBeDefined(); + expect(customEntry.id).toBe(customId); + expect(customEntry.parentId).toBe(msgId); + expect(customEntry.customType).toBe("my_hook"); + expect(customEntry.data).toEqual({ key: "value" }); + + expect(entries[2].parentId).toBe(customId); + }); + + it("leaf pointer advances after each append", () => { + const session = SessionManager.inMemory(); + + expect(session.getLeafId()).toBeNull(); + + const id1 = session.appendMessage(userMsg("1")); + expect(session.getLeafId()).toBe(id1); + + const id2 = session.appendMessage(assistantMsg("2")); + expect(session.getLeafId()).toBe(id2); + + const id3 = session.appendThinkingLevelChange("high"); + expect(session.getLeafId()).toBe(id3); + }); + }); + + describe("getPath", () => { + it("returns empty array for empty session", () => { + const session = SessionManager.inMemory(); + expect(session.getPath()).toEqual([]); + }); + + it("returns single entry path", () => { + const session = SessionManager.inMemory(); + const id = session.appendMessage(userMsg("hello")); + + const path = session.getPath(); + expect(path).toHaveLength(1); + expect(path[0].id).toBe(id); + }); + + it("returns full path from root to leaf", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendThinkingLevelChange("high"); + const id4 = session.appendMessage(userMsg("3")); + + const path = session.getPath(); + expect(path).toHaveLength(4); + expect(path.map((e) => e.id)).toEqual([id1, id2, id3, id4]); + }); + + it("returns path from specified entry to root", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const _id3 = session.appendMessage(userMsg("3")); + const _id4 = session.appendMessage(assistantMsg("4")); + + const path = session.getPath(id2); + expect(path).toHaveLength(2); + expect(path.map((e) => e.id)).toEqual([id1, id2]); + }); + }); + + describe("getTree", () => { + it("returns empty array for empty session", () => { + const session = SessionManager.inMemory(); + expect(session.getTree()).toEqual([]); + }); + + it("returns single root for linear session", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendMessage(userMsg("3")); + + const tree = session.getTree(); + expect(tree).toHaveLength(1); + + const root = tree[0]; + expect(root.entry.id).toBe(id1); + expect(root.children).toHaveLength(1); + expect(root.children[0].entry.id).toBe(id2); + expect(root.children[0].children).toHaveLength(1); + expect(root.children[0].children[0].entry.id).toBe(id3); + expect(root.children[0].children[0].children).toHaveLength(0); + }); + + it("returns tree with branches after branch", () => { + const session = SessionManager.inMemory(); + + // Build: 1 -> 2 -> 3 + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendMessage(userMsg("3")); + + // Branch from id2, add new path: 2 -> 4 + session.branch(id2); + const id4 = session.appendMessage(userMsg("4-branch")); + + const tree = session.getTree(); + expect(tree).toHaveLength(1); + + const root = tree[0]; + expect(root.entry.id).toBe(id1); + expect(root.children).toHaveLength(1); + + const node2 = root.children[0]; + expect(node2.entry.id).toBe(id2); + expect(node2.children).toHaveLength(2); // id3 and id4 are siblings + + const childIds = node2.children.map((c) => c.entry.id).sort(); + expect(childIds).toEqual([id3, id4].sort()); + }); + + it("handles multiple branches at same point", () => { + const session = SessionManager.inMemory(); + + const _id1 = session.appendMessage(userMsg("root")); + const id2 = session.appendMessage(assistantMsg("response")); + + // Branch A + session.branch(id2); + const idA = session.appendMessage(userMsg("branch-A")); + + // Branch B + session.branch(id2); + const idB = session.appendMessage(userMsg("branch-B")); + + // Branch C + session.branch(id2); + const idC = session.appendMessage(userMsg("branch-C")); + + const tree = session.getTree(); + const node2 = tree[0].children[0]; + expect(node2.entry.id).toBe(id2); + expect(node2.children).toHaveLength(3); + + const branchIds = node2.children.map((c) => c.entry.id).sort(); + expect(branchIds).toEqual([idA, idB, idC].sort()); + }); + + it("handles deep branching", () => { + const session = SessionManager.inMemory(); + + // Main path: 1 -> 2 -> 3 -> 4 + const _id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendMessage(userMsg("3")); + const _id4 = session.appendMessage(assistantMsg("4")); + + // Branch from 2: 2 -> 5 -> 6 + session.branch(id2); + const id5 = session.appendMessage(userMsg("5")); + const _id6 = session.appendMessage(assistantMsg("6")); + + // Branch from 5: 5 -> 7 + session.branch(id5); + const _id7 = session.appendMessage(userMsg("7")); + + const tree = session.getTree(); + + // Verify structure + const node2 = tree[0].children[0]; + expect(node2.children).toHaveLength(2); // id3 and id5 + + const node5 = node2.children.find((c) => c.entry.id === id5)!; + expect(node5.children).toHaveLength(2); // id6 and id7 + + const node3 = node2.children.find((c) => c.entry.id === id3)!; + expect(node3.children).toHaveLength(1); // id4 + }); + }); + + describe("branch", () => { + it("moves leaf pointer to specified entry", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const _id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendMessage(userMsg("3")); + + expect(session.getLeafId()).toBe(id3); + + session.branch(id1); + expect(session.getLeafId()).toBe(id1); + }); + + it("throws for non-existent entry", () => { + const session = SessionManager.inMemory(); + session.appendMessage(userMsg("hello")); + + expect(() => session.branch("nonexistent")).toThrow("Entry nonexistent not found"); + }); + + it("new appends become children of branch point", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const _id2 = session.appendMessage(assistantMsg("2")); + + session.branch(id1); + const id3 = session.appendMessage(userMsg("branched")); + + const entries = session.getEntries(); + const branchedEntry = entries.find((e) => e.id === id3)!; + expect(branchedEntry.parentId).toBe(id1); // sibling of id2 + }); + }); + + describe("branchWithSummary", () => { + it("inserts branch summary and advances leaf", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("1")); + const _id2 = session.appendMessage(assistantMsg("2")); + const _id3 = session.appendMessage(userMsg("3")); + + const summaryId = session.branchWithSummary(id1, "Summary of abandoned work"); + + expect(session.getLeafId()).toBe(summaryId); + + const entries = session.getEntries(); + const summaryEntry = entries.find((e) => e.type === "branch_summary"); + expect(summaryEntry).toBeDefined(); + expect(summaryEntry?.parentId).toBe(id1); + if (summaryEntry?.type === "branch_summary") { + expect(summaryEntry.summary).toBe("Summary of abandoned work"); + } + }); + + it("throws for non-existent entry", () => { + const session = SessionManager.inMemory(); + session.appendMessage(userMsg("hello")); + + expect(() => session.branchWithSummary("nonexistent", "summary")).toThrow("Entry nonexistent not found"); + }); + }); + + describe("getLeafEntry", () => { + it("returns undefined for empty session", () => { + const session = SessionManager.inMemory(); + expect(session.getLeafEntry()).toBeUndefined(); + }); + + it("returns current leaf entry", () => { + const session = SessionManager.inMemory(); + + session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + + const leaf = session.getLeafEntry(); + expect(leaf).toBeDefined(); + expect(leaf!.id).toBe(id2); + }); + }); + + describe("getEntry", () => { + it("returns undefined for non-existent id", () => { + const session = SessionManager.inMemory(); + expect(session.getEntry("nonexistent")).toBeUndefined(); + }); + + it("returns entry by id", () => { + const session = SessionManager.inMemory(); + + const id1 = session.appendMessage(userMsg("first")); + const id2 = session.appendMessage(assistantMsg("second")); + + const entry1 = session.getEntry(id1); + expect(entry1).toBeDefined(); + expect(entry1?.type).toBe("message"); + if (entry1?.type === "message" && entry1.message.role === "user") { + expect(entry1.message.content).toBe("first"); + } + + const entry2 = session.getEntry(id2); + expect(entry2).toBeDefined(); + if (entry2?.type === "message" && entry2.message.role === "assistant") { + expect((entry2.message.content as any)[0].text).toBe("second"); + } + }); + }); + + describe("buildSessionContext with branches", () => { + it("returns messages from current branch only", () => { + const session = SessionManager.inMemory(); + + // Main: 1 -> 2 -> 3 + session.appendMessage(userMsg("msg1")); + const id2 = session.appendMessage(assistantMsg("msg2")); + session.appendMessage(userMsg("msg3")); + + // Branch from 2: 2 -> 4 + session.branch(id2); + session.appendMessage(assistantMsg("msg4-branch")); + + const ctx = session.buildSessionContext(); + expect(ctx.messages).toHaveLength(3); // msg1, msg2, msg4-branch (not msg3) + + expect((ctx.messages[0] as any).content).toBe("msg1"); + expect((ctx.messages[1] as any).content[0].text).toBe("msg2"); + expect((ctx.messages[2] as any).content[0].text).toBe("msg4-branch"); + }); + }); +}); + +describe("createBranchedSession", () => { + it("throws for non-existent entry", () => { + const session = SessionManager.inMemory(); + session.appendMessage(userMsg("hello")); + + expect(() => session.createBranchedSession("nonexistent")).toThrow("Entry nonexistent not found"); + }); + + it("creates new session with path to specified leaf (in-memory)", () => { + const session = SessionManager.inMemory(); + + // Build: 1 -> 2 -> 3 -> 4 + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + const id3 = session.appendMessage(userMsg("3")); + session.appendMessage(assistantMsg("4")); + + // Branch from 3: 3 -> 5 + session.branch(id3); + const _id5 = session.appendMessage(userMsg("5")); + + // Create branched session from id2 (should only have 1 -> 2) + const result = session.createBranchedSession(id2); + expect(result).toBeUndefined(); // in-memory returns null + + // Session should now only have entries 1 and 2 + const entries = session.getEntries(); + expect(entries).toHaveLength(2); + expect(entries[0].id).toBe(id1); + expect(entries[1].id).toBe(id2); + }); + + it("extracts correct path from branched tree", () => { + const session = SessionManager.inMemory(); + + // Build: 1 -> 2 -> 3 + const id1 = session.appendMessage(userMsg("1")); + const id2 = session.appendMessage(assistantMsg("2")); + session.appendMessage(userMsg("3")); + + // Branch from 2: 2 -> 4 -> 5 + session.branch(id2); + const id4 = session.appendMessage(userMsg("4")); + const id5 = session.appendMessage(assistantMsg("5")); + + // Create branched session from id5 (should have 1 -> 2 -> 4 -> 5) + session.createBranchedSession(id5); + + const entries = session.getEntries(); + expect(entries).toHaveLength(4); + expect(entries.map((e) => e.id)).toEqual([id1, id2, id4, id5]); + }); +}); diff --git a/packages/coding-agent/test/utilities.ts b/packages/coding-agent/test/utilities.ts new file mode 100644 index 00000000..3fe5d4cc --- /dev/null +++ b/packages/coding-agent/test/utilities.ts @@ -0,0 +1,158 @@ +/** + * Shared test utilities for coding-agent tests. + */ + +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { Agent } from "@mariozechner/pi-agent-core"; +import { getModel } from "@mariozechner/pi-ai"; +import { AgentSession } from "../src/core/agent-session.js"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { ModelRegistry } from "../src/core/model-registry.js"; +import { SessionManager } from "../src/core/session-manager.js"; +import { SettingsManager } from "../src/core/settings-manager.js"; +import { codingTools } from "../src/core/tools/index.js"; + +/** + * API key for authenticated tests. Tests using this should be wrapped in + * describe.skipIf(!API_KEY) + */ +export const API_KEY = process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; + +/** + * Create a minimal user message for testing. + */ +export function userMsg(text: string) { + return { role: "user" as const, content: text, timestamp: Date.now() }; +} + +/** + * Create a minimal assistant message for testing. + */ +export function assistantMsg(text: string) { + return { + role: "assistant" as const, + content: [{ type: "text" as const, text }], + api: "anthropic-messages" as const, + provider: "anthropic", + model: "test", + usage: { + input: 1, + output: 1, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 2, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop" as const, + timestamp: Date.now(), + }; +} + +/** + * Options for creating a test session. + */ +export interface TestSessionOptions { + /** Use in-memory session (no file persistence) */ + inMemory?: boolean; + /** Custom system prompt */ + systemPrompt?: string; + /** Custom settings overrides */ + settingsOverrides?: Record; +} + +/** + * Resources returned by createTestSession that need cleanup. + */ +export interface TestSessionContext { + session: AgentSession; + sessionManager: SessionManager; + tempDir: string; + cleanup: () => void; +} + +/** + * Create an AgentSession for testing with proper setup and cleanup. + * Use this for e2e tests that need real LLM calls. + */ +export function createTestSession(options: TestSessionOptions = {}): TestSessionContext { + const tempDir = join(tmpdir(), `pi-test-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(tempDir, { recursive: true }); + + const model = getModel("anthropic", "claude-sonnet-4-5")!; + const agent = new Agent({ + getApiKey: () => API_KEY, + initialState: { + model, + systemPrompt: options.systemPrompt ?? "You are a helpful assistant. Be extremely concise.", + tools: codingTools, + }, + }); + + const sessionManager = options.inMemory ? SessionManager.inMemory() : SessionManager.create(tempDir); + const settingsManager = SettingsManager.create(tempDir, tempDir); + + if (options.settingsOverrides) { + settingsManager.applyOverrides(options.settingsOverrides); + } + + const authStorage = new AuthStorage(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + + const session = new AgentSession({ + agent, + sessionManager, + settingsManager, + modelRegistry, + }); + + // Must subscribe to enable session persistence + session.subscribe(() => {}); + + const cleanup = () => { + session.dispose(); + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }; + + return { session, sessionManager, tempDir, cleanup }; +} + +/** + * Build a session tree for testing using SessionManager. + * Returns the IDs of all created entries. + * + * Example tree structure: + * ``` + * u1 -> a1 -> u2 -> a2 + * -> u3 -> a3 (branch from a1) + * u4 -> a4 (another root) + * ``` + */ +export function buildTestTree( + session: SessionManager, + structure: { + messages: Array<{ role: "user" | "assistant"; text: string; branchFrom?: string }>; + }, +): Map { + const ids = new Map(); + + for (const msg of structure.messages) { + if (msg.branchFrom) { + const branchFromId = ids.get(msg.branchFrom); + if (!branchFromId) { + throw new Error(`Cannot branch from unknown entry: ${msg.branchFrom}`); + } + session.branch(branchFromId); + } + + const id = + msg.role === "user" ? session.appendMessage(userMsg(msg.text)) : session.appendMessage(assistantMsg(msg.text)); + + ids.set(msg.text, id); + } + + return ids; +} diff --git a/packages/mom/src/agent.ts b/packages/mom/src/agent.ts index b5cb6333..2be0bf3b 100644 --- a/packages/mom/src/agent.ts +++ b/packages/mom/src/agent.ts @@ -1,15 +1,15 @@ -import { Agent, type AgentEvent, type Attachment, ProviderTransport } from "@mariozechner/pi-agent-core"; -import { getModel } from "@mariozechner/pi-ai"; +import { Agent, type AgentEvent } from "@mariozechner/pi-agent-core"; +import { getModel, type ImageContent } from "@mariozechner/pi-ai"; import { AgentSession, AuthStorage, + convertToLlm, formatSkillsForPrompt, loadSkillsFromDir, ModelRegistry, - messageTransformer, type Skill, } from "@mariozechner/pi-coding-agent"; -import { existsSync, readFileSync, statSync } from "fs"; +import { existsSync, readFileSync } from "fs"; import { mkdir, writeFile } from "fs/promises"; import { homedir } from "os"; import { join } from "path"; @@ -434,14 +434,12 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi thinkingLevel: "off", tools, }, - messageTransformer, - transport: new ProviderTransport({ - getApiKey: async () => getAnthropicApiKey(authStorage), - }), + convertToLlm, + getApiKey: async () => getAnthropicApiKey(authStorage), }); // Load existing messages - const loadedSession = sessionManager.loadSession(); + const loadedSession = sessionManager.buildSessionContex(); if (loadedSession.messages.length > 0) { agent.replaceMessages(loadedSession.messages); log.logInfo(`[${channelId}] Loaded ${loadedSession.messages.length} messages from context.jsonl`); @@ -628,7 +626,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi // Reload messages from context.jsonl // This picks up any messages synced from log.jsonl before this run - const reloadedSession = sessionManager.loadSession(); + const reloadedSession = sessionManager.buildSessionContex(); if (reloadedSession.messages.length > 0) { agent.replaceMessages(reloadedSession.messages); log.logInfo(`[${channelId}] Reloaded ${reloadedSession.messages.length} messages from context`); @@ -716,7 +714,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi const timestamp = `${now.getFullYear()}-${pad(now.getMonth() + 1)}-${pad(now.getDate())} ${pad(now.getHours())}:${pad(now.getMinutes())}:${pad(now.getSeconds())}${offsetSign}${offsetHours}:${offsetMins}`; let userMessage = `[${timestamp}] [${ctx.message.userName || "unknown"}]: ${ctx.message.text}`; - const imageAttachments: Attachment[] = []; + const imageAttachments: ImageContent[] = []; const nonImagePaths: string[] = []; for (const a of ctx.message.attachments || []) { @@ -725,14 +723,10 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi if (mimeType && existsSync(fullPath)) { try { - const stats = statSync(fullPath); imageAttachments.push({ - id: a.local, type: "image", - fileName: a.local.split("/").pop() || a.local, mimeType, - size: stats.size, - content: readFileSync(fullPath).toString("base64"), + data: readFileSync(fullPath).toString("base64"), }); } catch { nonImagePaths.push(fullPath); @@ -755,7 +749,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi }; await writeFile(join(channelDir, "last_prompt.jsonl"), JSON.stringify(debugContext, null, 2)); - await session.prompt(userMessage, imageAttachments.length > 0 ? { attachments: imageAttachments } : undefined); + await session.prompt(userMessage, imageAttachments.length > 0 ? { images: imageAttachments } : undefined); // Wait for queued messages await queueChain; diff --git a/packages/mom/src/context.ts b/packages/mom/src/context.ts index 00920ae4..11f8a69c 100644 --- a/packages/mom/src/context.ts +++ b/packages/mom/src/context.ts @@ -10,13 +10,15 @@ * - MomSettingsManager: Simple settings for mom (compaction, retry, model preferences) */ -import type { AppMessage } from "@mariozechner/pi-agent-core"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; import { buildSessionContext, type CompactionEntry, - type LoadedSession, + type FileEntry, type ModelChangeEntry, + type SessionContext, type SessionEntry, + type SessionEntryBase, type SessionMessageEntry, type ThinkingLevelChangeEntry, } from "@mariozechner/pi-coding-agent"; @@ -48,7 +50,8 @@ export class MomSessionManager { private logFile: string; private channelDir: string; private flushed: boolean = false; - private inMemoryEntries: SessionEntry[] = []; + private inMemoryEntries: FileEntry[] = []; + private leafId: string | null = null; constructor(channelDir: string) { this.channelDir = channelDir; @@ -64,12 +67,14 @@ export class MomSessionManager { if (existsSync(this.contextFile)) { this.inMemoryEntries = this.loadEntriesFromFile(); this.sessionId = this.extractSessionId() || uuidv4(); + this._updateLeafId(); this.flushed = true; } else { this.sessionId = uuidv4(); this.inMemoryEntries = [ { type: "session", + version: 2, id: this.sessionId, timestamp: new Date().toISOString(), cwd: this.channelDir, @@ -79,6 +84,28 @@ export class MomSessionManager { // Note: syncFromLog() is called explicitly from agent.ts with excludeTimestamp } + private _updateLeafId(): void { + for (let i = this.inMemoryEntries.length - 1; i >= 0; i--) { + const entry = this.inMemoryEntries[i]; + if (entry.type !== "session") { + this.leafId = entry.id; + return; + } + } + this.leafId = null; + } + + private _createEntryBase(): Omit { + const id = uuidv4(); + const base = { + id, + parentId: this.leafId, + timestamp: new Date().toISOString(), + }; + this.leafId = id; + return base; + } + private _persist(entry: SessionEntry): void { const hasAssistant = this.inMemoryEntries.some((e) => e.type === "message" && e.message.role === "assistant"); if (!hasAssistant) return; @@ -126,7 +153,7 @@ export class MomSessionManager { contextSlackTimestamps.add(entry.timestamp); // Also store message text to catch duplicates added via prompt() - // AppMessage has different shapes, check for content property + // AgentMessage has different shapes, check for content property const msg = msgEntry.message as { role: string; content?: unknown }; if (msg.role === "user" && msg.content !== undefined) { const content = msg.content; @@ -162,7 +189,7 @@ export class MomSessionManager { isBot?: boolean; } - const newMessages: Array<{ timestamp: string; slackTs: string; message: AppMessage }> = []; + const newMessages: Array<{ timestamp: string; slackTs: string; message: AgentMessage }> = []; for (const line of logLines) { try { @@ -188,7 +215,7 @@ export class MomSessionManager { if (contextMessageTexts.has(messageText)) continue; const msgTime = new Date(date).getTime() || Date.now(); - const userMessage: AppMessage = { + const userMessage: AgentMessage = { role: "user", content: messageText, timestamp: msgTime, @@ -206,11 +233,15 @@ export class MomSessionManager { newMessages.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime()); for (const { timestamp, message } of newMessages) { + const id = uuidv4(); const entry: SessionMessageEntry = { type: "message", + id, + parentId: this.leafId, timestamp, // Use log date as entry timestamp for consistent deduplication message, }; + this.leafId = id; this.inMemoryEntries.push(entry); appendFileSync(this.contextFile, `${JSON.stringify(entry)}\n`); @@ -226,17 +257,17 @@ export class MomSessionManager { return null; } - private loadEntriesFromFile(): SessionEntry[] { + private loadEntriesFromFile(): FileEntry[] { if (!existsSync(this.contextFile)) return []; const content = readFileSync(this.contextFile, "utf8"); - const entries: SessionEntry[] = []; + const entries: FileEntry[] = []; const lines = content.trim().split("\n"); for (const line of lines) { if (!line.trim()) continue; try { - const entry = JSON.parse(line) as SessionEntry; + const entry = JSON.parse(line) as FileEntry; entries.push(entry); } catch { // Skip malformed lines @@ -246,20 +277,16 @@ export class MomSessionManager { return entries; } - saveMessage(message: AppMessage): void { - const entry: SessionMessageEntry = { - type: "message", - timestamp: new Date().toISOString(), - message, - }; + saveMessage(message: AgentMessage): void { + const entry: SessionMessageEntry = { ...this._createEntryBase(), type: "message", message }; this.inMemoryEntries.push(entry); this._persist(entry); } saveThinkingLevelChange(thinkingLevel: string): void { const entry: ThinkingLevelChangeEntry = { + ...this._createEntryBase(), type: "thinking_level_change", - timestamp: new Date().toISOString(), thinkingLevel, }; this.inMemoryEntries.push(entry); @@ -267,12 +294,7 @@ export class MomSessionManager { } saveModelChange(provider: string, modelId: string): void { - const entry: ModelChangeEntry = { - type: "model_change", - timestamp: new Date().toISOString(), - provider, - modelId, - }; + const entry: ModelChangeEntry = { ...this._createEntryBase(), type: "model_change", provider, modelId }; this.inMemoryEntries.push(entry); this._persist(entry); } @@ -283,17 +305,15 @@ export class MomSessionManager { } /** Load session with compaction support */ - loadSession(): LoadedSession { + buildSessionContex(): SessionContext { const entries = this.loadEntries(); return buildSessionContext(entries); } loadEntries(): SessionEntry[] { // Re-read from file to get latest state - if (existsSync(this.contextFile)) { - return this.loadEntriesFromFile(); - } - return [...this.inMemoryEntries]; + const entries = existsSync(this.contextFile) ? this.loadEntriesFromFile() : this.inMemoryEntries; + return entries.filter((e): e is SessionEntry => e.type !== "session"); } getSessionId(): string { @@ -332,15 +352,15 @@ export class MomSessionManager { } loadModel(): { provider: string; modelId: string } | null { - return this.loadSession().model; + return this.buildSessionContex().model; } loadThinkingLevel(): string { - return this.loadSession().thinkingLevel; + return this.buildSessionContex().thinkingLevel; } /** Not used by mom but required by AgentSession interface */ - createBranchedSessionFromEntries(_entries: SessionEntry[], _branchBeforeIndex: number): string | null { + createBranchedSession(_leafId: string): string | null { return null; // Mom doesn't support branching } } diff --git a/packages/mom/src/tools/attach.ts b/packages/mom/src/tools/attach.ts index 174faf02..fae9e8db 100644 --- a/packages/mom/src/tools/attach.ts +++ b/packages/mom/src/tools/attach.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import { basename, resolve as resolvePath } from "path"; diff --git a/packages/mom/src/tools/bash.ts b/packages/mom/src/tools/bash.ts index dbda5e43..82e9dacd 100644 --- a/packages/mom/src/tools/bash.ts +++ b/packages/mom/src/tools/bash.ts @@ -2,7 +2,7 @@ import { randomBytes } from "node:crypto"; import { createWriteStream } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import type { Executor } from "../sandbox.js"; import { DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, formatSize, type TruncationResult, truncateTail } from "./truncate.js"; diff --git a/packages/mom/src/tools/edit.ts b/packages/mom/src/tools/edit.ts index 3fce6146..5ee678e8 100644 --- a/packages/mom/src/tools/edit.ts +++ b/packages/mom/src/tools/edit.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import * as Diff from "diff"; import type { Executor } from "../sandbox.js"; diff --git a/packages/mom/src/tools/index.ts b/packages/mom/src/tools/index.ts index 607e2e83..ff21ad0a 100644 --- a/packages/mom/src/tools/index.ts +++ b/packages/mom/src/tools/index.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import type { Executor } from "../sandbox.js"; import { attachTool } from "./attach.js"; import { createBashTool } from "./bash.js"; diff --git a/packages/mom/src/tools/read.ts b/packages/mom/src/tools/read.ts index db36d615..4f284d70 100644 --- a/packages/mom/src/tools/read.ts +++ b/packages/mom/src/tools/read.ts @@ -1,4 +1,5 @@ -import type { AgentTool, ImageContent, TextContent } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import type { ImageContent, TextContent } from "@mariozechner/pi-ai"; import { Type } from "@sinclair/typebox"; import { extname } from "path"; import type { Executor } from "../sandbox.js"; diff --git a/packages/mom/src/tools/write.ts b/packages/mom/src/tools/write.ts index 22bdb1e5..ebd0735b 100644 --- a/packages/mom/src/tools/write.ts +++ b/packages/mom/src/tools/write.ts @@ -1,4 +1,4 @@ -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; import { Type } from "@sinclair/typebox"; import type { Executor } from "../sandbox.js"; diff --git a/packages/pods/src/cli.ts b/packages/pods/src/cli.ts index 150e5dd2..e6a25f55 100644 --- a/packages/pods/src/cli.ts +++ b/packages/pods/src/cli.ts @@ -4,7 +4,7 @@ import { spawn } from "child_process"; import { readFileSync } from "fs"; import { dirname, join } from "path"; import { fileURLToPath } from "url"; -import { listModels, startModel, stopModel, viewLogs } from "./commands/models.js"; +import { listModels, showKnownModels, startModel, stopAllModels, stopModel, viewLogs } from "./commands/models.js"; import { listPods, removePodCommand, setupPod, switchActivePod } from "./commands/pods.js"; import { promptModel } from "./commands/prompt.js"; import { getActivePod, loadConfig } from "./config.js"; @@ -241,7 +241,6 @@ try { const modelId = args[1]; if (!modelId) { // Show available models - const { showKnownModels } = await import("./commands/models.js"); await showKnownModels(); process.exit(0); } @@ -306,7 +305,6 @@ try { const name = args[1]; if (!name) { // Stop all models on the active pod - const { stopAllModels } = await import("./commands/models.js"); await stopAllModels({ pod: podOverride }); } else { await stopModel(name, { pod: podOverride }); diff --git a/packages/proxy/README.md b/packages/proxy/README.md deleted file mode 100644 index 61b721f6..00000000 --- a/packages/proxy/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# @mariozechner/pi-proxy - -CORS and authentication proxy for pi-ai. Enables browser clients to access OAuth-protected endpoints. - -## Usage - -### CORS Proxy - -Zero-config CORS proxy for development: - -```bash -# Run directly with tsx -npx tsx packages/proxy/src/cors-proxy.ts 3001 - -# Or use npm script -npm run dev -w @mariozechner/pi-proxy - -# Or install globally and use CLI -npm install -g @mariozechner/pi-proxy -pi-proxy 3001 -``` - -The proxy will forward requests to any URL: - -```javascript -// Instead of: -fetch('https://api.anthropic.com/v1/messages', { ... }) - -// Use: -fetch('http://localhost:3001?url=https://api.anthropic.com/v1/messages', { ... }) -``` - -### OAuth Integration - -For Anthropic OAuth tokens, configure your client to use the proxy: - -```typescript -import Anthropic from '@anthropic-ai/sdk'; - -const client = new Anthropic({ - apiKey: 'oauth_token_here', - baseURL: 'http://localhost:3001?url=https://api.anthropic.com' -}); -``` - -## Future Proxy Types - -- **BunnyCDN Edge Function**: Deploy as edge function -- **Managed Proxy**: Self-hosted with provider key management and credential auth -- **Cloudflare Worker**: Deploy as CF worker - -## Architecture - -The proxy: -1. Accepts requests with `?url=` query parameter -2. Forwards all headers (except `host`, `origin`) -3. Forwards request body for non-GET/HEAD requests -4. Returns response with CORS headers enabled -5. Strips CORS headers from upstream response - -## Development - -```bash -npm install -npm run build -npm run check -``` diff --git a/packages/proxy/package.json b/packages/proxy/package.json deleted file mode 100644 index b6d9c515..00000000 --- a/packages/proxy/package.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "name": "@mariozechner/pi-proxy", - "version": "0.30.2", - "type": "module", - "description": "CORS and authentication proxy for pi-ai", - "main": "dist/index.js", - "types": "dist/index.d.ts", - "bin": { - "pi-proxy": "dist/cli.js" - }, - "scripts": { - "clean": "rm -rf dist", - "build": "tsc", - "typecheck": "tsgo --noEmit", - "dev": "tsx src/cors-proxy.ts 3001" - }, - "dependencies": { - "@hono/node-server": "^1.14.0", - "hono": "^4.6.16" - }, - "devDependencies": { - "@types/node": "^22.10.5", - "tsx": "^4.19.2", - "typescript": "^5.7.3" - } -} diff --git a/packages/proxy/src/cli.ts b/packages/proxy/src/cli.ts deleted file mode 100644 index 29ac06b3..00000000 --- a/packages/proxy/src/cli.ts +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env node -import { spawn } from "node:child_process"; -import path from "node:path"; -import { fileURLToPath } from "node:url"; - -const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const port = process.argv[2] || "3001"; - -// Run the CORS proxy -const child = spawn("node", [path.join(__dirname, "cors-proxy.js"), port], { - stdio: "inherit", -}); - -child.on("exit", (code) => { - process.exit(code || 0); -}); diff --git a/packages/proxy/src/cors-proxy.ts b/packages/proxy/src/cors-proxy.ts deleted file mode 100644 index 9fbeb387..00000000 --- a/packages/proxy/src/cors-proxy.ts +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env node -import { serve } from "@hono/node-server"; -import { Hono } from "hono"; -import { cors } from "hono/cors"; - -export function createCorsProxy() { - const app = new Hono(); - - // Enable CORS for all origins - app.use("*", cors()); - - // Proxy all requests - app.all("*", async (c) => { - const url = new URL(c.req.url); - const targetUrl = url.searchParams.get("url"); - - if (!targetUrl) { - return c.json({ error: "Missing 'url' query parameter" }, 400); - } - - try { - // Forward the request - const headers = new Headers(); - c.req.raw.headers.forEach((value, key) => { - // Skip host and origin headers - if (key.toLowerCase() !== "host" && key.toLowerCase() !== "origin") { - headers.set(key, value); - } - }); - - const response = await fetch(targetUrl, { - method: c.req.method, - headers, - body: c.req.method !== "GET" && c.req.method !== "HEAD" ? await c.req.raw.clone().arrayBuffer() : undefined, - }); - - // Forward response headers - const responseHeaders = new Headers(); - response.headers.forEach((value, key) => { - // Skip CORS headers (we handle them) - if (!key.toLowerCase().startsWith("access-control-")) { - responseHeaders.set(key, value); - } - }); - - // Return proxied response - return new Response(response.body, { - status: response.status, - statusText: response.statusText, - headers: responseHeaders, - }); - } catch (error) { - console.error("Proxy error:", error); - return c.json({ error: error instanceof Error ? error.message : "Proxy request failed" }, 502); - } - }); - - return app; -} - -// CLI entry point -if (import.meta.url === `file://${process.argv[1]}`) { - const app = createCorsProxy(); - const port = Number.parseInt(process.argv[2] || "3001", 10); - - console.log(`🔌 CORS proxy running on http://localhost:${port}`); - console.log(`Usage: http://localhost:${port}?url=`); - - serve({ - fetch: app.fetch, - port, - }); -} diff --git a/packages/proxy/src/index.ts b/packages/proxy/src/index.ts deleted file mode 100644 index 4e1bcaee..00000000 --- a/packages/proxy/src/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { createCorsProxy } from "./cors-proxy.js"; diff --git a/packages/proxy/tsconfig.json b/packages/proxy/tsconfig.json deleted file mode 100644 index 3502c876..00000000 --- a/packages/proxy/tsconfig.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "extends": "../../tsconfig.base.json", - "compilerOptions": { - "outDir": "dist", - "rootDir": "src" - }, - "include": ["src/**/*"] -} diff --git a/packages/tui/CHANGELOG.md b/packages/tui/CHANGELOG.md index 298f5c39..e1c04859 100644 --- a/packages/tui/CHANGELOG.md +++ b/packages/tui/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- Markdown component now renders HTML tags as plain text instead of silently dropping them ([#359](https://github.com/badlogic/pi-mono/issues/359)) + ## [0.29.0] - 2025-12-25 ### Added diff --git a/packages/tui/src/components/markdown.ts b/packages/tui/src/components/markdown.ts index c43f8f2f..857af283 100644 --- a/packages/tui/src/components/markdown.ts +++ b/packages/tui/src/components/markdown.ts @@ -317,7 +317,10 @@ export class Markdown implements Component { break; case "html": - // Skip HTML for terminal output + // Render HTML as plain text (escaped for terminal) + if ("raw" in token && typeof token.raw === "string") { + lines.push(this.applyDefaultStyle(token.raw.trim())); + } break; case "space": @@ -394,6 +397,13 @@ export class Markdown implements Component { break; } + case "html": + // Render inline HTML as plain text + if ("raw" in token && typeof token.raw === "string") { + result += this.applyDefaultStyle(token.raw); + } + break; + default: // Handle any other inline token types as plain text if ("text" in token && typeof token.text === "string") { diff --git a/packages/tui/src/index.ts b/packages/tui/src/index.ts index 71f0eca3..4d4d895c 100644 --- a/packages/tui/src/index.ts +++ b/packages/tui/src/index.ts @@ -50,6 +50,8 @@ export { isEnter, isEscape, isHome, + isShiftCtrlD, + isShiftCtrlO, isShiftCtrlP, isShiftEnter, isShiftTab, diff --git a/packages/tui/src/keys.ts b/packages/tui/src/keys.ts index edd2c06d..da8cf699 100644 --- a/packages/tui/src/keys.ts +++ b/packages/tui/src/keys.ts @@ -304,6 +304,14 @@ export function isCtrlO(data: string): boolean { return data === RAW.CTRL_O || data === Keys.CTRL_O || matchesKittySequence(data, CODEPOINTS.o, MODIFIERS.ctrl); } +/** + * Check if input matches Shift+Ctrl+O (Kitty protocol only). + * Ignores lock key bits. + */ +export function isShiftCtrlO(data: string): boolean { + return matchesKittySequence(data, CODEPOINTS.o, MODIFIERS.shift + MODIFIERS.ctrl); +} + /** * Check if input matches Ctrl+P (raw byte or Kitty protocol). * Ignores lock key bits. @@ -320,6 +328,14 @@ export function isShiftCtrlP(data: string): boolean { return matchesKittySequence(data, CODEPOINTS.p, MODIFIERS.shift + MODIFIERS.ctrl); } +/** + * Check if input matches Shift+Ctrl+D (Kitty protocol only, for debug). + * Ignores lock key bits. + */ +export function isShiftCtrlD(data: string): boolean { + return matchesKittySequence(data, CODEPOINTS.d, MODIFIERS.shift + MODIFIERS.ctrl); +} + /** * Check if input matches Ctrl+T (raw byte or Kitty protocol). * Ignores lock key bits. diff --git a/packages/tui/src/tui.ts b/packages/tui/src/tui.ts index 1baf5e4d..583f099e 100644 --- a/packages/tui/src/tui.ts +++ b/packages/tui/src/tui.ts @@ -5,6 +5,7 @@ import * as fs from "node:fs"; import * as os from "node:os"; import * as path from "node:path"; +import { isShiftCtrlD } from "./keys.js"; import type { Terminal } from "./terminal.js"; import { getCapabilities, setCellDimensions } from "./terminal-image.js"; import { visibleWidth } from "./utils.js"; @@ -78,6 +79,9 @@ export class TUI extends Container { private previousLines: string[] = []; private previousWidth = 0; private focusedComponent: Component | null = null; + + /** Global callback for debug key (Shift+Ctrl+D). Called before input is forwarded to focused component. */ + public onDebug?: () => void; private renderRequested = false; private cursorRow = 0; // Track where cursor is (0-indexed, relative to our first line) private inputBuffer = ""; // Buffer for parsing terminal responses @@ -141,6 +145,12 @@ export class TUI extends Container { data = filtered; } + // Global debug key handler (Shift+Ctrl+D) + if (isShiftCtrlD(data) && this.onDebug) { + this.onDebug(); + return; + } + // Pass input to focused component (including Ctrl+C) // The focused component can decide how to handle Ctrl+C if (this.focusedComponent?.handleInput) { diff --git a/packages/web-ui/CHANGELOG.md b/packages/web-ui/CHANGELOG.md new file mode 100644 index 00000000..517fdef1 --- /dev/null +++ b/packages/web-ui/CHANGELOG.md @@ -0,0 +1,83 @@ +# Changelog + +## [Unreleased] + +### Breaking Changes + +- **Agent class moved to `@mariozechner/pi-agent-core`**: The `Agent` class, `AgentState`, and related types are no longer exported from this package. Import them from `@mariozechner/pi-agent-core` instead. + +- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, `AgentTransport` interface, and related types have been removed. The `Agent` class now uses `streamFn` for custom streaming. + +- **`AppMessage` renamed to `AgentMessage`**: Now imported from `@mariozechner/pi-agent-core`. Custom message types use declaration merging on `CustomAgentMessages` interface. + +- **`UserMessageWithAttachments` is now a custom message type**: Has `role: "user-with-attachments"` instead of `role: "user"`. Use `isUserMessageWithAttachments()` type guard. + +- **`CustomMessages` interface removed**: Use declaration merging on `CustomAgentMessages` from `@mariozechner/pi-agent-core` instead. + +### Added + +- **`defaultConvertToLlm`**: Default message transformer that handles `UserMessageWithAttachments` and `ArtifactMessage`. Apps can extend this for custom message types. + +- **`convertAttachments`**: Utility to convert `Attachment[]` to LLM content blocks (images and extracted document text). + +- **`isUserMessageWithAttachments` / `isArtifactMessage`**: Type guard functions for custom message types. + +- **`createStreamFn`**: Creates a stream function with CORS proxy support. Reads proxy settings on each call for dynamic configuration. + +- **Default `streamFn` and `getApiKey`**: `AgentInterface` now sets sensible defaults if not provided: + - `streamFn`: Uses `createStreamFn` with proxy settings from storage + - `getApiKey`: Reads from `providerKeys` storage + +- **Proxy utilities exported**: `applyProxyIfNeeded`, `shouldUseProxyForProvider`, `isCorsError`, `createStreamFn` + +### Removed + +- `Agent` class (moved to `@mariozechner/pi-agent-core`) +- `ProviderTransport` class +- `AppTransport` class +- `AgentTransport` interface +- `AgentRunConfig` type +- `ProxyAssistantMessageEvent` type + +### Migration Guide + +**Before (0.30.x):** +```typescript +import { Agent, ProviderTransport, type AppMessage } from '@mariozechner/pi-web-ui'; + +const agent = new Agent({ + transport: new ProviderTransport(), + messageTransformer: (messages: AppMessage[]) => messages.filter(...) +}); +``` + +**After:** +```typescript +import { Agent, type AgentMessage } from '@mariozechner/pi-agent-core'; +import { defaultConvertToLlm } from '@mariozechner/pi-web-ui'; + +const agent = new Agent({ + convertToLlm: (messages: AgentMessage[]) => { + // Extend defaultConvertToLlm for custom types + return defaultConvertToLlm(messages); + } +}); +// AgentInterface will set streamFn and getApiKey defaults automatically +``` + +**Custom message types:** +```typescript +// Before: declaration merging on CustomMessages +declare module "@mariozechner/pi-web-ui" { + interface CustomMessages { + "my-message": MyMessage; + } +} + +// After: declaration merging on CustomAgentMessages +declare module "@mariozechner/pi-agent-core" { + interface CustomAgentMessages { + "my-message": MyMessage; + } +} +``` diff --git a/packages/web-ui/example/src/custom-messages.ts b/packages/web-ui/example/src/custom-messages.ts index 8b02b13f..d5700b60 100644 --- a/packages/web-ui/example/src/custom-messages.ts +++ b/packages/web-ui/example/src/custom-messages.ts @@ -1,7 +1,7 @@ import { Alert } from "@mariozechner/mini-lit/dist/Alert.js"; import type { Message } from "@mariozechner/pi-ai"; -import type { AppMessage, MessageRenderer } from "@mariozechner/pi-web-ui"; -import { registerMessageRenderer } from "@mariozechner/pi-web-ui"; +import type { AgentMessage, MessageRenderer } from "@mariozechner/pi-web-ui"; +import { defaultConvertToLlm, registerMessageRenderer } from "@mariozechner/pi-web-ui"; import { html } from "lit"; // ============================================================================ @@ -16,9 +16,10 @@ export interface SystemNotificationMessage { timestamp: string; } -// Extend CustomMessages interface via declaration merging -declare module "@mariozechner/pi-web-ui" { - interface CustomMessages { +// Extend CustomAgentMessages interface via declaration merging +// This must target pi-agent-core where CustomAgentMessages is defined +declare module "@mariozechner/pi-agent-core" { + interface CustomAgentMessages { "system-notification": SystemNotificationMessage; } } @@ -74,36 +75,25 @@ export function createSystemNotification( // 5. CUSTOM MESSAGE TRANSFORMER // ============================================================================ -// Transform custom messages to user messages with tags so LLM can see them -export function customMessageTransformer(messages: AppMessage[]): Message[] { - return messages - .filter((m) => { - // Filter out artifact messages - they're for session reconstruction only - if (m.role === "artifact") { - return false; - } +/** + * Custom message transformer that extends defaultConvertToLlm. + * Handles system-notification messages by converting them to user messages. + */ +export function customConvertToLlm(messages: AgentMessage[]): Message[] { + // First, handle our custom system-notification type + const processed = messages.map((m): AgentMessage => { + if (m.role === "system-notification") { + const notification = m as SystemNotificationMessage; + // Convert to user message with tags + return { + role: "user", + content: `${notification.message}`, + timestamp: Date.now(), + }; + } + return m; + }); - // Keep LLM-compatible messages + custom messages - return ( - m.role === "user" || m.role === "assistant" || m.role === "toolResult" || m.role === "system-notification" - ); - }) - .map((m) => { - // Transform system notifications to user messages - if (m.role === "system-notification") { - const notification = m as SystemNotificationMessage; - return { - role: "user", - content: `${notification.message}`, - } as Message; - } - - // Strip attachments from user messages - if (m.role === "user") { - const { attachments: _, ...rest } = m as any; - return rest as Message; - } - - return m as Message; - }); + // Then use defaultConvertToLlm for standard handling + return defaultConvertToLlm(processed); } diff --git a/packages/web-ui/example/src/main.ts b/packages/web-ui/example/src/main.ts index ef49faba..4a93f409 100644 --- a/packages/web-ui/example/src/main.ts +++ b/packages/web-ui/example/src/main.ts @@ -1,10 +1,9 @@ import "@mariozechner/mini-lit/dist/ThemeToggle.js"; +import { Agent, type AgentMessage } from "@mariozechner/pi-agent-core"; import { getModel } from "@mariozechner/pi-ai"; import { - Agent, type AgentState, ApiKeyPromptDialog, - type AppMessage, AppStorage, ChatPanel, CustomProvidersStore, @@ -13,7 +12,6 @@ import { // PersistentStorageDialog, // TODO: Fix - currently broken ProviderKeysStore, ProvidersModelsTab, - ProviderTransport, ProxyTab, SessionListDialog, SessionsStore, @@ -27,11 +25,7 @@ import "./app.css"; import { icon } from "@mariozechner/mini-lit"; import { Button } from "@mariozechner/mini-lit/dist/Button.js"; import { Input } from "@mariozechner/mini-lit/dist/Input.js"; -import { - createSystemNotification, - customMessageTransformer, - registerCustomMessageRenderers, -} from "./custom-messages.js"; +import { createSystemNotification, customConvertToLlm, registerCustomMessageRenderers } from "./custom-messages.js"; // Register custom message renderers registerCustomMessageRenderers(); @@ -75,9 +69,9 @@ let agent: Agent; let chatPanel: ChatPanel; let agentUnsubscribe: (() => void) | undefined; -const generateTitle = (messages: AppMessage[]): string => { - const firstUserMsg = messages.find((m) => m.role === "user"); - if (!firstUserMsg || firstUserMsg.role !== "user") return ""; +const generateTitle = (messages: AgentMessage[]): string => { + const firstUserMsg = messages.find((m) => m.role === "user" || m.role === "user-with-attachments"); + if (!firstUserMsg || (firstUserMsg.role !== "user" && firstUserMsg.role !== "user-with-attachments")) return ""; let text = ""; const content = firstUserMsg.content; @@ -99,8 +93,8 @@ const generateTitle = (messages: AppMessage[]): string => { return text.length <= 50 ? text : `${text.substring(0, 47)}...`; }; -const shouldSaveSession = (messages: AppMessage[]): boolean => { - const hasUserMsg = messages.some((m: any) => m.role === "user"); +const shouldSaveSession = (messages: AgentMessage[]): boolean => { + const hasUserMsg = messages.some((m: any) => m.role === "user" || m.role === "user-with-attachments"); const hasAssistantMsg = messages.some((m: any) => m.role === "assistant"); return hasUserMsg && hasAssistantMsg; }; @@ -166,8 +160,6 @@ const createAgent = async (initialState?: Partial) => { agentUnsubscribe(); } - const transport = new ProviderTransport(); - agent = new Agent({ initialState: initialState || { systemPrompt: `You are a helpful AI assistant with access to various tools. @@ -182,9 +174,8 @@ Feel free to use these tools when needed to provide accurate and helpful respons messages: [], tools: [], }, - transport, - // Custom transformer: convert system notifications to user messages with tags - messageTransformer: customMessageTransformer, + // Custom transformer: convert custom messages to LLM-compatible format + convertToLlm: customConvertToLlm, }); agentUnsubscribe = agent.subscribe((event: any) => { @@ -353,9 +344,9 @@ const renderApp = () => { size: "sm", children: icon(Bell, "sm"), onClick: () => { - // Demo: Inject custom message + // Demo: Inject custom message (will appear on next agent run) if (agent) { - agent.appendMessage( + agent.queueMessage( createSystemNotification( "This is a custom message! It appears in the UI but is never sent to the LLM.", ), diff --git a/packages/web-ui/example/src/test-sessions.ts b/packages/web-ui/example/src/test-sessions.ts deleted file mode 100644 index bb4e008e..00000000 --- a/packages/web-ui/example/src/test-sessions.ts +++ /dev/null @@ -1,104 +0,0 @@ -const simpleHtml = { - "systemPrompt": "You are a helpful AI assistant with access to various tools.\n\nAvailable tools:\n- Calculator: Perform mathematical calculations\n- Current Time: Get the current date and time in any timezone\n- Artifacts: Create interactive HTML, SVG, Markdown, and text artifacts\n\nFeel free to use these tools when needed to provide accurate and helpful responses.", - "model": { - "id": "claude-sonnet-4-5-20250929", - "name": "Claude Sonnet 4.5", - "api": "anthropic-messages", - "provider": "anthropic", - "baseUrl": "https://api.anthropic.com", - "reasoning": true, - "input": [ - "text", - "image" - ], - "cost": { - "input": 3, - "output": 15, - "cacheRead": 0.3, - "cacheWrite": 3.75 - }, - "contextWindow": 200000, - "maxTokens": 64000 - }, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "simple html" - } - ] - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "I'll create a simple HTML page for you!" - }, - { - "type": "toolCall", - "id": "toolu_01JXaq9rMAUt1goJ5tMQXgLh", - "name": "artifacts", - "arguments": { - "command": "create", - "filename": "index.html", - "title": "Simple HTML Page", - "content": "\n\n\n \n \n Simple HTML Page\n \n\n\n
\n

Welcome to My Simple Page

\n
\n \n
\n

Hello!

\n

This is a simple HTML page with some basic styling.

\n \n

Features:

\n
    \n
  • Clean design
  • \n
  • Responsive layout
  • \n
  • Interactive button
  • \n
\n \n \n

\n
\n \n \n\n" - } - } - ], - "api": "anthropic-messages", - "provider": "anthropic", - "model": "claude-sonnet-4-5-20250929", - "usage": { - "input": 2108, - "output": 666, - "cacheRead": 0, - "cacheWrite": 0, - "cost": { - "input": 0.006324, - "output": 0.00999, - "cacheRead": 0, - "cacheWrite": 0, - "total": 0.016314000000000002 - } - }, - "stopReason": "toolUse" - }, - { - "role": "toolResult", - "toolCallId": "toolu_01JXaq9rMAUt1goJ5tMQXgLh", - "toolName": "artifacts", - "output": "Created file index.html\n\nExecution timed out. Partial logs:\nKeine Logs für index.html", - "isError": false - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "I've created a simple HTML page for you! It includes:\n\n- **Clean header** with a green background\n- **Main content area** with a white card-style design\n- **A list** showing some features\n- **An interactive button** that displays a message when clicked\n- **Responsive styling** that looks good on different screen sizes\n\nThe page has a light gray background and uses simple, clean styling. Try clicking the button to see it in action! \n\nWould you like me to modify anything or add more features?" - } - ], - "api": "anthropic-messages", - "provider": "anthropic", - "model": "claude-sonnet-4-5-20250929", - "usage": { - "input": 2811, - "output": 115, - "cacheRead": 0, - "cacheWrite": 0, - "cost": { - "input": 0.008433, - "output": 0.001725, - "cacheRead": 0, - "cacheWrite": 0, - "total": 0.010158 - } - }, - "stopReason": "stop" - } - ] -}; \ No newline at end of file diff --git a/packages/web-ui/example/tsconfig.json b/packages/web-ui/example/tsconfig.json index 340e2439..e095a279 100644 --- a/packages/web-ui/example/tsconfig.json +++ b/packages/web-ui/example/tsconfig.json @@ -6,6 +6,7 @@ "moduleResolution": "bundler", "paths": { "*": ["./*"], + "@mariozechner/pi-agent-core": ["../../agent/dist/index.d.ts"], "@mariozechner/pi-ai": ["../../ai/dist/index.d.ts"], "@mariozechner/pi-tui": ["../../tui/dist/index.d.ts"], "@mariozechner/pi-web-ui": ["../dist/index.d.ts"] diff --git a/packages/web-ui/package.json b/packages/web-ui/package.json index 234271c6..c732b97a 100644 --- a/packages/web-ui/package.json +++ b/packages/web-ui/package.json @@ -11,7 +11,7 @@ }, "scripts": { "clean": "rm -rf dist", - "build": "tsc -p tsconfig.build.json && tailwindcss -i ./src/app.css -o ./dist/app.css --minify", + "build": "tsgo -p tsconfig.build.json && tailwindcss -i ./src/app.css -o ./dist/app.css --minify", "dev": "concurrently --names \"build,example\" --prefix-colors \"cyan,green\" \"tsc -p tsconfig.build.json --watch --preserveWatchOutput\" \"tailwindcss -i ./src/app.css -o ./dist/app.css --watch\" \"npm run dev --prefix example\"", "dev:tsc": "concurrently --names \"build\" --prefix-colors \"cyan\" \"tsc -p tsconfig.build.json --watch --preserveWatchOutput\" \"tailwindcss -i ./src/app.css -o ./dist/app.css --watch\"", "check": "biome check --write . && tsc --noEmit && cd example && biome check --write . && tsc --noEmit" diff --git a/packages/web-ui/src/ChatPanel.ts b/packages/web-ui/src/ChatPanel.ts index ed2d152f..e055c776 100644 --- a/packages/web-ui/src/ChatPanel.ts +++ b/packages/web-ui/src/ChatPanel.ts @@ -1,9 +1,8 @@ import { Badge } from "@mariozechner/mini-lit/dist/Badge.js"; import { html, LitElement } from "lit"; import { customElement, state } from "lit/decorators.js"; -import type { Agent } from "./agent/agent.js"; import "./components/AgentInterface.js"; -import type { AgentTool } from "@mariozechner/pi-ai"; +import type { Agent, AgentTool } from "@mariozechner/pi-agent-core"; import type { AgentInterface } from "./components/AgentInterface.js"; import { ArtifactsRuntimeProvider } from "./components/sandbox/ArtifactsRuntimeProvider.js"; import { AttachmentsRuntimeProvider } from "./components/sandbox/AttachmentsRuntimeProvider.js"; @@ -95,7 +94,7 @@ export class ChatPanel extends LitElement { const runtimeProvidersFactory = () => { const attachments: Attachment[] = []; for (const message of this.agent!.state.messages) { - if (message.role === "user") { + if (message.role === "user-with-attachments") { message.attachments?.forEach((a) => { attachments.push(a); }); diff --git a/packages/web-ui/src/agent/agent.ts b/packages/web-ui/src/agent/agent.ts deleted file mode 100644 index f751e2e1..00000000 --- a/packages/web-ui/src/agent/agent.ts +++ /dev/null @@ -1,341 +0,0 @@ -import type { Context, QueuedMessage } from "@mariozechner/pi-ai"; -import { - type AgentTool, - type AssistantMessage as AssistantMessageType, - getModel, - type ImageContent, - type Message, - type Model, - type TextContent, -} from "@mariozechner/pi-ai"; -import type { AppMessage } from "../components/Messages.js"; -import type { Attachment } from "../utils/attachment-utils.js"; -import type { AgentRunConfig, AgentTransport } from "./transports/types.js"; -import type { DebugLogEntry } from "./types.js"; - -// Default transformer: Keep only LLM-compatible messages, strip app-specific fields -function defaultMessageTransformer(messages: AppMessage[]): Message[] { - return messages - .filter((m) => { - // Only keep standard LLM message roles - return m.role === "user" || m.role === "assistant" || m.role === "toolResult"; - }) - .map((m) => { - if (m.role === "user") { - // Strip attachments field (app-specific) - - // biome-ignore lint/correctness/noUnusedVariables: fine here - const { attachments, ...rest } = m as any; - return rest as Message; - } - return m as Message; - }); -} - -export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high"; - -export interface AgentState { - systemPrompt: string; - model: Model; - thinkingLevel: ThinkingLevel; - tools: AgentTool[]; - messages: AppMessage[]; - isStreaming: boolean; - streamMessage: Message | null; - pendingToolCalls: Set; - error?: string; -} - -export type AgentEvent = - | { type: "state-update"; state: AgentState } - | { type: "error-no-model" } - | { type: "error-no-api-key"; provider: string } - | { type: "started" } - | { type: "completed" }; - -export interface AgentOptions { - initialState?: Partial; - debugListener?: (entry: DebugLogEntry) => void; - transport: AgentTransport; - // Transform app messages to LLM-compatible messages before sending to transport - messageTransformer?: (messages: AppMessage[]) => Message[] | Promise; -} - -export class Agent { - private _state: AgentState = { - systemPrompt: "", - model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"), - thinkingLevel: "off", - tools: [], - messages: [], - isStreaming: false, - streamMessage: null, - pendingToolCalls: new Set(), - error: undefined, - }; - private listeners = new Set<(e: AgentEvent) => void>(); - private abortController?: AbortController; - private transport: AgentTransport; - private debugListener?: (entry: DebugLogEntry) => void; - private messageTransformer: (messages: AppMessage[]) => Message[] | Promise; - private messageQueue: Array> = []; - - constructor(opts: AgentOptions) { - this._state = { ...this._state, ...opts.initialState }; - this.debugListener = opts.debugListener; - this.transport = opts.transport; - this.messageTransformer = opts.messageTransformer || defaultMessageTransformer; - } - - get state(): AgentState { - return this._state; - } - - subscribe(fn: (e: AgentEvent) => void): () => void { - this.listeners.add(fn); - fn({ type: "state-update", state: this._state }); - return () => this.listeners.delete(fn); - } - - // Mutators - setSystemPrompt(v: string) { - this.patch({ systemPrompt: v }); - } - setModel(m: Model) { - this.patch({ model: m }); - } - setThinkingLevel(l: ThinkingLevel) { - this.patch({ thinkingLevel: l }); - } - setTools(t: AgentTool[]) { - this.patch({ tools: t }); - } - replaceMessages(ms: AppMessage[]) { - this.patch({ messages: ms.slice() }); - } - appendMessage(m: AppMessage) { - this.patch({ messages: [...this._state.messages, m] }); - } - async queueMessage(m: AppMessage) { - // Transform message and queue it for injection at next turn - const transformed = await this.messageTransformer([m]); - this.messageQueue.push({ - original: m, - llm: transformed[0], // undefined if filtered out - }); - } - clearMessages() { - this.patch({ messages: [] }); - } - - abort() { - this.abortController?.abort(); - } - - private logState(message: string) { - const { systemPrompt, model, messages } = this._state; - console.log(message, { systemPrompt, model, messages }); - } - - async prompt(input: string, attachments?: Attachment[]) { - const model = this._state.model; - if (!model) { - this.emit({ type: "error-no-model" }); - return; - } - - // Build user message with attachments - const content: Array = [{ type: "text", text: input }]; - if (attachments?.length) { - for (const a of attachments) { - if (a.type === "image") { - content.push({ type: "image", data: a.content, mimeType: a.mimeType }); - } else if (a.type === "document" && a.extractedText) { - content.push({ - type: "text", - text: `\n\n[Document: ${a.fileName}]\n${a.extractedText}`, - isDocument: true, - } as TextContent); - } - } - } - - const userMessage: AppMessage = { - role: "user", - content, - attachments: attachments?.length ? attachments : undefined, - timestamp: Date.now(), - }; - - this.abortController = new AbortController(); - this.patch({ isStreaming: true, streamMessage: null, error: undefined }); - this.emit({ type: "started" }); - - const reasoning = - this._state.thinkingLevel === "off" - ? undefined - : this._state.thinkingLevel === "minimal" - ? "low" - : this._state.thinkingLevel; - const cfg: AgentRunConfig = { - systemPrompt: this._state.systemPrompt, - tools: this._state.tools, - model, - reasoning, - getQueuedMessages: async () => { - // Return queued messages (they'll be added to state via message_end event) - const queued = this.messageQueue.slice(); - this.messageQueue = []; - return queued as QueuedMessage[]; - }, - }; - - try { - let partial: Message | null = null; - let turnDebug: DebugLogEntry | null = null; - let turnStart = 0; - - this.logState("prompt started, current state:"); - - // Transform app messages to LLM-compatible messages (initial set) - const llmMessages = await this.messageTransformer(this._state.messages); - - console.log("transformed messages:", llmMessages); - for await (const ev of this.transport.run( - llmMessages, - userMessage as Message, - cfg, - this.abortController.signal, - )) { - switch (ev.type) { - case "turn_start": { - turnStart = performance.now(); - // Build request context snapshot (use transformed messages) - const ctx: Context = { - systemPrompt: this._state.systemPrompt, - messages: [...llmMessages], - tools: this._state.tools, - }; - turnDebug = { - timestamp: new Date().toISOString(), - request: { - provider: cfg.model.provider, - model: cfg.model.id, - context: { ...ctx }, - }, - sseEvents: [], - }; - break; - } - case "message_start": - case "message_update": { - partial = ev.message; - // Collect SSE-like events for debug (drop heavy partial) - if (ev.type === "message_update" && ev.assistantMessageEvent && turnDebug) { - const copy: any = { ...ev.assistantMessageEvent }; - if (copy && "partial" in copy) delete copy.partial; - turnDebug.sseEvents.push(JSON.stringify(copy)); - if (!turnDebug.ttft) turnDebug.ttft = performance.now() - turnStart; - } - this.patch({ streamMessage: ev.message }); - break; - } - case "message_end": { - partial = null; - this.appendMessage(ev.message as AppMessage); - this.patch({ streamMessage: null }); - if (turnDebug) { - if (ev.message.role !== "assistant" && ev.message.role !== "toolResult") { - turnDebug.request.context.messages.push(ev.message); - } - if (ev.message.role === "assistant") turnDebug.response = ev.message as any; - } - break; - } - case "tool_execution_start": { - const s = new Set(this._state.pendingToolCalls); - s.add(ev.toolCallId); - this.patch({ pendingToolCalls: s }); - break; - } - case "tool_execution_end": { - const s = new Set(this._state.pendingToolCalls); - s.delete(ev.toolCallId); - this.patch({ pendingToolCalls: s }); - break; - } - case "turn_end": { - // finalize current turn - if (turnDebug) { - turnDebug.totalTime = performance.now() - turnStart; - this.debugListener?.(turnDebug); - turnDebug = null; - } - break; - } - case "agent_end": { - this.patch({ streamMessage: null }); - break; - } - } - } - - if (partial && partial.role === "assistant" && partial.content.length > 0) { - const onlyEmpty = !partial.content.some( - (c) => - (c.type === "thinking" && c.thinking.trim().length > 0) || - (c.type === "text" && c.text.trim().length > 0) || - (c.type === "toolCall" && c.name.trim().length > 0), - ); - if (!onlyEmpty) { - this.appendMessage(partial as AppMessage); - } else { - if (this.abortController?.signal.aborted) { - throw new Error("Request was aborted"); - } - } - } - } catch (err: any) { - if (String(err?.message || err) === "no-api-key") { - this.emit({ type: "error-no-api-key", provider: model.provider }); - } else { - const msg: AssistantMessageType = { - role: "assistant", - content: [{ type: "text", text: "" }], - api: model.api, - provider: model.provider, - model: model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - stopReason: this.abortController?.signal.aborted ? "aborted" : "error", - errorMessage: err?.message || String(err), - timestamp: Date.now(), - }; - this.appendMessage(msg as AppMessage); - this.patch({ error: err?.message || String(err) }); - } - } finally { - this.patch({ isStreaming: false, streamMessage: null, pendingToolCalls: new Set() }); - this.abortController = undefined; - this.emit({ type: "completed" }); - } - this.logState("final state:"); - } - - private patch(p: Partial): void { - this._state = { ...this._state, ...p }; - this.emit({ type: "state-update", state: this._state }); - } - - private emit(e: AgentEvent) { - for (const listener of this.listeners) { - listener(e); - } - } -} diff --git a/packages/web-ui/src/agent/transports/AppTransport.ts b/packages/web-ui/src/agent/transports/AppTransport.ts deleted file mode 100644 index 90525a7b..00000000 --- a/packages/web-ui/src/agent/transports/AppTransport.ts +++ /dev/null @@ -1,371 +0,0 @@ -import type { - AgentContext, - AgentLoopConfig, - Api, - AssistantMessage, - AssistantMessageEvent, - Context, - Message, - Model, - SimpleStreamOptions, - ToolCall, - UserMessage, -} from "@mariozechner/pi-ai"; -import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai"; -import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js"; -import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js"; -import { clearAuthToken, getAuthToken } from "../../utils/auth-token.js"; -import { i18n } from "../../utils/i18n.js"; -import type { ProxyAssistantMessageEvent } from "./proxy-types.js"; -import type { AgentRunConfig, AgentTransport } from "./types.js"; - -/** - * Stream function that proxies through a server instead of calling providers directly. - * The server strips the partial field from delta events to reduce bandwidth. - * We reconstruct the partial message client-side. - */ -function streamSimpleProxy( - model: Model, - context: Context, - options: SimpleStreamOptions & { authToken: string }, - proxyUrl: string, -): AssistantMessageEventStream { - const stream = new AssistantMessageEventStream(); - - (async () => { - // Initialize the partial message that we'll build up from events - const partial: AssistantMessage = { - role: "assistant", - stopReason: "stop", - content: [], - api: model.api, - provider: model.provider, - model: model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, - timestamp: Date.now(), - }; - - let reader: ReadableStreamDefaultReader | undefined; - - // Set up abort handler to cancel the reader - const abortHandler = () => { - if (reader) { - reader.cancel("Request aborted by user").catch(() => {}); - } - }; - - if (options.signal) { - options.signal.addEventListener("abort", abortHandler); - } - - try { - const response = await fetch(`${proxyUrl}/api/stream`, { - method: "POST", - headers: { - Authorization: `Bearer ${options.authToken}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model, - context, - options: { - temperature: options.temperature, - maxTokens: options.maxTokens, - reasoning: options.reasoning, - // Don't send apiKey or signal - those are added server-side - }, - }), - signal: options.signal, - }); - - if (!response.ok) { - let errorMessage = `Proxy error: ${response.status} ${response.statusText}`; - try { - const errorData = await response.json(); - if (errorData.error) { - errorMessage = `Proxy error: ${errorData.error}`; - } - } catch { - // Couldn't parse error response, use default message - } - throw new Error(errorMessage); - } - - // Parse SSE stream - reader = response.body!.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - // Check if aborted after reading - if (options.signal?.aborted) { - throw new Error("Request aborted by user"); - } - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - if (line.startsWith("data: ")) { - const data = line.slice(6).trim(); - if (data) { - const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent; - let event: AssistantMessageEvent | undefined; - - // Handle different event types - // Server sends events with partial for non-delta events, - // and without partial for delta events - switch (proxyEvent.type) { - case "start": - event = { type: "start", partial }; - break; - - case "text_start": - partial.content[proxyEvent.contentIndex] = { - type: "text", - text: "", - }; - event = { type: "text_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "text_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "text") { - content.text += proxyEvent.delta; - event = { - type: "text_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - } else { - throw new Error("Received text_delta for non-text content"); - } - break; - } - case "text_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "text") { - content.textSignature = proxyEvent.contentSignature; - event = { - type: "text_end", - contentIndex: proxyEvent.contentIndex, - content: content.text, - partial, - }; - } else { - throw new Error("Received text_end for non-text content"); - } - break; - } - - case "thinking_start": - partial.content[proxyEvent.contentIndex] = { - type: "thinking", - thinking: "", - }; - event = { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "thinking_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "thinking") { - content.thinking += proxyEvent.delta; - event = { - type: "thinking_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - } else { - throw new Error("Received thinking_delta for non-thinking content"); - } - break; - } - - case "thinking_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "thinking") { - content.thinkingSignature = proxyEvent.contentSignature; - event = { - type: "thinking_end", - contentIndex: proxyEvent.contentIndex, - content: content.thinking, - partial, - }; - } else { - throw new Error("Received thinking_end for non-thinking content"); - } - break; - } - - case "toolcall_start": - partial.content[proxyEvent.contentIndex] = { - type: "toolCall", - id: proxyEvent.id, - name: proxyEvent.toolName, - arguments: {}, - partialJson: "", - } satisfies ToolCall & { partialJson: string } as ToolCall; - event = { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial }; - break; - - case "toolcall_delta": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "toolCall") { - (content as any).partialJson += proxyEvent.delta; - content.arguments = parseStreamingJson((content as any).partialJson) || {}; - event = { - type: "toolcall_delta", - contentIndex: proxyEvent.contentIndex, - delta: proxyEvent.delta, - partial, - }; - partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity - } else { - throw new Error("Received toolcall_delta for non-toolCall content"); - } - break; - } - - case "toolcall_end": { - const content = partial.content[proxyEvent.contentIndex]; - if (content?.type === "toolCall") { - delete (content as any).partialJson; - event = { - type: "toolcall_end", - contentIndex: proxyEvent.contentIndex, - toolCall: content, - partial, - }; - } - break; - } - - case "done": - partial.stopReason = proxyEvent.reason; - partial.usage = proxyEvent.usage; - event = { type: "done", reason: proxyEvent.reason, message: partial }; - break; - - case "error": - partial.stopReason = proxyEvent.reason; - partial.errorMessage = proxyEvent.errorMessage; - partial.usage = proxyEvent.usage; - event = { type: "error", reason: proxyEvent.reason, error: partial }; - break; - - default: { - // Exhaustive check - const _exhaustiveCheck: never = proxyEvent; - console.warn(`Unhandled event type: ${(proxyEvent as any).type}`); - break; - } - } - - // Push the event to stream - if (event) { - stream.push(event); - } else { - throw new Error("Failed to create event from proxy event"); - } - } - } - } - } - - // Check if aborted after reading - if (options.signal?.aborted) { - throw new Error("Request aborted by user"); - } - - stream.end(); - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error); - if (errorMessage.toLowerCase().includes("proxy") && errorMessage.includes("Unauthorized")) { - clearAuthToken(); - } - partial.stopReason = options.signal?.aborted ? "aborted" : "error"; - partial.errorMessage = errorMessage; - stream.push({ - type: "error", - reason: partial.stopReason, - error: partial, - } satisfies AssistantMessageEvent); - stream.end(); - } finally { - // Clean up abort handler - if (options.signal) { - options.signal.removeEventListener("abort", abortHandler); - } - } - })(); - - return stream; -} - -/** - * Transport that uses an app server with user authentication tokens. - * The server manages user accounts and proxies requests to LLM providers. - */ -export class AppTransport implements AgentTransport { - private readonly proxyUrl = "https://genai.mariozechner.at"; - - private async getStreamFn() { - const authToken = await getAuthToken(); - if (!authToken) { - throw new Error(i18n("Auth token is required for proxy transport")); - } - - return (model: Model, context: Context, options?: SimpleStreamOptions) => { - return streamSimpleProxy(model, context, { ...options, authToken }, this.proxyUrl); - }; - } - - private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext { - return { - systemPrompt: cfg.systemPrompt, - messages, - tools: cfg.tools, - }; - } - - private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig { - return { - model: cfg.model, - reasoning: cfg.reasoning, - getQueuedMessages: cfg.getQueuedMessages, - }; - } - - async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) { - const streamFn = await this.getStreamFn(); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(cfg); - - for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) { - yield ev; - } - } - - async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) { - const streamFn = await this.getStreamFn(); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(cfg); - - for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) { - yield ev; - } - } -} diff --git a/packages/web-ui/src/agent/transports/ProviderTransport.ts b/packages/web-ui/src/agent/transports/ProviderTransport.ts deleted file mode 100644 index b16991d1..00000000 --- a/packages/web-ui/src/agent/transports/ProviderTransport.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { - type AgentContext, - type AgentLoopConfig, - agentLoop, - agentLoopContinue, - type Message, - type UserMessage, -} from "@mariozechner/pi-ai"; -import { getAppStorage } from "../../storage/app-storage.js"; -import { applyProxyIfNeeded } from "../../utils/proxy-utils.js"; -import type { AgentRunConfig, AgentTransport } from "./types.js"; - -/** - * Transport that calls LLM providers directly. - * Uses CORS proxy only for providers that require it (Anthropic OAuth, Z-AI). - */ -export class ProviderTransport implements AgentTransport { - private async getModel(cfg: AgentRunConfig) { - const apiKey = await getAppStorage().providerKeys.get(cfg.model.provider); - if (!apiKey) { - throw new Error("no-api-key"); - } - - const proxyEnabled = await getAppStorage().settings.get("proxy.enabled"); - const proxyUrl = await getAppStorage().settings.get("proxy.url"); - const model = applyProxyIfNeeded(cfg.model, apiKey, proxyEnabled ? proxyUrl || undefined : undefined); - - return model; - } - - private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext { - return { - systemPrompt: cfg.systemPrompt, - messages, - tools: cfg.tools, - }; - } - - private buildLoopConfig(model: AgentRunConfig["model"], cfg: AgentRunConfig): AgentLoopConfig { - return { - model, - reasoning: cfg.reasoning, - // Resolve API key per assistant response (important for expiring OAuth tokens) - getApiKey: async (provider: string) => { - const key = await getAppStorage().providerKeys.get(provider); - return key ?? undefined; // Convert null to undefined for type compatibility - }, - getQueuedMessages: cfg.getQueuedMessages, - }; - } - - async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) { - const model = await this.getModel(cfg); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(model, cfg); - - for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) { - yield ev; - } - } - - async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) { - const model = await this.getModel(cfg); - const context = this.buildContext(messages, cfg); - const pc = this.buildLoopConfig(model, cfg); - - for await (const ev of agentLoopContinue(context, pc, signal)) { - yield ev; - } - } -} diff --git a/packages/web-ui/src/agent/transports/index.ts b/packages/web-ui/src/agent/transports/index.ts deleted file mode 100644 index 8dd56057..00000000 --- a/packages/web-ui/src/agent/transports/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from "./AppTransport.js"; -export * from "./ProviderTransport.js"; -export * from "./types.js"; diff --git a/packages/web-ui/src/agent/transports/proxy-types.ts b/packages/web-ui/src/agent/transports/proxy-types.ts deleted file mode 100644 index 94d4dbf9..00000000 --- a/packages/web-ui/src/agent/transports/proxy-types.ts +++ /dev/null @@ -1,15 +0,0 @@ -import type { StopReason, Usage } from "@mariozechner/pi-ai"; - -export type ProxyAssistantMessageEvent = - | { type: "start" } - | { type: "text_start"; contentIndex: number } - | { type: "text_delta"; contentIndex: number; delta: string } - | { type: "text_end"; contentIndex: number; contentSignature?: string } - | { type: "thinking_start"; contentIndex: number } - | { type: "thinking_delta"; contentIndex: number; delta: string } - | { type: "thinking_end"; contentIndex: number; contentSignature?: string } - | { type: "toolcall_start"; contentIndex: number; id: string; toolName: string } - | { type: "toolcall_delta"; contentIndex: number; delta: string } - | { type: "toolcall_end"; contentIndex: number } - | { type: "done"; reason: Extract; usage: Usage } - | { type: "error"; reason: Extract; errorMessage: string; usage: Usage }; diff --git a/packages/web-ui/src/agent/transports/types.ts b/packages/web-ui/src/agent/transports/types.ts deleted file mode 100644 index 74d28628..00000000 --- a/packages/web-ui/src/agent/transports/types.ts +++ /dev/null @@ -1,26 +0,0 @@ -import type { AgentEvent, AgentTool, Message, Model, QueuedMessage } from "@mariozechner/pi-ai"; - -// The minimal configuration needed to run a turn. -export interface AgentRunConfig { - systemPrompt: string; - tools: AgentTool[]; - model: Model; - reasoning?: "low" | "medium" | "high"; - getQueuedMessages?: () => Promise[]>; -} - -// Events yielded by transports must match the @mariozechner/pi-ai prompt() events. -// We re-export the Message type above; consumers should use the upstream AgentEvent type. - -export interface AgentTransport { - /** Run with a new user message */ - run( - messages: Message[], - userMessage: Message, - config: AgentRunConfig, - signal?: AbortSignal, - ): AsyncIterable; - - /** Continue from current context (no new user message) */ - continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable; -} diff --git a/packages/web-ui/src/agent/types.ts b/packages/web-ui/src/agent/types.ts deleted file mode 100644 index c5513941..00000000 --- a/packages/web-ui/src/agent/types.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { AssistantMessage, Context } from "@mariozechner/pi-ai"; - -export interface DebugLogEntry { - timestamp: string; - request: { provider: string; model: string; context: Context }; - response?: AssistantMessage; - error?: unknown; - sseEvents: string[]; - ttft?: number; - totalTime?: number; -} diff --git a/packages/web-ui/src/components/AgentInterface.ts b/packages/web-ui/src/components/AgentInterface.ts index 3d44faa3..fd4872a5 100644 --- a/packages/web-ui/src/components/AgentInterface.ts +++ b/packages/web-ui/src/components/AgentInterface.ts @@ -1,4 +1,4 @@ -import type { ToolResultMessage, Usage } from "@mariozechner/pi-ai"; +import { streamSimple, type ToolResultMessage, type Usage } from "@mariozechner/pi-ai"; import { html, LitElement } from "lit"; import { customElement, property, query } from "lit/decorators.js"; import { ModelSelector } from "../dialogs/ModelSelector.js"; @@ -6,12 +6,14 @@ import type { MessageEditor } from "./MessageEditor.js"; import "./MessageEditor.js"; import "./MessageList.js"; import "./Messages.js"; // Import for side effects to register the custom elements -import type { Agent, AgentEvent } from "../agent/agent.js"; import { getAppStorage } from "../storage/app-storage.js"; import "./StreamingMessageContainer.js"; +import type { Agent, AgentEvent } from "@mariozechner/pi-agent-core"; import type { Attachment } from "../utils/attachment-utils.js"; import { formatUsage } from "../utils/format.js"; import { i18n } from "../utils/i18n.js"; +import { createStreamFn } from "../utils/proxy-utils.js"; +import type { UserMessageWithAttachments } from "./Messages.js"; import type { StreamingMessageContainer } from "./StreamingMessageContainer.js"; @customElement("agent-interface") @@ -129,17 +131,48 @@ export class AgentInterface extends LitElement { this._unsubscribeSession = undefined; } if (!this.session) return; + + // Set default streamFn with proxy support if not already set + if (this.session.streamFn === streamSimple) { + this.session.streamFn = createStreamFn(async () => { + const enabled = await getAppStorage().settings.get("proxy.enabled"); + return enabled ? (await getAppStorage().settings.get("proxy.url")) || undefined : undefined; + }); + } + + // Set default getApiKey if not already set + if (!this.session.getApiKey) { + this.session.getApiKey = async (provider: string) => { + const key = await getAppStorage().providerKeys.get(provider); + return key ?? undefined; + }; + } + this._unsubscribeSession = this.session.subscribe(async (ev: AgentEvent) => { - if (ev.type === "state-update") { - if (this._streamingContainer) { - this._streamingContainer.isStreaming = ev.state.isStreaming; - this._streamingContainer.setMessage(ev.state.streamMessage, !ev.state.isStreaming); - } - this.requestUpdate(); - } else if (ev.type === "error-no-model") { - // TODO show some UI feedback - } else if (ev.type === "error-no-api-key") { - // Handled by onApiKeyRequired callback + switch (ev.type) { + case "message_start": + case "message_end": + case "turn_start": + case "turn_end": + case "agent_start": + this.requestUpdate(); + break; + case "agent_end": + // Clear streaming container when agent finishes + if (this._streamingContainer) { + this._streamingContainer.isStreaming = false; + this._streamingContainer.setMessage(null, true); + } + this.requestUpdate(); + break; + case "message_update": + if (this._streamingContainer) { + const isStreaming = this.session?.state.isStreaming || false; + this._streamingContainer.isStreaming = isStreaming; + this._streamingContainer.setMessage(ev.message, !isStreaming); + } + this.requestUpdate(); + break; } }); } @@ -205,7 +238,18 @@ export class AgentInterface extends LitElement { this._messageEditor.attachments = []; this._autoScroll = true; // Enable auto-scroll when sending a message - await this.session?.prompt(input, attachments); + // Compose message with attachments if any + if (attachments && attachments.length > 0) { + const message: UserMessageWithAttachments = { + role: "user-with-attachments", + content: input, + attachments, + timestamp: Date.now(), + }; + await this.session?.prompt(message); + } else { + await this.session?.prompt(input); + } } private renderMessages() { diff --git a/packages/web-ui/src/components/MessageEditor.ts b/packages/web-ui/src/components/MessageEditor.ts index ae286d05..78e44230 100644 --- a/packages/web-ui/src/components/MessageEditor.ts +++ b/packages/web-ui/src/components/MessageEditor.ts @@ -9,6 +9,7 @@ import { Brain, Loader2, Paperclip, Send, Sparkles, Square } from "lucide"; import { type Attachment, loadAttachment } from "../utils/attachment-utils.js"; import { i18n } from "../utils/i18n.js"; import "./AttachmentTile.js"; +import type { ThinkingLevel } from "@mariozechner/pi-agent-core"; @customElement("message-editor") export class MessageEditor extends LitElement { @@ -28,7 +29,7 @@ export class MessageEditor extends LitElement { @property() isStreaming = false; @property() currentModel?: Model; - @property() thinkingLevel: "off" | "minimal" | "low" | "medium" | "high" = "off"; + @property() thinkingLevel: ThinkingLevel = "off"; @property() showAttachmentButton = true; @property() showModelSelector = true; @property() showThinkingSelector = true; diff --git a/packages/web-ui/src/components/MessageList.ts b/packages/web-ui/src/components/MessageList.ts index 0cad8bd0..97670ba3 100644 --- a/packages/web-ui/src/components/MessageList.ts +++ b/packages/web-ui/src/components/MessageList.ts @@ -1,16 +1,15 @@ +import type { AgentMessage, AgentTool } from "@mariozechner/pi-agent-core"; import type { - AgentTool, AssistantMessage as AssistantMessageType, ToolResultMessage as ToolResultMessageType, } from "@mariozechner/pi-ai"; import { html, LitElement, type TemplateResult } from "lit"; import { property } from "lit/decorators.js"; import { repeat } from "lit/directives/repeat.js"; -import type { AppMessage } from "./Messages.js"; import { renderMessage } from "./message-renderer-registry.js"; export class MessageList extends LitElement { - @property({ type: Array }) messages: AppMessage[] = []; + @property({ type: Array }) messages: AgentMessage[] = []; @property({ type: Array }) tools: AgentTool[] = []; @property({ type: Object }) pendingToolCalls?: Set; @property({ type: Boolean }) isStreaming: boolean = false; @@ -51,7 +50,7 @@ export class MessageList extends LitElement { } // Fall back to built-in renderers - if (msg.role === "user") { + if (msg.role === "user" || msg.role === "user-with-attachments") { items.push({ key: `msg:${index}`, template: html``, diff --git a/packages/web-ui/src/components/Messages.ts b/packages/web-ui/src/components/Messages.ts index 73e9a3e6..7d68ff4a 100644 --- a/packages/web-ui/src/components/Messages.ts +++ b/packages/web-ui/src/components/Messages.ts @@ -1,6 +1,7 @@ import type { - AgentTool, AssistantMessage as AssistantMessageType, + ImageContent, + TextContent, ToolCall, ToolResultMessage as ToolResultMessageType, UserMessage as UserMessageType, @@ -12,8 +13,14 @@ import type { Attachment } from "../utils/attachment-utils.js"; import { formatUsage } from "../utils/format.js"; import { i18n } from "../utils/i18n.js"; import "./ThinkingBlock.js"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; -export type UserMessageWithAttachments = UserMessageType & { attachments?: Attachment[] }; +export type UserMessageWithAttachments = { + role: "user-with-attachments"; + content: string | (TextContent | ImageContent)[]; + timestamp: number; + attachments?: Attachment[]; +}; // Artifact message type for session persistence export interface ArtifactMessage { @@ -25,26 +32,16 @@ export interface ArtifactMessage { timestamp: string; } -// Base message union -type BaseMessage = AssistantMessageType | UserMessageWithAttachments | ToolResultMessageType | ArtifactMessage; - -// Extensible interface - apps can extend via declaration merging -// Example: -// declare module "@mariozechner/pi-web-ui" { -// interface CustomMessages { -// "system-notification": SystemNotificationMessage; -// } -// } -export interface CustomMessages { - // Empty by default - apps extend via declaration merging +declare module "@mariozechner/pi-agent-core" { + interface CustomAgentMessages { + "user-with-attachments": UserMessageWithAttachments; + artifact: ArtifactMessage; + } } -// AppMessage is union of base messages + custom messages -export type AppMessage = BaseMessage | CustomMessages[keyof CustomMessages]; - @customElement("user-message") export class UserMessage extends LitElement { - @property({ type: Object }) message!: UserMessageWithAttachments; + @property({ type: Object }) message!: UserMessageWithAttachments | UserMessageType; protected override createRenderRoot(): HTMLElement | DocumentFragment { return this; @@ -66,7 +63,9 @@ export class UserMessage extends LitElement {
${ - this.message.attachments && this.message.attachments.length > 0 + this.message.role === "user-with-attachments" && + this.message.attachments && + this.message.attachments.length > 0 ? html`
${this.message.attachments.map( @@ -286,3 +285,93 @@ export class AbortedMessage extends LitElement { return html`${i18n("Request aborted")}`; } } + +// ============================================================================ +// Default Message Transformer +// ============================================================================ + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { Message } from "@mariozechner/pi-ai"; + +/** + * Convert attachments to content blocks for LLM. + * - Images become ImageContent blocks + * - Documents with extractedText become TextContent blocks with filename header + */ +export function convertAttachments(attachments: Attachment[]): (TextContent | ImageContent)[] { + const content: (TextContent | ImageContent)[] = []; + for (const attachment of attachments) { + if (attachment.type === "image") { + content.push({ + type: "image", + data: attachment.content, + mimeType: attachment.mimeType, + } as ImageContent); + } else if (attachment.type === "document" && attachment.extractedText) { + content.push({ + type: "text", + text: `\n\n[Document: ${attachment.fileName}]\n${attachment.extractedText}`, + } as TextContent); + } + } + return content; +} + +/** + * Check if a message is a UserMessageWithAttachments. + */ +export function isUserMessageWithAttachments(msg: AgentMessage): msg is UserMessageWithAttachments { + return (msg as UserMessageWithAttachments).role === "user-with-attachments"; +} + +/** + * Check if a message is an ArtifactMessage. + */ +export function isArtifactMessage(msg: AgentMessage): msg is ArtifactMessage { + return (msg as ArtifactMessage).role === "artifact"; +} + +/** + * Default convertToLlm for web-ui apps. + * + * Handles: + * - UserMessageWithAttachments: converts to user message with content blocks + * - ArtifactMessage: filtered out (UI-only, for session reconstruction) + * - Standard LLM messages (user, assistant, toolResult): passed through + */ +export function defaultConvertToLlm(messages: AgentMessage[]): Message[] { + return messages + .filter((m) => { + // Filter out artifact messages - they're for session reconstruction only + if (isArtifactMessage(m)) { + return false; + } + return true; + }) + .map((m): Message | null => { + // Convert user-with-attachments to user message with content blocks + if (isUserMessageWithAttachments(m)) { + const textContent: (TextContent | ImageContent)[] = + typeof m.content === "string" ? [{ type: "text", text: m.content }] : [...m.content]; + + if (m.attachments) { + textContent.push(...convertAttachments(m.attachments)); + } + + return { + role: "user", + content: textContent, + timestamp: m.timestamp, + } as Message; + } + + // Pass through standard LLM roles + if (m.role === "user" || m.role === "assistant" || m.role === "toolResult") { + return m as Message; + } + + // Filter out unknown message types + return null; + }) + .filter((m): m is Message => m !== null); +} diff --git a/packages/web-ui/src/components/StreamingMessageContainer.ts b/packages/web-ui/src/components/StreamingMessageContainer.ts index 3b5790ea..3d269315 100644 --- a/packages/web-ui/src/components/StreamingMessageContainer.ts +++ b/packages/web-ui/src/components/StreamingMessageContainer.ts @@ -1,4 +1,5 @@ -import type { AgentTool, Message, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { AgentMessage, AgentTool } from "@mariozechner/pi-agent-core"; +import type { ToolResultMessage } from "@mariozechner/pi-ai"; import { html, LitElement } from "lit"; import { property, state } from "lit/decorators.js"; @@ -9,8 +10,8 @@ export class StreamingMessageContainer extends LitElement { @property({ type: Object }) toolResultsById?: Map; @property({ attribute: false }) onCostClick?: () => void; - @state() private _message: Message | null = null; - private _pendingMessage: Message | null = null; + @state() private _message: AgentMessage | null = null; + private _pendingMessage: AgentMessage | null = null; private _updateScheduled = false; private _immediateUpdate = false; @@ -24,7 +25,7 @@ export class StreamingMessageContainer extends LitElement { } // Public method to update the message with batching for performance - public setMessage(message: Message | null, immediate = false) { + public setMessage(message: AgentMessage | null, immediate = false) { // Store the latest message this._pendingMessage = message; @@ -73,7 +74,7 @@ export class StreamingMessageContainer extends LitElement { if (msg.role === "toolResult") { // Skip standalone tool result in streaming; the stable list will render paired tool-message return html``; - } else if (msg.role === "user") { + } else if (msg.role === "user" || msg.role === "user-with-attachments") { // Skip standalone tool result in streaming; the stable list will render it immediiately return html``; } else if (msg.role === "assistant") { diff --git a/packages/web-ui/src/components/message-renderer-registry.ts b/packages/web-ui/src/components/message-renderer-registry.ts index eac4689e..51f84a48 100644 --- a/packages/web-ui/src/components/message-renderer-registry.ts +++ b/packages/web-ui/src/components/message-renderer-registry.ts @@ -1,11 +1,11 @@ +import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { TemplateResult } from "lit"; -import type { AppMessage } from "./Messages.js"; // Extract role type from AppMessage union -export type MessageRole = AppMessage["role"]; +export type MessageRole = AgentMessage["role"]; // Generic message renderer typed to specific message type -export interface MessageRenderer { +export interface MessageRenderer { render(message: TMessage): TemplateResult; } @@ -14,7 +14,7 @@ const messageRenderers = new Map>(); export function registerMessageRenderer( role: TRole, - renderer: MessageRenderer>, + renderer: MessageRenderer>, ): void { messageRenderers.set(role, renderer); } @@ -23,6 +23,6 @@ export function getMessageRenderer(role: MessageRole): MessageRenderer | undefin return messageRenderers.get(role); } -export function renderMessage(message: AppMessage): TemplateResult | undefined { +export function renderMessage(message: AgentMessage): TemplateResult | undefined { return messageRenderers.get(message.role)?.render(message); } diff --git a/packages/web-ui/src/index.ts b/packages/web-ui/src/index.ts index 547cb959..cff990b0 100644 --- a/packages/web-ui/src/index.ts +++ b/packages/web-ui/src/index.ts @@ -1,13 +1,7 @@ // Main chat interface -export type { AgentState, ThinkingLevel } from "./agent/agent.js"; -// State management -export { Agent } from "./agent/agent.js"; -// Transports -export { AppTransport } from "./agent/transports/AppTransport.js"; -export { ProviderTransport } from "./agent/transports/ProviderTransport.js"; -export type { ProxyAssistantMessageEvent } from "./agent/transports/proxy-types.js"; -export type { AgentRunConfig, AgentTransport } from "./agent/transports/types.js"; +export type { Agent, AgentMessage, AgentState, ThinkingLevel } from "@mariozechner/pi-agent-core"; +export type { Model } from "@mariozechner/pi-ai"; export { ChatPanel } from "./ChatPanel.js"; // Components export { AgentInterface } from "./components/AgentInterface.js"; @@ -18,8 +12,16 @@ export { Input } from "./components/Input.js"; export { MessageEditor } from "./components/MessageEditor.js"; export { MessageList } from "./components/MessageList.js"; // Message components -export type { AppMessage, CustomMessages, UserMessageWithAttachments } from "./components/Messages.js"; -export { AssistantMessage, ToolMessage, UserMessage } from "./components/Messages.js"; +export type { ArtifactMessage, UserMessageWithAttachments } from "./components/Messages.js"; +export { + AssistantMessage, + convertAttachments, + defaultConvertToLlm, + isArtifactMessage, + isUserMessageWithAttachments, + ToolMessage, + UserMessage, +} from "./components/Messages.js"; // Message renderer registry export { getMessageRenderer, @@ -110,3 +112,4 @@ export { loadAttachment } from "./utils/attachment-utils.js"; export { clearAuthToken, getAuthToken } from "./utils/auth-token.js"; export { formatCost, formatModelCost, formatTokenCount, formatUsage } from "./utils/format.js"; export { i18n, setLanguage, translations } from "./utils/i18n.js"; +export { applyProxyIfNeeded, createStreamFn, isCorsError, shouldUseProxyForProvider } from "./utils/proxy-utils.js"; diff --git a/packages/web-ui/src/storage/stores/sessions-store.ts b/packages/web-ui/src/storage/stores/sessions-store.ts index 40a34edb..b0084fdd 100644 --- a/packages/web-ui/src/storage/stores/sessions-store.ts +++ b/packages/web-ui/src/storage/stores/sessions-store.ts @@ -1,4 +1,4 @@ -import type { AgentState } from "../../agent/agent.js"; +import type { AgentState } from "@mariozechner/pi-agent-core"; import { Store } from "../store.js"; import type { SessionData, SessionMetadata, StoreConfig } from "../types.js"; diff --git a/packages/web-ui/src/storage/types.ts b/packages/web-ui/src/storage/types.ts index 038f9657..3bbcf602 100644 --- a/packages/web-ui/src/storage/types.ts +++ b/packages/web-ui/src/storage/types.ts @@ -1,6 +1,5 @@ +import type { AgentMessage, ThinkingLevel } from "@mariozechner/pi-agent-core"; import type { Model } from "@mariozechner/pi-ai"; -import type { ThinkingLevel } from "../agent/agent.js"; -import type { AppMessage } from "../components/Messages.js"; /** * Transaction interface for atomic operations across stores. @@ -159,7 +158,7 @@ export interface SessionData { thinkingLevel: ThinkingLevel; /** Full conversation history (with attachments inline) */ - messages: AppMessage[]; + messages: AgentMessage[]; /** ISO 8601 UTC timestamp of creation */ createdAt: string; diff --git a/packages/web-ui/src/tools/artifacts/artifacts.ts b/packages/web-ui/src/tools/artifacts/artifacts.ts index 5aac9066..0a5474a7 100644 --- a/packages/web-ui/src/tools/artifacts/artifacts.ts +++ b/packages/web-ui/src/tools/artifacts/artifacts.ts @@ -1,13 +1,13 @@ import { icon } from "@mariozechner/mini-lit"; import "@mariozechner/mini-lit/dist/MarkdownBlock.js"; import { Button } from "@mariozechner/mini-lit/dist/Button.js"; -import { type AgentTool, type Message, StringEnum, type ToolCall } from "@mariozechner/pi-ai"; +import type { Agent, AgentMessage, AgentTool } from "@mariozechner/pi-agent-core"; +import { StringEnum, type ToolCall } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import { html, LitElement, type TemplateResult } from "lit"; import { customElement, property, state } from "lit/decorators.js"; import { createRef, type Ref, ref } from "lit/directives/ref.js"; import { X } from "lucide"; -import type { Agent } from "../../agent/agent.js"; import type { ArtifactMessage } from "../../components/Messages.js"; import { ArtifactsRuntimeProvider } from "../../components/sandbox/ArtifactsRuntimeProvider.js"; import { AttachmentsRuntimeProvider } from "../../components/sandbox/AttachmentsRuntimeProvider.js"; @@ -85,7 +85,7 @@ export class ArtifactsPanel extends LitElement { if (this.agent) { const attachments: Attachment[] = []; for (const message of this.agent.state.messages) { - if (message.role === "user" && message.attachments) { + if (message.role === "user-with-attachments" && message.attachments) { attachments.push(...message.attachments); } } @@ -292,7 +292,7 @@ export class ArtifactsPanel extends LitElement { // Re-apply artifacts by scanning a message list (optional utility) public async reconstructFromMessages( - messages: Array, + messages: Array, ): Promise { const toolCalls = new Map(); const artifactToolName = "artifacts"; diff --git a/packages/web-ui/src/tools/extract-document.ts b/packages/web-ui/src/tools/extract-document.ts index 73eddcb4..b733c7ee 100644 --- a/packages/web-ui/src/tools/extract-document.ts +++ b/packages/web-ui/src/tools/extract-document.ts @@ -1,4 +1,5 @@ -import type { AgentTool, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import type { ToolResultMessage } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import { html } from "lit"; import { createRef, ref } from "lit/directives/ref.js"; diff --git a/packages/web-ui/src/tools/javascript-repl.ts b/packages/web-ui/src/tools/javascript-repl.ts index 11e9233c..c42ed9e7 100644 --- a/packages/web-ui/src/tools/javascript-repl.ts +++ b/packages/web-ui/src/tools/javascript-repl.ts @@ -1,5 +1,6 @@ import { i18n } from "@mariozechner/mini-lit"; -import type { AgentTool, ToolResultMessage } from "@mariozechner/pi-ai"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import type { ToolResultMessage } from "@mariozechner/pi-ai"; import { type Static, Type } from "@sinclair/typebox"; import { html } from "lit"; import { createRef, ref } from "lit/directives/ref.js"; diff --git a/packages/web-ui/src/utils/proxy-utils.ts b/packages/web-ui/src/utils/proxy-utils.ts index 65c68e57..27816b18 100644 --- a/packages/web-ui/src/utils/proxy-utils.ts +++ b/packages/web-ui/src/utils/proxy-utils.ts @@ -1,4 +1,5 @@ -import type { Api, Model } from "@mariozechner/pi-ai"; +import type { Api, Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai"; +import { streamSimple } from "@mariozechner/pi-ai"; /** * Centralized proxy decision logic. @@ -110,3 +111,24 @@ export function isCorsError(error: unknown): boolean { return false; } + +/** + * Create a streamFn that applies CORS proxy when needed. + * Reads proxy settings from storage on each call. + * + * @param getProxyUrl - Async function to get current proxy URL (or undefined if disabled) + * @returns A streamFn compatible with Agent's streamFn option + */ +export function createStreamFn(getProxyUrl: () => Promise) { + return async (model: Model, context: Context, options?: SimpleStreamOptions) => { + const apiKey = options?.apiKey; + const proxyUrl = await getProxyUrl(); + + if (!apiKey || !proxyUrl) { + return streamSimple(model, context, options); + } + + const proxiedModel = applyProxyIfNeeded(model, apiKey, proxyUrl); + return streamSimple(proxiedModel, context, options); + }; +} diff --git a/scripts/release.mjs b/scripts/release.mjs new file mode 100755 index 00000000..12b33d04 --- /dev/null +++ b/scripts/release.mjs @@ -0,0 +1,145 @@ +#!/usr/bin/env node +/** + * Release script for pi-mono + * + * Usage: node scripts/release.mjs + * + * Steps: + * 1. Check for uncommitted changes + * 2. Bump version via npm run version:xxx + * 3. Update CHANGELOG.md files: [Unreleased] -> [version] - date + * 4. Commit and tag + * 5. Publish to npm + * 6. Add new [Unreleased] section to changelogs + * 7. Commit + */ + +import { execSync } from "child_process"; +import { readFileSync, writeFileSync, readdirSync, existsSync } from "fs"; +import { join } from "path"; + +const BUMP_TYPE = process.argv[2]; + +if (!["major", "minor", "patch"].includes(BUMP_TYPE)) { + console.error("Usage: node scripts/release.mjs "); + process.exit(1); +} + +function run(cmd, options = {}) { + console.log(`$ ${cmd}`); + try { + return execSync(cmd, { encoding: "utf-8", stdio: options.silent ? "pipe" : "inherit", ...options }); + } catch (e) { + if (!options.ignoreError) { + console.error(`Command failed: ${cmd}`); + process.exit(1); + } + return null; + } +} + +function getVersion() { + const pkg = JSON.parse(readFileSync("packages/ai/package.json", "utf-8")); + return pkg.version; +} + +function getChangelogs() { + const packagesDir = "packages"; + const packages = readdirSync(packagesDir); + return packages + .map((pkg) => join(packagesDir, pkg, "CHANGELOG.md")) + .filter((path) => existsSync(path)); +} + +function updateChangelogsForRelease(version) { + const date = new Date().toISOString().split("T")[0]; + const changelogs = getChangelogs(); + + for (const changelog of changelogs) { + const content = readFileSync(changelog, "utf-8"); + + if (!content.includes("## [Unreleased]")) { + console.log(` Skipping ${changelog}: no [Unreleased] section`); + continue; + } + + const updated = content.replace( + "## [Unreleased]", + `## [${version}] - ${date}` + ); + writeFileSync(changelog, updated); + console.log(` Updated ${changelog}`); + } +} + +function addUnreleasedSection() { + const changelogs = getChangelogs(); + const unreleasedSection = "## [Unreleased]\n\n"; + + for (const changelog of changelogs) { + const content = readFileSync(changelog, "utf-8"); + + // Insert after "# Changelog\n\n" + const updated = content.replace( + /^(# Changelog\n\n)/, + `$1${unreleasedSection}` + ); + writeFileSync(changelog, updated); + console.log(` Added [Unreleased] to ${changelog}`); + } +} + +// Main flow +console.log("\n=== Release Script ===\n"); + +// 1. Check for uncommitted changes +console.log("Checking for uncommitted changes..."); +const status = run("git status --porcelain", { silent: true }); +if (status && status.trim()) { + console.error("Error: Uncommitted changes detected. Commit or stash first."); + console.error(status); + process.exit(1); +} +console.log(" Working directory clean\n"); + +// 2. Bump version +console.log(`Bumping version (${BUMP_TYPE})...`); +run(`npm run version:${BUMP_TYPE}`); +const version = getVersion(); +console.log(` New version: ${version}\n`); + +// 3. Update changelogs +console.log("Updating CHANGELOG.md files..."); +updateChangelogsForRelease(version); +console.log(); + +// 4. Commit and tag +console.log("Committing and tagging..."); +run("git add ."); +run(`git commit -m "Release v${version}"`); +run(`git tag v${version}`); +console.log(); + +// 5. Publish +console.log("Publishing to npm..."); +run("npm run publish"); +console.log(); + +// 6. Add new [Unreleased] sections +console.log("Adding [Unreleased] sections for next cycle..."); +addUnreleasedSection(); +console.log(); + +// 7. Commit +console.log("Committing changelog updates..."); +run("git add ."); +run(`git commit -m "Add [Unreleased] section for next cycle"`); +console.log(); + +// 8. Push +console.log("Pushing to remote..."); +run("git push origin main"); +run(`git push origin v${version}`); +console.log(); + +console.log(`=== Released v${version} ===`); diff --git a/tsconfig.json b/tsconfig.json index 2f00b8ea..62eb337a 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -17,8 +17,6 @@ "@mariozechner/pi-mom/*": ["./packages/mom/src/*"], "@mariozechner/pi": ["./packages/pods/src/index.ts"], "@mariozechner/pi/*": ["./packages/pods/src/*"], - "@mariozechner/pi-proxy": ["./packages/proxy/src/index.ts"], - "@mariozechner/pi-proxy/*": ["./packages/proxy/src/*"], "@mariozechner/pi-tui": ["./packages/tui/src/index.ts"], "@mariozechner/pi-tui/*": ["./packages/tui/src/*"], "@mariozechner/pi-web-ui": ["./packages/web-ui/src/index.ts"],