Merge session-tree: tree structure with branching, compaction, and hook API improvements

This commit is contained in:
Mario Zechner 2025-12-30 22:45:57 +01:00
commit 1f3f851185
174 changed files with 10978 additions and 6295 deletions

1
.gitignore vendored
View file

@ -29,3 +29,4 @@ compaction-results/
.opencode/
syntax.jsonl
out.jsonl
pi-*.html

12
.pi/commands/review.md Normal file
View file

@ -0,0 +1,12 @@
---
description: Review a file for issues
---
Please review the following file for potential issues, bugs, or improvements:
$1
Focus on:
- Logic errors
- Edge cases
- Code style
- Performance concerns

86
.pi/hooks/test-command.ts Normal file
View file

@ -0,0 +1,86 @@
/**
* Test hook demonstrating custom commands, message rendering, and before_agent_start.
*/
import type { BeforeAgentStartEvent, HookAPI } from "@mariozechner/pi-coding-agent";
import { Box, Spacer, Text } from "@mariozechner/pi-tui";
export default function (pi: HookAPI) {
// Track whether injection is enabled
let injectEnabled = false;
// Register a custom message renderer for our "test-info" type
pi.registerMessageRenderer("test-info", (message, options, theme) => {
const box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
const label = theme.fg("success", "[TEST INFO]");
box.addChild(new Text(label, 0, 0));
box.addChild(new Spacer(1));
const content =
typeof message.content === "string"
? message.content
: message.content.map((c) => (c.type === "text" ? c.text : "[image]")).join("");
box.addChild(new Text(theme.fg("text", content), 0, 0));
if (options.expanded && message.details) {
box.addChild(new Spacer(1));
box.addChild(new Text(theme.fg("dim", `Details: ${JSON.stringify(message.details)}`), 0, 0));
}
return box;
});
// Register /test-msg command
pi.registerCommand("test-msg", {
description: "Send a test custom message",
handler: async () => {
pi.sendMessage(
{
customType: "test-info",
content: "This is a test message with custom rendering!",
display: true,
details: { timestamp: Date.now(), source: "test-command hook" },
},
true, // triggerTurn: start agent run
);
},
});
// Register /test-hidden command
pi.registerCommand("test-hidden", {
description: "Send a hidden message (display: false)",
handler: async (ctx) => {
pi.sendMessage({
customType: "test-info",
content: "This message is in context but not displayed",
display: false,
});
ctx.ui.notify("Sent hidden message (check session file)");
},
});
// Register /test-inject command to toggle before_agent_start injection
pi.registerCommand("test-inject", {
description: "Toggle context injection before agent starts",
handler: async (ctx) => {
injectEnabled = !injectEnabled;
ctx.ui.notify(`Context injection ${injectEnabled ? "enabled" : "disabled"}`);
},
});
// Demonstrate before_agent_start: inject context when enabled
pi.on("before_agent_start", async (event: BeforeAgentStartEvent) => {
if (!injectEnabled) return;
// Return a message to inject before the user's prompt
return {
message: {
customType: "test-info",
content: `[Injected context for prompt: "${event.prompt.slice(0, 50)}..."]`,
display: true,
details: { injectedAt: Date.now() },
},
};
});
}

View file

@ -30,7 +30,7 @@ When reading issues:
When creating issues:
- Add `pkg:*` labels to indicate which package(s) the issue affects
- Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:proxy`, `pkg:tui`, `pkg:web-ui`
- Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:tui`, `pkg:web-ui`
- If an issue spans multiple packages, add all relevant labels
When closing issues via commit:
@ -39,7 +39,7 @@ When closing issues via commit:
## Tools
- GitHub CLI for issues/PRs
- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:proxy, pkg:tui, pkg:web-ui
- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:tui, pkg:web-ui
- TUI interaction: use tmux
## Style
@ -49,43 +49,37 @@ When closing issues via commit:
- Technical prose only, be kind but direct (e.g., "Thanks @user" not "Thanks so much @user!")
## Changelog
Location: `packages/coding-agent/CHANGELOG.md`, `packages/ai/CHANGELOG.md`, `packages/tui/CHANGELOG.md`, pick the one relevant to the changes or ask user.
Location: `packages/*/CHANGELOG.md` (each package has its own)
### Format
Use these sections under `## [Unreleased]`:
- `### Breaking Changes` - API changes requiring migration
- `### Added` - New features
- `### Changed` - Changes to existing functionality
- `### Fixed` - Bug fixes
- `### Removed` - Removed features
### Rules
- New entries ALWAYS go under `## [Unreleased]` section
- NEVER modify already-released version sections (e.g., `## [0.12.2]`)
- Each version section is immutable once released
- When releasing: rename `[Unreleased]` to the new version, then add a fresh empty `[Unreleased]` section
### Attribution format
- **Internal changes (from issues)**: Reference issue only
- Example: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))`
- **External contributions (PRs from others)**: Reference PR and credit the contributor
- Example: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))`
- If a PR addresses an issue, reference both: `([#123](...issues/123), [#456](...pull/456) by [@user](...))` or just the PR if the issue context is clear from the description
### Attribution
- **Internal changes (from issues)**: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))`
- **External contributions**: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))`
## Releasing
1. **Bump version** (all packages use lockstep versioning):
1. **Update CHANGELOGs**: Ensure all changes since last release are documented in the `[Unreleased]` section of each affected package's CHANGELOG.md
2. **Run release script**:
```bash
npm run version:patch # For bug fixes
npm run version:minor # For new features
npm run version:major # For breaking changes
npm run release:patch # Bug fixes
npm run release:minor # New features
npm run release:major # Breaking changes
```
2. **Finalize CHANGELOG.md**: Change `[Unreleased]` to the new version with today's date (e.g., `## [0.12.12] - 2025-12-05`)
3. **Commit and tag**:
```bash
git add .
git commit -m "Release v0.12.12"
git tag v0.12.12
git push origin main
git push origin v0.12.12
```
4. **Publish to npm**:
```bash
npm run publish
```
5. **Add new [Unreleased] section** at top of CHANGELOG.md for next cycle, commit it
The script handles: version bump, CHANGELOG finalization, commit, tag, publish, and adding new `[Unreleased]` sections.
### Tool Usage
**CTRICIAL**: NEVER use sed/cat to read a file or a range of a file. Always use the read tool (use offset + limit for ranged reads).

View file

@ -7,12 +7,11 @@ Tools for building AI agents and managing LLM deployments.
| Package | Description |
|---------|-------------|
| **[@mariozechner/pi-ai](packages/ai)** | Unified multi-provider LLM API (OpenAI, Anthropic, Google, etc.) |
| **[@mariozechner/pi-agent](packages/agent)** | Agent runtime with tool calling and state management |
| **[@mariozechner/pi-agent-core](packages/agent)** | Agent runtime with tool calling and state management |
| **[@mariozechner/pi-coding-agent](packages/coding-agent)** | Interactive coding agent CLI |
| **[@mariozechner/pi-mom](packages/mom)** | Slack bot that delegates messages to the pi coding agent |
| **[@mariozechner/pi-tui](packages/tui)** | Terminal UI library with differential rendering |
| **[@mariozechner/pi-web-ui](packages/web-ui)** | Web components for AI chat interfaces |
| **[@mariozechner/pi-proxy](packages/proxy)** | CORS proxy for browser-based LLM API calls |
| **[@mariozechner/pi-pods](packages/pods)** | CLI for managing vLLM deployments on GPU pods |
## Development
@ -71,55 +70,18 @@ These commands:
### Publishing
Complete release process:
```bash
npm run release:patch # Bug fixes
npm run release:minor # New features
npm run release:major # Breaking changes
```
1. **Add changes to CHANGELOG.md** (if changes affect coding-agent):
```bash
# Add your changes to the [Unreleased] section in packages/coding-agent/CHANGELOG.md
# Always add new entries under [Unreleased], never under already-released versions
```
This handles version bump, CHANGELOG updates, commit, tag, publish, and push.
2. **Bump version** (all packages):
```bash
npm run version:patch # For bug fixes
npm run version:minor # For new features
npm run version:major # For breaking changes
```
3. **Finalize CHANGELOG.md for release** (if changes affect coding-agent):
```bash
# Change [Unreleased] to the new version number with today's date
# e.g., ## [0.7.16] - 2025-11-17
# NEVER add entries to already-released version sections
# Each version section is immutable once released
```
4. **Commit and tag**:
```bash
git add .
git commit -m "Release v0.7.16"
git tag v0.7.16
git push origin main
git push origin v0.7.16
```
5. **Publish to npm**:
```bash
npm run publish # Publish all packages to npm
```
**NPM Token Setup**: Publishing requires a granular access token with "Bypass 2FA on publish" enabled.
- Go to https://www.npmjs.com/settings/badlogic/tokens/
- Create a new "Granular Access Token"
- Select "Bypass 2FA on publish"
- Tokens expire after 90 days, so regenerate when needed
- Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN`
6. **Add new [Unreleased] section** (for next development cycle):
```bash
# Add a new [Unreleased] section at the top of CHANGELOG.md
# Commit: git commit -am "Add [Unreleased] section"
```
**NPM Token Setup**: Requires a granular access token with "Bypass 2FA on publish" enabled.
- Go to https://www.npmjs.com/settings/badlogic/tokens/
- Create a new "Granular Access Token" with "Bypass 2FA on publish"
- Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN`
## License

149
package-lock.json generated
View file

@ -12,6 +12,7 @@
"packages/web-ui/example"
],
"dependencies": {
"@mariozechner/pi-coding-agent": "^0.30.2",
"get-east-asian-width": "^1.4.0"
},
"devDependencies": {
@ -695,18 +696,6 @@
}
}
},
"node_modules/@hono/node-server": {
"version": "1.19.7",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.7.tgz",
"integrity": "sha512-vUcD0uauS7EU2caukW8z5lJKtoGMokxNbJtBiwHgpqxEXokaHCBkQUmCHhjFB1VUTWdqj25QoMkMKzgjq+uhrw==",
"license": "MIT",
"engines": {
"node": ">=18.14.1"
},
"peerDependencies": {
"hono": "^4"
}
},
"node_modules/@isaacs/balanced-match": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
@ -964,10 +953,6 @@
"resolved": "packages/mom",
"link": true
},
"node_modules/@mariozechner/pi-proxy": {
"resolved": "packages/proxy",
"link": true
},
"node_modules/@mariozechner/pi-tui": {
"resolved": "packages/tui",
"link": true
@ -995,9 +980,9 @@
}
},
"node_modules/@napi-rs/canvas": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.86.tgz",
"integrity": "sha512-hOkywnrkdFdVpsuaNsZWfEY7kc96eROV2DuMTTvGF15AZfwobzdG2w0eDlU5UBx3Lg/XlWUnqVT5zLUWyo5h6A==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.88.tgz",
"integrity": "sha512-/p08f93LEbsL5mDZFQ3DBxcPv/I4QG9EDYRRq1WNlCOXVfAHBTHMSVMwxlqG/AtnSfUr9+vgfN7MKiyDo0+Weg==",
"license": "MIT",
"optional": true,
"workspaces": [
@ -1011,23 +996,23 @@
"url": "https://github.com/sponsors/Brooooooklyn"
},
"optionalDependencies": {
"@napi-rs/canvas-android-arm64": "0.1.86",
"@napi-rs/canvas-darwin-arm64": "0.1.86",
"@napi-rs/canvas-darwin-x64": "0.1.86",
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.86",
"@napi-rs/canvas-linux-arm64-gnu": "0.1.86",
"@napi-rs/canvas-linux-arm64-musl": "0.1.86",
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.86",
"@napi-rs/canvas-linux-x64-gnu": "0.1.86",
"@napi-rs/canvas-linux-x64-musl": "0.1.86",
"@napi-rs/canvas-win32-arm64-msvc": "0.1.86",
"@napi-rs/canvas-win32-x64-msvc": "0.1.86"
"@napi-rs/canvas-android-arm64": "0.1.88",
"@napi-rs/canvas-darwin-arm64": "0.1.88",
"@napi-rs/canvas-darwin-x64": "0.1.88",
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.88",
"@napi-rs/canvas-linux-arm64-gnu": "0.1.88",
"@napi-rs/canvas-linux-arm64-musl": "0.1.88",
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.88",
"@napi-rs/canvas-linux-x64-gnu": "0.1.88",
"@napi-rs/canvas-linux-x64-musl": "0.1.88",
"@napi-rs/canvas-win32-arm64-msvc": "0.1.88",
"@napi-rs/canvas-win32-x64-msvc": "0.1.88"
}
},
"node_modules/@napi-rs/canvas-android-arm64": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.86.tgz",
"integrity": "sha512-IjkZFKUr6GzMzzrawJaN3v+yY3Fvpa71e0DcbePfxWelFKnESIir+XUcdAbim29JOd0JE0/hQJdfUCb5t/Fjrw==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.88.tgz",
"integrity": "sha512-KEaClPnZuVxJ8smUWjV1wWFkByBO/D+vy4lN+Dm5DFH514oqwukxKGeck9xcKJhaWJGjfruGmYGiwRe//+/zQQ==",
"cpu": [
"arm64"
],
@ -1045,9 +1030,9 @@
}
},
"node_modules/@napi-rs/canvas-darwin-arm64": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.86.tgz",
"integrity": "sha512-PUCxDq0wSSJbtaOqoKj3+t5tyDbtxWumziOTykdn3T839hu6koMaBFpGk9lXpsGaPNgyFpPqjxhtsPljBGnDHg==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.88.tgz",
"integrity": "sha512-Xgywz0dDxOKSgx3eZnK85WgGMmGrQEW7ZLA/E7raZdlEE+xXCozobgqz2ZvYigpB6DJFYkqnwHjqCOTSDGlFdg==",
"cpu": [
"arm64"
],
@ -1065,9 +1050,9 @@
}
},
"node_modules/@napi-rs/canvas-darwin-x64": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.86.tgz",
"integrity": "sha512-rlCFLv4Rrg45qFZq7mysrKnsUbMhwdNg3YPuVfo9u4RkOqm7ooAJvdyDFxiqfSsJJTqupYqa9VQCUt8WKxKhNQ==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.88.tgz",
"integrity": "sha512-Yz4wSCIQOUgNucgk+8NFtQxQxZV5NO8VKRl9ePKE6XoNyNVC8JDqtvhh3b3TPqKK8W5p2EQpAr1rjjm0mfBxdg==",
"cpu": [
"x64"
],
@ -1085,9 +1070,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.86.tgz",
"integrity": "sha512-6xWwyMc9BlDBt+9XHN/GzUo3MozHta/2fxQHMb80x0K2zpZuAdDKUYHmYzx9dFWDY3SbPYnx6iRlQl6wxnwS1w==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.88.tgz",
"integrity": "sha512-9gQM2SlTo76hYhxHi2XxWTAqpTOb+JtxMPEIr+H5nAhHhyEtNmTSDRtz93SP7mGd2G3Ojf2oF5tP9OdgtgXyKg==",
"cpu": [
"arm"
],
@ -1105,9 +1090,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.86.tgz",
"integrity": "sha512-r2OX3w50xHxrToTovOSQWwkVfSq752CUzH9dzlVXyr8UDKFV8dMjfa9hePXvAJhN3NBp4TkHcGx15QCdaCIwnA==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.88.tgz",
"integrity": "sha512-7qgaOBMXuVRk9Fzztzr3BchQKXDxGbY+nwsovD3I/Sx81e+sX0ReEDYHTItNb0Je4NHbAl7D0MKyd4SvUc04sg==",
"cpu": [
"arm64"
],
@ -1125,9 +1110,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.86.tgz",
"integrity": "sha512-jbXuh8zVFUPw6a9SGpgc6EC+fRbGGyP1NFfeQiVqGLs6bN93ROtPLPL6MH9Bp6yt0CXUFallk2vgKdWDbmW+bw==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.88.tgz",
"integrity": "sha512-kYyNrUsHLkoGHBc77u4Unh067GrfiCUMbGHC2+OTxbeWfZkPt2o32UOQkhnSswKd9Fko/wSqqGkY956bIUzruA==",
"cpu": [
"arm64"
],
@ -1145,9 +1130,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.86.tgz",
"integrity": "sha512-9IwHR2qbq2HceM9fgwyL7x37Jy3ptt1uxvikQEuWR0FisIx9QEdt7F3huljCky76aoouF2vSd0R2fHo3ESRoPw==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.88.tgz",
"integrity": "sha512-HVuH7QgzB0yavYdNZDRyAsn/ejoXB0hn8twwFnOqUbCCdkV+REna7RXjSR7+PdfW0qMQ2YYWsLvVBT5iL/mGpw==",
"cpu": [
"riscv64"
],
@ -1165,9 +1150,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.86.tgz",
"integrity": "sha512-Jor+rhRN6ubix+D2QkNn9XlPPVAYl+2qFrkZ4oZN9UgtqIUZ+n+HljxhlkkDFRaX1mlxXOXPQjxaZg17zDSFcQ==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.88.tgz",
"integrity": "sha512-hvcvKIcPEQrvvJtJnwD35B3qk6umFJ8dFIr8bSymfrSMem0EQsfn1ztys8ETIFndTwdNWJKWluvxztA41ivsEw==",
"cpu": [
"x64"
],
@ -1185,9 +1170,9 @@
}
},
"node_modules/@napi-rs/canvas-linux-x64-musl": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.86.tgz",
"integrity": "sha512-A28VTy91DbclopSGZ2tIon3p8hcVI1JhnNpDpJ5N9rYlUnVz1WQo4waEMh+FICTZF07O3coxBNZc4Vu4doFw7A==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.88.tgz",
"integrity": "sha512-eSMpGYY2xnZSQ6UxYJ6plDboxq4KeJ4zT5HaVkUnbObNN6DlbJe0Mclh3wifAmquXfrlgTZt6zhHsUgz++AK6g==",
"cpu": [
"x64"
],
@ -1205,9 +1190,9 @@
}
},
"node_modules/@napi-rs/canvas-win32-arm64-msvc": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.86.tgz",
"integrity": "sha512-q6G1YXUt3gBCAS2bcDMCaBL4y20di8eVVBi1XhjUqZSVyZZxxwIuRQHy31NlPJUCMiyNiMuc6zeI0uqgkWwAmA==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.88.tgz",
"integrity": "sha512-qcIFfEgHrchyYqRrxsCeTQgpJZ/GqHiqPcU/Fvw/ARVlQeDX1VyFH+X+0gCR2tca6UJrq96vnW+5o7buCq+erA==",
"cpu": [
"arm64"
],
@ -1225,9 +1210,9 @@
}
},
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
"version": "0.1.86",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.86.tgz",
"integrity": "sha512-X0g46uRVgnvCM1cOjRXAOSFSG63ktUFIf/TIfbKCUc7QpmYUcHmSP9iR6DGOYfk+SggLsXoJCIhPTotYeZEAmg==",
"version": "0.1.88",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.88.tgz",
"integrity": "sha512-ROVqbfS4QyZxYkqmaIBBpbz/BQvAR+05FXM5PAtTYVc0uyY8Y4BHJSMdGAaMf6TdIVRsQsiq+FG/dH9XhvWCFQ==",
"cpu": [
"x64"
],
@ -4008,16 +3993,6 @@
"node": "*"
}
},
"node_modules/hono": {
"version": "4.11.2",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.2.tgz",
"integrity": "sha512-o+avdUAD1v94oHkjGBhiMhBV4WBHxhbu0+CUVH78hhphKy/OKQLxtKjkmmNcrMlbYAhAbsM/9F+l3KnYxyD3Lg==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=16.9.0"
}
},
"node_modules/html-parse-string": {
"version": "0.0.9",
"resolved": "https://registry.npmjs.org/html-parse-string/-/html-parse-string-0.0.9.tgz",
@ -4583,7 +4558,6 @@
"resolved": "https://registry.npmjs.org/lit/-/lit-3.3.2.tgz",
"integrity": "sha512-NF9zbsP79l4ao2SNrH3NkfmFgN/hBYSQo90saIVI1o5GpjAdCPVstVzO1MrLOakHoEhYkrtRjPK6Ob521aoYWQ==",
"license": "BSD-3-Clause",
"peer": true,
"dependencies": {
"@lit/reactive-element": "^2.1.0",
"lit-element": "^4.2.0",
@ -5704,7 +5678,6 @@
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.4.0.tgz",
"integrity": "sha512-uSaO4gnW+b3Y2aWoWfFpX62vn2sR3skfhbjsEnaBI81WD1wBLlHZe5sWf0AqjksNdYTbGBEd0UasQMT3SNV15g==",
"license": "MIT",
"peer": true,
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
@ -5733,8 +5706,7 @@
"version": "4.1.18",
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.18.tgz",
"integrity": "sha512-4+Z+0yiYyEtUVCScyfHCxOYP06L5Ne+JiHhY2IjR2KWMIWhJOYZKLSGZaP5HkZ8+bY0cxfzwDE5uOmzFXyIwxw==",
"license": "MIT",
"peer": true
"license": "MIT"
},
"node_modules/tapable": {
"version": "2.3.0",
@ -5999,7 +5971,6 @@
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.0.tgz",
"integrity": "sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==",
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.27.0",
"fdir": "^6.5.0",
@ -6442,9 +6413,9 @@
}
},
"node_modules/zod-to-json-schema": {
"version": "3.25.0",
"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.0.tgz",
"integrity": "sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==",
"version": "3.25.1",
"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.1.tgz",
"integrity": "sha512-pM/SU9d3YAggzi6MtR4h7ruuQlqKtad8e9S0fmxcMi+ueAK5Korys/aWcV9LIIHTVbj01NdzxcnXSN+O74ZIVA==",
"license": "ISC",
"peerDependencies": {
"zod": "^3.25 || ^4"
@ -6636,22 +6607,6 @@
"node": ">=20.0.0"
}
},
"packages/proxy": {
"name": "@mariozechner/pi-proxy",
"version": "0.30.2",
"dependencies": {
"@hono/node-server": "^1.14.0",
"hono": "^4.6.16"
},
"bin": {
"pi-proxy": "dist/cli.js"
},
"devDependencies": {
"@types/node": "^22.10.5",
"tsx": "^4.19.2",
"typescript": "^5.7.3"
}
},
"packages/tui": {
"name": "@mariozechner/pi-tui",
"version": "0.30.2",

View file

@ -8,8 +8,8 @@
],
"scripts": {
"clean": "npm run clean --workspaces",
"build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi-proxy && npm run build -w @mariozechner/pi",
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui,proxy\" --prefix-colors \"cyan,yellow,red,white,green,magenta,blue\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\" \"npm run dev -w @mariozechner/pi-proxy\"",
"build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi",
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui\" --prefix-colors \"cyan,yellow,red,white,green,magenta\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\"",
"dev:tsc": "concurrently --names \"ai,web-ui\" --prefix-colors \"cyan,green\" \"npm run dev:tsc -w @mariozechner/pi-ai\" \"npm run dev:tsc -w @mariozechner/pi-web-ui\"",
"check": "biome check --write . && tsgo --noEmit && npm run check -w @mariozechner/pi-web-ui",
"test": "npm run test --workspaces --if-present",
@ -20,6 +20,9 @@
"prepublishOnly": "npm run clean && npm run build && npm run check",
"publish": "npm run prepublishOnly && npm publish -ws --access public",
"publish:dry": "npm run prepublishOnly && npm publish -ws --access public --dry-run",
"release:patch": "node scripts/release.mjs patch",
"release:minor": "node scripts/release.mjs minor",
"release:major": "node scripts/release.mjs major",
"prepare": "husky"
},
"devDependencies": {
@ -36,6 +39,7 @@
},
"version": "0.0.3",
"dependencies": {
"@mariozechner/pi-coding-agent": "^0.30.2",
"get-east-asian-width": "^1.4.0"
}
}

View file

@ -0,0 +1,69 @@
# Changelog
## [Unreleased]
### Breaking Changes
- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. The `Agent` class now takes a `streamFn` option directly for custom streaming implementations.
- **Agent options renamed**:
- `transport` → removed (use `streamFn` instead)
- `messageTransformer``convertToLlm` (converts `AgentMessage[]` to LLM-compatible `Message[]`)
- `preprocessor``transformContext` (transforms `AgentMessage[]` before `convertToLlm`)
- **AppMessage renamed to AgentMessage**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency.
- **Agent loop moved from pi-ai**: The `agentLoop`, `agentLoopContinue`, and related types (`AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `AgentLoopConfig`) have moved from `@mariozechner/pi-ai` to this package.
### Added
- **`streamFn` option**: Pass a custom stream function to the Agent for proxy backends or custom implementations. Default uses `streamSimple` from pi-ai.
- **`streamProxy` utility**: New helper function for browser apps that need to proxy through a backend server. Replaces `AppTransport`.
- **`getApiKey` option**: Dynamic API key resolution for expiring OAuth tokens (e.g., GitHub Copilot).
- **`AgentLoopContext` and `AgentLoopConfig`**: Exported types for the low-level agent loop API.
- **`agentLoop` and `agentLoopContinue`**: Low-level functions for running the agent loop directly without the `Agent` class wrapper.
### Migration Guide
**Before (0.30.x):**
```typescript
import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core';
const agent = new Agent({
transport: new ProviderTransport({ apiKey: '...' }),
messageTransformer: (messages) => messages.filter(...),
preprocessor: async (messages) => compactMessages(messages)
});
```
**After:**
```typescript
import { Agent } from '@mariozechner/pi-agent-core';
import { streamSimple } from '@mariozechner/pi-ai';
const agent = new Agent({
streamFn: streamSimple, // or omit for default
convertToLlm: (messages) => messages.filter(...),
transformContext: async (messages) => compactMessages(messages),
getApiKey: async (provider) => resolveApiKey(provider)
});
```
**For proxy usage (replaces AppTransport):**
```typescript
import { Agent, streamProxy } from '@mariozechner/pi-agent-core';
const agent = new Agent({
streamFn: (model, context, options) => streamProxy(
'/api/agent',
model,
context,
options,
{ 'Authorization': 'Bearer ...' }
)
});
```

View file

@ -1,6 +1,6 @@
# @mariozechner/pi-agent-core
Stateful agent abstraction with transport layer for LLM interactions. Provides a reactive `Agent` class that manages conversation state, emits granular events, and supports pluggable transports for different deployment scenarios.
Stateful agent with tool execution, event streaming, and extensible message types. Built on `@mariozechner/pi-ai`.
## Installation
@ -11,12 +11,10 @@ npm install @mariozechner/pi-agent-core
## Quick Start
```typescript
import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core';
import { Agent } from '@mariozechner/pi-agent-core';
import { getModel } from '@mariozechner/pi-ai';
// Create agent with direct provider transport
const agent = new Agent({
transport: new ProviderTransport(),
initialState: {
systemPrompt: 'You are a helpful assistant.',
model: getModel('anthropic', 'claude-sonnet-4-20250514'),
@ -28,38 +26,105 @@ const agent = new Agent({
// Subscribe to events for reactive UI updates
agent.subscribe((event) => {
switch (event.type) {
case 'message_start':
console.log(`${event.message.role} message started`);
break;
case 'message_update':
// Stream text to UI
const content = event.message.content;
for (const block of content) {
if (block.type === 'text') console.log(block.text);
// Only emitted for assistant messages during streaming
// event.message is partial - may have incomplete content
for (const block of event.message.content) {
if (block.type === 'text') process.stdout.write(block.text);
}
break;
case 'message_end':
console.log(`${event.message.role} message complete`);
break;
case 'tool_execution_start':
console.log(`Calling ${event.toolName}...`);
break;
case 'tool_execution_update':
// Stream tool output (e.g., bash stdout)
console.log('Progress:', event.partialResult.content);
break;
case 'tool_execution_end':
console.log(`Result:`, event.result.content);
break;
}
});
// Send a prompt
await agent.prompt('Hello, world!');
// Access conversation state
console.log(agent.state.messages);
```
## Core Concepts
## AgentMessage vs LLM Message
### Agent State
The agent internally works with `AgentMessage`, a flexible type that can include:
- Standard LLM messages (`user`, `assistant`, `toolResult`)
- Custom app-specific message types (via declaration merging)
The `Agent` maintains reactive state:
LLMs only understand a subset: `user`, `assistant`, and `toolResult` messages with specific content formats. The `convertToLlm` function bridges this gap.
### Why This Separation?
1. **Rich UI state**: Store UI-specific data (attachments metadata, custom message types) alongside the conversation
2. **Session persistence**: Save the full conversation state including app-specific messages
3. **Context manipulation**: Transform messages before sending to LLM (compaction, injection, filtering)
### The Conversion Flow
```
AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM
↑ (optional) (required)
|
App state with custom types,
attachments, UI metadata
```
### Constraints
**Messages passed to `prompt()` or queued via `queueMessage()` must convert to LLM messages with `role: "user"` or `role: "toolResult"`.**
When calling `continue()`, the last message in the context must also convert to `user` or `toolResult`. The LLM expects to respond to a user or tool result, not to its own assistant message.
```typescript
// OK: Standard user message
await agent.prompt('Hello');
// OK: Custom type that converts to user message
await agent.prompt({ role: 'hookMessage', content: 'System notification', timestamp: Date.now() });
// But convertToLlm must handle this:
convertToLlm: (messages) => messages.map(m => {
if (m.role === 'hookMessage') {
return { role: 'user', content: m.content, timestamp: m.timestamp };
}
return m;
})
// ERROR: Cannot prompt with assistant message
await agent.prompt({ role: 'assistant', content: [...], ... }); // Will fail at LLM
```
## Agent Options
```typescript
interface AgentOptions {
initialState?: Partial<AgentState>;
// Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
// Default: filters to user/assistant/toolResult and converts image attachments.
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
// Transform context before convertToLlm (for pruning, compaction, injecting context)
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
// Queue mode: 'all' sends all queued messages, 'one-at-a-time' sends one per turn
queueMode?: 'all' | 'one-at-a-time';
// Custom stream function (for proxy backends). Default: streamSimple from pi-ai
streamFn?: StreamFn;
// Dynamic API key resolution (useful for expiring OAuth tokens)
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
}
```
## Agent State
```typescript
interface AgentState {
@ -67,17 +132,19 @@ interface AgentState {
model: Model<any>;
thinkingLevel: ThinkingLevel; // 'off' | 'minimal' | 'low' | 'medium' | 'high' | 'xhigh'
tools: AgentTool<any>[];
messages: AppMessage[];
messages: AgentMessage[]; // Full conversation including custom types
isStreaming: boolean;
streamMessage: Message | null;
streamMessage: AgentMessage | null; // Current partial message during streaming
pendingToolCalls: Set<string>;
error?: string;
}
```
### Events
## Events
Events provide fine-grained lifecycle information:
Events provide fine-grained lifecycle information for building reactive UIs.
### Event Types
| Event | Description |
|-------|-------------|
@ -86,33 +153,118 @@ Events provide fine-grained lifecycle information:
| `turn_start` | New turn begins (one LLM response + tool executions) |
| `turn_end` | Turn completes with assistant message and tool results |
| `message_start` | Message begins (user, assistant, or toolResult) |
| `message_update` | Assistant message streaming update |
| `message_update` | **Assistant messages only.** Partial message during streaming |
| `message_end` | Message completes |
| `tool_execution_start` | Tool begins execution |
| `tool_execution_update` | Tool streams progress (e.g., bash output) |
| `tool_execution_update` | Tool streams progress |
| `tool_execution_end` | Tool completes with result |
### Transports
### Message Events for prompt() and queueMessage()
Transports abstract LLM communication:
When you call `prompt(message)`, the agent emits `message_start` and `message_end` events for that message before the assistant responds:
- **`ProviderTransport`**: Direct API calls using `@mariozechner/pi-ai`
- **`AppTransport`**: Proxy through a backend server (for browser apps)
```
prompt(userMessage)
→ agent_start
→ turn_start
→ message_start { message: userMessage }
→ message_end { message: userMessage }
→ message_start { message: assistantMessage } // LLM starts responding
→ message_update { message: partialAssistant } // streaming...
→ message_end { message: assistantMessage }
...
```
Queued messages (via `queueMessage()`) emit the same events when injected:
```
// During tool execution, a message is queued
agent.queueMessage(interruptMessage)
// After tool completes, before next LLM call:
→ message_start { message: interruptMessage }
→ message_end { message: interruptMessage }
→ message_start { message: assistantMessage } // LLM responds to interrupt
...
```
### Handling Partial Messages in Reactive UIs
`message_update` events contain partial assistant messages during streaming. The `event.message` may have:
- Incomplete text (truncated mid-word)
- Partial tool call arguments
- Missing content blocks that haven't started streaming yet
**Pattern for reactive UIs:**
```typescript
// Direct provider access (Node.js)
const agent = new Agent({
transport: new ProviderTransport({
apiKey: process.env.ANTHROPIC_API_KEY
})
});
agent.subscribe((event) => {
switch (event.type) {
case 'message_start':
if (event.message.role === 'assistant') {
// Create placeholder in UI
ui.addMessage({ id: tempId, role: 'assistant', content: [] });
}
break;
// Via proxy (browser)
case 'message_update':
// Replace placeholder content with partial content
// This is only emitted for assistant messages
ui.updateMessage(tempId, event.message.content);
break;
case 'message_end':
if (event.message.role === 'assistant') {
// Finalize with complete message
ui.finalizeMessage(tempId, event.message);
}
break;
}
});
```
**Accessing the current partial message:**
During streaming, `agent.state.streamMessage` contains the current partial message. This is useful for rendering outside the event handler:
```typescript
// In a render loop or reactive binding
if (agent.state.isStreaming && agent.state.streamMessage) {
renderPartialMessage(agent.state.streamMessage);
}
```
## Custom Message Types
Extend `AgentMessage` for app-specific messages via declaration merging:
```typescript
declare module '@mariozechner/pi-agent-core' {
interface CustomAgentMessages {
artifact: { role: 'artifact'; code: string; language: string; timestamp: number };
notification: { role: 'notification'; text: string; timestamp: number };
}
}
// AgentMessage now includes your custom types
const msg: AgentMessage = { role: 'artifact', code: '...', language: 'typescript', timestamp: Date.now() };
```
Custom messages are stored in state but filtered out by the default `convertToLlm`. Provide your own converter to handle them:
```typescript
const agent = new Agent({
transport: new AppTransport({
endpoint: '/api/agent',
headers: { 'Authorization': 'Bearer ...' }
})
convertToLlm: (messages) => {
return messages
.filter(m => m.role !== 'notification') // Filter out UI-only messages
.map(m => {
if (m.role === 'artifact') {
// Convert to user message so LLM sees the artifact
return { role: 'user', content: `[Artifact: ${m.language}]\n${m.code}`, timestamp: m.timestamp };
}
return m;
});
}
});
```
@ -121,45 +273,76 @@ const agent = new Agent({
Queue messages to inject at the next turn:
```typescript
// Queue mode: 'all' or 'one-at-a-time'
agent.setQueueMode('one-at-a-time');
// Queue a message while agent is streaming
await agent.queueMessage({
// Queue while agent is streaming
agent.queueMessage({
role: 'user',
content: 'Additional context...',
content: 'Stop what you are doing and focus on this instead.',
timestamp: Date.now()
});
```
## Attachments
When queued messages are detected after a tool call, remaining tool calls are skipped with error results ("Skipped due to queued user message"). The queued message is then injected before the next assistant response.
User messages can include attachments:
## Images
User messages can include images:
```typescript
await agent.prompt('What is in this image?', [{
id: 'img1',
type: 'image',
fileName: 'photo.jpg',
mimeType: 'image/jpeg',
size: 102400,
content: base64ImageData
}]);
await agent.prompt('What is in this image?', [
{ type: 'image', data: base64ImageData, mimeType: 'image/jpeg' }
]);
```
## Custom Message Types
## Proxy Usage
Extend `AppMessage` for app-specific messages via declaration merging:
For browser apps that need to proxy through a backend, use `streamProxy`:
```typescript
declare module '@mariozechner/pi-agent-core' {
interface CustomMessages {
artifact: { role: 'artifact'; code: string; language: string };
}
import { Agent, streamProxy } from '@mariozechner/pi-agent-core';
const agent = new Agent({
streamFn: (model, context, options) => streamProxy(
'/api/agent',
model,
context,
options,
{ 'Authorization': 'Bearer ...' }
)
});
```
## Low-Level API
For more control, use `agentLoop` and `agentLoopContinue` directly:
```typescript
import { agentLoop, agentLoopContinue, AgentLoopContext, AgentLoopConfig } from '@mariozechner/pi-agent-core';
import { getModel, streamSimple } from '@mariozechner/pi-ai';
const context: AgentLoopContext = {
systemPrompt: 'You are helpful.',
messages: [],
tools: [myTool]
};
const config: AgentLoopConfig = {
model: getModel('openai', 'gpt-4o-mini'),
convertToLlm: (msgs) => msgs.filter(m => ['user', 'assistant', 'toolResult'].includes(m.role))
};
const userMessage = { role: 'user', content: 'Hello', timestamp: Date.now() };
for await (const event of agentLoop(userMessage, context, config, undefined, streamSimple)) {
console.log(event.type);
}
// Now AppMessage includes your custom type
const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript' };
// Continue from existing context (e.g., after overflow recovery)
// Last message in context must convert to 'user' or 'toolResult'
for await (const event of agentLoopContinue(context, config, undefined, streamSimple)) {
console.log(event.type);
}
```
## API Reference
@ -168,13 +351,14 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript'
| Method | Description |
|--------|-------------|
| `prompt(text, attachments?)` | Send a user prompt |
| `continue()` | Continue from current context (for retry after overflow) |
| `prompt(text, images?)` | Send a user prompt with optional images |
| `prompt(message)` | Send an AgentMessage directly (must convert to user/toolResult) |
| `continue()` | Continue from current context (last message must convert to user/toolResult) |
| `abort()` | Abort current operation |
| `waitForIdle()` | Returns promise that resolves when agent is idle |
| `waitForIdle()` | Promise that resolves when agent is idle |
| `reset()` | Clear all messages and state |
| `subscribe(fn)` | Subscribe to events, returns unsubscribe function |
| `queueMessage(msg)` | Queue message for next turn |
| `queueMessage(msg)` | Queue message for next turn (must convert to user/toolResult) |
| `clearMessageQueue()` | Clear queued messages |
### State Mutators
@ -184,7 +368,7 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript'
| `setSystemPrompt(v)` | Update system prompt |
| `setModel(m)` | Switch model |
| `setThinkingLevel(l)` | Set reasoning level |
| `setQueueMode(m)` | Set queue mode ('all' or 'one-at-a-time') |
| `setQueueMode(m)` | Set queue mode |
| `setTools(t)` | Update available tools |
| `replaceMessages(ms)` | Replace all messages |
| `appendMessage(m)` | Append a message |

View file

@ -1,33 +1,52 @@
import { streamSimple } from "../stream.js";
import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js";
import { EventStream } from "../utils/event-stream.js";
import { validateToolArguments } from "../utils/validation.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, QueuedMessage } from "./types.js";
/**
* Agent loop that works with AgentMessage throughout.
* Transforms to Message[] only at the LLM call boundary.
*/
import {
type AssistantMessage,
type Context,
EventStream,
streamSimple,
type ToolResultMessage,
validateToolArguments,
} from "@mariozechner/pi-ai";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
AgentToolResult,
StreamFn,
} from "./types.js";
/**
* Start an agent loop with a new user message.
* Start an agent loop with a new prompt message.
* The prompt is added to the context and events are emitted for it.
*/
export function agentLoop(
prompt: UserMessage,
prompts: AgentMessage[],
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [prompt];
const newMessages: AgentMessage[] = [...prompts];
const currentContext: AgentContext = {
...context,
messages: [...context.messages, prompt],
messages: [...context.messages, ...prompts],
};
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
for (const prompt of prompts) {
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
}
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
@ -37,33 +56,34 @@ export function agentLoop(
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retry after overflow - context already has user message or tool results.
* Throws if the last message is not a user message or tool result.
* Used for retries - context already has user message or tool results.
*
* **Important:** The last message in context must convert to a `user` or `toolResult` message
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
* This cannot be validated here since `convertToLlm` is only called once per turn.
*/
export function agentLoopContinue(
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
// Validate that we can continue from this context
const lastMessage = context.messages[context.messages.length - 1];
if (!lastMessage) {
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
if (context.messages.length === 0) {
throw new Error("Cannot continue: no messages in context");
}
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
throw new Error(`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`);
if (context.messages[context.messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [];
const newMessages: AgentMessage[] = [];
const currentContext: AgentContext = { ...context };
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
// No user message events - we're continuing from existing context
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
@ -71,28 +91,28 @@ export function agentLoopContinue(
return stream;
}
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
return new EventStream<AgentEvent, AgentContext["messages"]>(
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
return new EventStream<AgentEvent, AgentMessage[]>(
(event: AgentEvent) => event.type === "agent_end",
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
);
}
/**
* Shared loop logic for both agentLoop and agentLoopContinue.
* Main loop logic shared by agentLoop and agentLoopContinue.
*/
async function runLoop(
currentContext: AgentContext,
newMessages: AgentContext["messages"],
newMessages: AgentMessage[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<void> {
let hasMoreToolCalls = true;
let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: QueuedMessage<any>[] | null = null;
let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: AgentMessage[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) {
@ -101,15 +121,13 @@ async function runLoop(
firstTurn = false;
}
// Process queued messages first (inject before next assistant response)
// Process queued messages (inject before next assistant response)
if (queuedMessages.length > 0) {
for (const { original, llm } of queuedMessages) {
stream.push({ type: "message_start", message: original });
stream.push({ type: "message_end", message: original });
if (llm) {
currentContext.messages.push(llm);
newMessages.push(llm);
}
for (const message of queuedMessages) {
stream.push({ type: "message_start", message });
stream.push({ type: "message_end", message });
currentContext.messages.push(message);
newMessages.push(message);
}
queuedMessages = [];
}
@ -119,7 +137,6 @@ async function runLoop(
newMessages.push(message);
if (message.stopReason === "error" || message.stopReason === "aborted") {
// Stop the loop on error or abort
stream.push({ type: "turn_end", message, toolResults: [] });
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
@ -132,7 +149,6 @@ async function runLoop(
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
// Execute tool calls
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
@ -142,10 +158,14 @@ async function runLoop(
);
toolResults.push(...toolExecution.toolResults);
queuedAfterTools = toolExecution.queuedMessages ?? null;
currentContext.messages.push(...toolResults);
newMessages.push(...toolResults);
for (const result of toolResults) {
currentContext.messages.push(result);
newMessages.push(result);
}
}
stream.push({ type: "turn_end", message, toolResults: toolResults });
stream.push({ type: "turn_end", message, toolResults });
// Get queued messages after turn completes
if (queuedAfterTools && queuedAfterTools.length > 0) {
@ -160,41 +180,44 @@ async function runLoop(
stream.end(newMessages);
}
// Helper functions
/**
* Stream an assistant response from the LLM.
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
*/
async function streamAssistantResponse(
context: AgentContext,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<AssistantMessage> {
// Convert AgentContext to Context for streamSimple
// Use a copy of messages to avoid mutating the original context
const processedMessages = config.preprocessor
? await config.preprocessor(context.messages, signal)
: [...context.messages];
const processedContext: Context = {
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
let messages = context.messages;
if (config.transformContext) {
messages = await config.transformContext(messages, signal);
}
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
const llmMessages = await config.convertToLlm(messages);
// Build LLM context
const llmContext: Context = {
systemPrompt: context.systemPrompt,
messages: [...processedMessages].map((m) => {
if (m.role === "toolResult") {
// biome-ignore lint/correctness/noUnusedVariables: fine here
const { details, ...rest } = m;
return rest;
} else {
return m;
}
}),
tools: context.tools, // AgentTool extends Tool, so this works
messages: llmMessages,
tools: context.tools,
};
// Use custom stream function if provided, otherwise use default streamSimple
const streamFunction = streamFn || streamSimple;
// Resolve API key for every assistant response (important for expiring tokens)
// Resolve API key (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
const response = await streamFunction(config.model, processedContext, { ...config, apiKey: resolvedApiKey, signal });
const response = await streamFunction(config.model, llmContext, {
...config,
apiKey: resolvedApiKey,
signal,
});
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
@ -220,7 +243,11 @@ async function streamAssistantResponse(
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } });
stream.push({
type: "message_update",
assistantMessageEvent: event,
message: { ...partialMessage },
});
}
break;
@ -244,16 +271,19 @@ async function streamAssistantResponse(
return await response.result();
}
async function executeToolCalls<T>(
tools: AgentTool<any, T>[] | undefined,
/**
* Execute tool calls from an assistant message.
*/
async function executeToolCalls(
tools: AgentTool<any>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
stream: EventStream<AgentEvent, AgentMessage[]>,
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{ toolResults: ToolResultMessage<T>[]; queuedMessages?: QueuedMessage<any>[] }> {
): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage<any>[] = [];
let queuedMessages: QueuedMessage<any>[] | undefined;
const results: ToolResultMessage[] = [];
let queuedMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
@ -266,16 +296,14 @@ async function executeToolCalls<T>(
args: toolCall.arguments,
});
let result: AgentToolResult<T>;
let result: AgentToolResult<any>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
// Validate arguments using shared validation function
const validatedArgs = validateToolArguments(tool, toolCall);
// Execute with validated, typed arguments, passing update callback
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
@ -288,7 +316,7 @@ async function executeToolCalls<T>(
} catch (e) {
result = {
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
details: {} as T,
details: {},
};
isError = true;
}
@ -301,7 +329,7 @@ async function executeToolCalls<T>(
isError,
});
const toolResultMessage: ToolResultMessage<T> = {
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
@ -315,6 +343,7 @@ async function executeToolCalls<T>(
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
// Check for queued messages - skip remaining tools if user interrupted
if (getQueuedMessages) {
const queued = await getQueuedMessages();
if (queued.length > 0) {
@ -331,13 +360,13 @@ async function executeToolCalls<T>(
return { toolResults: results, queuedMessages };
}
function skipToolCall<T>(
function skipToolCall(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, Message[]>,
): ToolResultMessage<T> {
const result: AgentToolResult<T> = {
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {} as T,
details: {},
};
stream.push({
@ -354,12 +383,12 @@ function skipToolCall<T>(
isError: true,
});
const toolResultMessage: ToolResultMessage<T> = {
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
details: {},
isError: true,
timestamp: Date.now(),
};

View file

@ -1,62 +1,66 @@
import type { ImageContent, Message, QueuedMessage, ReasoningEffort, TextContent } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import type { AgentTransport } from "./transports/types.js";
import type { AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "./types.js";
/**
* Agent class that uses the agent-loop directly.
* No transport abstraction - calls streamSimple via the loop.
*/
import {
getModel,
type ImageContent,
type Message,
type Model,
type ReasoningEffort,
streamSimple,
type TextContent,
} from "@mariozechner/pi-ai";
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentState,
AgentTool,
StreamFn,
ThinkingLevel,
} from "./types.js";
/**
* Default message transformer: Keep only LLM-compatible messages, strip app-specific fields.
* Converts attachments to proper content blocks (images ImageContent, documents TextContent).
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
*/
function defaultMessageTransformer(messages: AppMessage[]): Message[] {
return messages
.filter((m) => {
// Only keep standard LLM message roles
return m.role === "user" || m.role === "assistant" || m.role === "toolResult";
})
.map((m) => {
if (m.role === "user") {
const { attachments, ...rest } = m as any;
// If no attachments, return as-is
if (!attachments || attachments.length === 0) {
return rest as Message;
}
// Convert attachments to content blocks
const content = Array.isArray(rest.content) ? [...rest.content] : [{ type: "text", text: rest.content }];
for (const attachment of attachments as Attachment[]) {
// Add image blocks for image attachments
if (attachment.type === "image") {
content.push({
type: "image",
data: attachment.content,
mimeType: attachment.mimeType,
} as ImageContent);
}
// Add text blocks for documents with extracted text
else if (attachment.type === "document" && attachment.extractedText) {
content.push({
type: "text",
text: `\n\n[Document: ${attachment.fileName}]\n${attachment.extractedText}`,
isDocument: true,
} as TextContent);
}
}
return { ...rest, content } as Message;
}
return m as Message;
});
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
}
export interface AgentOptions {
initialState?: Partial<AgentState>;
transport: AgentTransport;
// Transform app messages to LLM-compatible messages before sending to transport
messageTransformer?: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
// Queue mode: "all" = send all queued messages at once, "one-at-a-time" = send one queued message per turn
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
* Default filters to user/assistant/toolResult and converts attachments.
*/
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to context before convertToLlm.
* Use for context pruning, injecting external context, etc.
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn
*/
queueMode?: "all" | "one-at-a-time";
/**
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
*/
streamFn?: StreamFn;
/**
* Resolves an API key dynamically for each LLM call.
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
}
export class Agent {
@ -71,20 +75,25 @@ export class Agent {
pendingToolCalls: new Set<string>(),
error: undefined,
};
private listeners = new Set<(e: AgentEvent) => void>();
private abortController?: AbortController;
private transport: AgentTransport;
private messageTransformer: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
private messageQueue: Array<QueuedMessage<AppMessage>> = [];
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
private messageQueue: AgentMessage[] = [];
private queueMode: "all" | "one-at-a-time";
public streamFn: StreamFn;
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
private runningPrompt?: Promise<void>;
private resolveRunningPrompt?: () => void;
constructor(opts: AgentOptions) {
constructor(opts: AgentOptions = {}) {
this._state = { ...this._state, ...opts.initialState };
this.transport = opts.transport;
this.messageTransformer = opts.messageTransformer || defaultMessageTransformer;
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
this.transformContext = opts.transformContext;
this.queueMode = opts.queueMode || "one-at-a-time";
this.streamFn = opts.streamFn || streamSimple;
this.getApiKey = opts.getApiKey;
}
get state(): AgentState {
@ -96,12 +105,12 @@ export class Agent {
return () => this.listeners.delete(fn);
}
// State mutators - update internal state without emitting events
// State mutators
setSystemPrompt(v: string) {
this._state.systemPrompt = v;
}
setModel(m: typeof this._state.model) {
setModel(m: Model<any>) {
this._state.model = m;
}
@ -117,25 +126,20 @@ export class Agent {
return this.queueMode;
}
setTools(t: typeof this._state.tools) {
setTools(t: AgentTool<any>[]) {
this._state.tools = t;
}
replaceMessages(ms: AppMessage[]) {
replaceMessages(ms: AgentMessage[]) {
this._state.messages = ms.slice();
}
appendMessage(m: AppMessage) {
appendMessage(m: AgentMessage) {
this._state.messages = [...this._state.messages, m];
}
async queueMessage(m: AppMessage) {
// Transform message and queue it for injection at next turn
const transformed = await this.messageTransformer([m]);
this.messageQueue.push({
original: m,
llm: transformed[0], // undefined if filtered out
});
queueMessage(m: AgentMessage) {
this.messageQueue.push(m);
}
clearMessageQueue() {
@ -150,17 +154,10 @@ export class Agent {
this.abortController?.abort();
}
/**
* Returns a promise that resolves when the current prompt completes.
* Returns immediately resolved promise if no prompt is running.
*/
waitForIdle(): Promise<void> {
return this.runningPrompt ?? Promise.resolve();
}
/**
* Clear all messages and state. Call abort() first if a prompt is in flight.
*/
reset() {
this._state.messages = [];
this._state.isStreaming = false;
@ -170,86 +167,57 @@ export class Agent {
this.messageQueue = [];
}
async prompt(input: string, attachments?: Attachment[]) {
/** Send a prompt with an AgentMessage */
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
async prompt(input: string, images?: ImageContent[]): Promise<void>;
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
const model = this._state.model;
if (!model) {
throw new Error("No model configured");
}
if (!model) throw new Error("No model configured");
// Build user message with attachments
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
if (attachments?.length) {
for (const a of attachments) {
if (a.type === "image") {
content.push({ type: "image", data: a.content, mimeType: a.mimeType });
} else if (a.type === "document" && a.extractedText) {
content.push({
type: "text",
text: `\n\n[Document: ${a.fileName}]\n${a.extractedText}`,
isDocument: true,
} as TextContent);
}
let msgs: AgentMessage[];
if (Array.isArray(input)) {
msgs = input;
} else if (typeof input === "string") {
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
if (images && images.length > 0) {
content.push(...images);
}
msgs = [
{
role: "user",
content,
timestamp: Date.now(),
},
];
} else {
msgs = [input];
}
const userMessage: AppMessage = {
role: "user",
content,
attachments: attachments?.length ? attachments : undefined,
timestamp: Date.now(),
};
await this._runAgentLoop(userMessage);
await this._runLoop(msgs);
}
/**
* Continue from the current context without adding a new user message.
* Used for retry after overflow recovery when context already has user message or tool results.
*/
/** Continue from current context (for retry after overflow) */
async continue() {
const messages = this._state.messages;
if (messages.length === 0) {
throw new Error("No messages to continue from");
}
const lastMessage = messages[messages.length - 1];
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
throw new Error(`Cannot continue from message role: ${lastMessage.role}`);
if (messages[messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
await this._runAgentLoopContinue();
await this._runLoop(undefined);
}
/**
* Internal: Run the agent loop with a new user message.
* Run the agent loop.
* If messages are provided, starts a new conversation turn with those messages.
* Otherwise, continues from existing context.
*/
private async _runAgentLoop(userMessage: AppMessage) {
const { llmMessages, cfg } = await this._prepareRun();
const events = this.transport.run(llmMessages, userMessage as Message, cfg, this.abortController!.signal);
await this._processEvents(events);
}
/**
* Internal: Continue the agent loop from current context.
*/
private async _runAgentLoopContinue() {
const { llmMessages, cfg } = await this._prepareRun();
const events = this.transport.continue(llmMessages, cfg, this.abortController!.signal);
await this._processEvents(events);
}
/**
* Prepare for running the agent loop.
*/
private async _prepareRun() {
private async _runLoop(messages?: AgentMessage[]) {
const model = this._state.model;
if (!model) {
throw new Error("No model configured");
}
if (!model) throw new Error("No model configured");
this.runningPrompt = new Promise<void>((resolve) => {
this.resolveRunningPrompt = resolve;
@ -265,87 +233,90 @@ export class Agent {
? undefined
: this._state.thinkingLevel === "minimal"
? "low"
: this._state.thinkingLevel;
: (this._state.thinkingLevel as ReasoningEffort);
const cfg = {
const context: AgentContext = {
systemPrompt: this._state.systemPrompt,
messages: this._state.messages.slice(),
tools: this._state.tools,
};
const config: AgentLoopConfig = {
model,
reasoning,
getQueuedMessages: async <T>() => {
convertToLlm: this.convertToLlm,
transformContext: this.transformContext,
getApiKey: this.getApiKey,
getQueuedMessages: async () => {
if (this.queueMode === "one-at-a-time") {
if (this.messageQueue.length > 0) {
const first = this.messageQueue[0];
this.messageQueue = this.messageQueue.slice(1);
return [first] as QueuedMessage<T>[];
return [first];
}
return [];
} else {
const queued = this.messageQueue.slice();
this.messageQueue = [];
return queued as QueuedMessage<T>[];
return queued;
}
},
};
const llmMessages = await this.messageTransformer(this._state.messages);
return { llmMessages, cfg, model };
}
/**
* Process events from the transport.
*/
private async _processEvents(events: AsyncIterable<AgentEvent>) {
const model = this._state.model!;
const generatedMessages: AppMessage[] = [];
let partial: AppMessage | null = null;
let partial: AgentMessage | null = null;
try {
for await (const ev of events) {
switch (ev.type) {
case "message_start": {
partial = ev.message as AppMessage;
this._state.streamMessage = ev.message as Message;
const stream = messages
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
for await (const event of stream) {
// Update internal state based on events
switch (event.type) {
case "message_start":
partial = event.message;
this._state.streamMessage = event.message;
break;
}
case "message_update": {
partial = ev.message as AppMessage;
this._state.streamMessage = ev.message as Message;
case "message_update":
partial = event.message;
this._state.streamMessage = event.message;
break;
}
case "message_end": {
case "message_end":
partial = null;
this._state.streamMessage = null;
this.appendMessage(ev.message as AppMessage);
generatedMessages.push(ev.message as AppMessage);
this.appendMessage(event.message);
break;
}
case "tool_execution_start": {
const s = new Set(this._state.pendingToolCalls);
s.add(ev.toolCallId);
s.add(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "tool_execution_end": {
const s = new Set(this._state.pendingToolCalls);
s.delete(ev.toolCallId);
s.delete(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "turn_end": {
if (ev.message.role === "assistant" && ev.message.errorMessage) {
this._state.error = ev.message.errorMessage;
case "turn_end":
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
this._state.error = (event.message as any).errorMessage;
}
break;
}
case "agent_end": {
case "agent_end":
this._state.isStreaming = false;
this._state.streamMessage = null;
break;
}
}
this.emit(ev as AgentEvent);
// Emit to listeners
this.emit(event);
}
// Handle any remaining partial message
@ -357,8 +328,7 @@ export class Agent {
(c.type === "toolCall" && c.name.trim().length > 0),
);
if (!onlyEmpty) {
this.appendMessage(partial as AppMessage);
generatedMessages.push(partial as AppMessage);
this.appendMessage(partial);
} else {
if (this.abortController?.signal.aborted) {
throw new Error("Request was aborted");
@ -366,7 +336,7 @@ export class Agent {
}
}
} catch (err: any) {
const msg: Message = {
const errorMsg: AgentMessage = {
role: "assistant",
content: [{ type: "text", text: "" }],
api: model.api,
@ -383,10 +353,11 @@ export class Agent {
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
errorMessage: err?.message || String(err),
timestamp: Date.now(),
};
this.appendMessage(msg as AppMessage);
generatedMessages.push(msg as AppMessage);
} as AgentMessage;
this.appendMessage(errorMsg);
this._state.error = err?.message || String(err);
this.emit({ type: "agent_end", messages: [errorMsg] });
} finally {
this._state.isStreaming = false;
this._state.streamMessage = null;

View file

@ -1,22 +1,6 @@
// Core Agent
export { Agent, type AgentOptions } from "./agent.js";
// Transports
export {
type AgentRunConfig,
type AgentTransport,
AppTransport,
type AppTransportOptions,
ProviderTransport,
type ProviderTransportOptions,
type ProxyAssistantMessageEvent,
} from "./transports/index.js";
export * from "./agent.js";
// Loop functions
export * from "./agent-loop.js";
// Types
export type {
AgentEvent,
AgentState,
AppMessage,
Attachment,
CustomMessages,
ThinkingLevel,
UserMessageWithAttachments,
} from "./types.js";
export * from "./types.js";

340
packages/agent/src/proxy.ts Normal file
View file

@ -0,0 +1,340 @@
/**
* Proxy stream function for apps that route LLM calls through a server.
* The server manages auth and proxies requests to LLM providers.
*/
import {
type AssistantMessage,
type AssistantMessageEvent,
type Context,
EventStream,
type Model,
type SimpleStreamOptions,
type StopReason,
type ToolCall,
} from "@mariozechner/pi-ai";
// Internal import for JSON parsing utility
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
// Create stream class matching ProxyMessageEventStream
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
/**
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
usage: AssistantMessage["usage"];
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
errorMessage?: string;
usage: AssistantMessage["usage"];
};
export interface ProxyStreamOptions extends SimpleStreamOptions {
/** Auth token for the proxy server */
authToken: string;
/** Proxy server URL (e.g., "https://genai.example.com") */
proxyUrl: string;
}
/**
* Stream function that proxies through a server instead of calling LLM providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
*
* @example
* ```typescript
* const agent = new Agent({
* streamFn: (model, context, options) =>
* streamProxy(model, context, {
* ...options,
* authToken: await getAuthToken(),
* proxyUrl: "https://genai.example.com",
* }),
* });
* ```
*/
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
const stream = new ProxyMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${options.proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response
}
throw new Error(errorMessage);
}
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
const event = processProxyEvent(proxyEvent, partial);
if (event) {
stream.push(event);
}
}
}
}
}
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
const reason = options.signal?.aborted ? "aborted" : "error";
partial.stopReason = reason;
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason,
error: partial,
});
stream.end();
} finally {
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
/**
* Process a proxy event and update the partial message.
*/
function processProxyEvent(
proxyEvent: ProxyAssistantMessageEvent,
partial: AssistantMessage,
): AssistantMessageEvent | undefined {
switch (proxyEvent.type) {
case "start":
return { type: "start", partial };
case "text_start":
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
return {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received text_delta for non-text content");
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
return {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
}
throw new Error("Received text_end for non-text content");
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
return {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received thinking_delta for non-thinking content");
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
return {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
}
throw new Error("Received thinking_end for non-thinking content");
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments = parseStreamingJson((content as any).partialJson) || {};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
return {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received toolcall_delta for non-toolCall content");
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
return {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
return undefined;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
return { type: "done", reason: proxyEvent.reason, message: partial };
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
return { type: "error", reason: proxyEvent.reason, error: partial };
default: {
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
return undefined;
}
}
}

View file

@ -1,397 +0,0 @@
import type {
AgentContext,
AgentLoopConfig,
Api,
AssistantMessage,
AssistantMessageEvent,
Context,
Message,
Model,
SimpleStreamOptions,
ToolCall,
UserMessage,
} from "@mariozechner/pi-ai";
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai";
import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js";
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
import type { ProxyAssistantMessageEvent } from "./proxy-types.js";
import type { AgentRunConfig, AgentTransport } from "./types.js";
/**
* Stream function that proxies through a server instead of calling providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*/
function streamSimpleProxy(
model: Model<any>,
context: Context,
options: SimpleStreamOptions & { authToken: string },
proxyUrl: string,
): AssistantMessageEventStream {
const stream = new AssistantMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
// Set up abort handler to cancel the reader
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
// Don't send apiKey or signal - those are added server-side
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response, use default message
}
throw new Error(errorMessage);
}
// Parse SSE stream
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Check if aborted after reading
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
let event: AssistantMessageEvent | undefined;
// Handle different event types
// Server sends events with partial for non-delta events,
// and without partial for delta events
switch (proxyEvent.type) {
case "start":
event = { type: "start", partial };
break;
case "text_start":
partial.content[proxyEvent.contentIndex] = {
type: "text",
text: "",
};
event = { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
event = {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
} else {
throw new Error("Received text_delta for non-text content");
}
break;
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
event = {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
} else {
throw new Error("Received text_end for non-text content");
}
break;
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = {
type: "thinking",
thinking: "",
};
event = { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
event = {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
} else {
throw new Error("Received thinking_delta for non-thinking content");
}
break;
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
event = {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
} else {
throw new Error("Received thinking_end for non-thinking content");
}
break;
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
event = { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments = parseStreamingJson((content as any).partialJson) || {};
event = {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
} else {
throw new Error("Received toolcall_delta for non-toolCall content");
}
break;
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
event = {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
break;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
event = { type: "done", reason: proxyEvent.reason, message: partial };
break;
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
event = { type: "error", reason: proxyEvent.reason, error: partial };
break;
default: {
// Exhaustive check
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled event type: ${(proxyEvent as any).type}`);
break;
}
}
// Push the event to stream
if (event) {
stream.push(event);
} else {
throw new Error("Failed to create event from proxy event");
}
}
}
}
}
// Check if aborted after reading
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
partial.stopReason = options.signal?.aborted ? "aborted" : "error";
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason: partial.stopReason,
error: partial,
} satisfies AssistantMessageEvent);
stream.end();
} finally {
// Clean up abort handler
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
export interface AppTransportOptions {
/**
* Proxy server URL. The server manages user accounts and proxies requests to LLM providers.
* Example: "https://genai.mariozechner.at"
*/
proxyUrl: string;
/**
* Function to retrieve auth token for the proxy server.
* The token is used for user authentication and authorization.
*/
getAuthToken: () => Promise<string> | string;
}
/**
* Transport that uses an app server with user authentication tokens.
* The server manages user accounts and proxies requests to LLM providers.
*/
export class AppTransport implements AgentTransport {
private options: AppTransportOptions;
constructor(options: AppTransportOptions) {
this.options = options;
}
private async getStreamFn(authToken: string) {
return <TApi extends Api>(model: Model<TApi>, context: Context, options?: SimpleStreamOptions) => {
return streamSimpleProxy(
model,
context,
{
...options,
authToken,
},
this.options.proxyUrl,
);
};
}
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
return {
systemPrompt: cfg.systemPrompt,
messages,
tools: cfg.tools,
};
}
private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig {
return {
model: cfg.model,
reasoning: cfg.reasoning,
getQueuedMessages: cfg.getQueuedMessages,
};
}
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
const authToken = await this.options.getAuthToken();
if (!authToken) {
throw new Error("Auth token is required for AppTransport");
}
const streamFn = await this.getStreamFn(authToken);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(cfg);
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) {
yield ev;
}
}
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
const authToken = await this.options.getAuthToken();
if (!authToken) {
throw new Error("Auth token is required for AppTransport");
}
const streamFn = await this.getStreamFn(authToken);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(cfg);
for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) {
yield ev;
}
}
}

View file

@ -1,85 +0,0 @@
import {
type AgentContext,
type AgentLoopConfig,
agentLoop,
agentLoopContinue,
type Message,
type UserMessage,
} from "@mariozechner/pi-ai";
import type { AgentRunConfig, AgentTransport } from "./types.js";
export interface ProviderTransportOptions {
/**
* Function to retrieve API key for a given provider.
* If not provided, transport will try to use environment variables.
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Optional CORS proxy URL for browser environments.
* If provided, all requests will be routed through this proxy.
* Format: "https://proxy.example.com"
*/
corsProxyUrl?: string;
}
/**
* Transport that calls LLM providers directly.
* Optionally routes calls through a CORS proxy if configured.
*/
export class ProviderTransport implements AgentTransport {
private options: ProviderTransportOptions;
constructor(options: ProviderTransportOptions = {}) {
this.options = options;
}
private getModel(cfg: AgentRunConfig) {
let model = cfg.model;
if (this.options.corsProxyUrl && cfg.model.baseUrl) {
model = {
...cfg.model,
baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`,
};
}
return model;
}
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
return {
systemPrompt: cfg.systemPrompt,
messages,
tools: cfg.tools,
};
}
private buildLoopConfig(model: AgentRunConfig["model"], cfg: AgentRunConfig): AgentLoopConfig {
return {
model,
reasoning: cfg.reasoning,
// Resolve API key per assistant response (important for expiring OAuth tokens)
getApiKey: this.options.getApiKey,
getQueuedMessages: cfg.getQueuedMessages,
};
}
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
const model = this.getModel(cfg);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(model, cfg);
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) {
yield ev;
}
}
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
const model = this.getModel(cfg);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(model, cfg);
for await (const ev of agentLoopContinue(context, pc, signal)) {
yield ev;
}
}
}

View file

@ -1,4 +0,0 @@
export { AppTransport, type AppTransportOptions } from "./AppTransport.js";
export { ProviderTransport, type ProviderTransportOptions } from "./ProviderTransport.js";
export type { ProxyAssistantMessageEvent } from "./proxy-types.js";
export type { AgentRunConfig, AgentTransport } from "./types.js";

View file

@ -1,20 +0,0 @@
import type { StopReason, Usage } from "@mariozechner/pi-ai";
/**
* Event types emitted by the proxy server.
* The server strips the `partial` field from delta events to reduce bandwidth.
* Clients reconstruct the partial message from these events.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; usage: Usage }
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; errorMessage: string; usage: Usage };

View file

@ -1,32 +0,0 @@
import type { AgentEvent, AgentTool, Message, Model, QueuedMessage, ReasoningEffort } from "@mariozechner/pi-ai";
/**
* The minimal configuration needed to run an agent turn.
*/
export interface AgentRunConfig {
systemPrompt: string;
tools: AgentTool<any>[];
model: Model<any>;
reasoning?: ReasoningEffort;
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
}
/**
* Transport interface for executing agent turns.
* Transports handle the communication with LLM providers,
* abstracting away the details of API calls, proxies, etc.
*
* Events yielded must match the @mariozechner/pi-ai AgentEvent types.
*/
export interface AgentTransport {
/** Run with a new user message */
run(
messages: Message[],
userMessage: Message,
config: AgentRunConfig,
signal?: AbortSignal,
): AsyncIterable<AgentEvent>;
/** Continue from current context (no new user message) */
continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable<AgentEvent>;
}

View file

@ -1,26 +1,86 @@
import type {
AgentTool,
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
streamSimple,
TextContent,
Tool,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import type { Static, TSchema } from "@sinclair/typebox";
/** Stream function - can return sync or Promise for async config lookup */
export type StreamFn = (
...args: Parameters<typeof streamSimple>
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
/**
* Attachment type definition.
* Processing is done by consumers (e.g., document extraction in web-ui).
* Configuration for the agent loop.
*/
export interface Attachment {
id: string;
type: "image" | "document";
fileName: string;
mimeType: string;
size: number;
content: string; // base64 encoded (without data URL prefix)
extractedText?: string; // For documents
preview?: string; // base64 image preview
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
*
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
* status messages) should be filtered out.
*
* @example
* ```typescript
* convertToLlm: (messages) => messages.flatMap(m => {
* if (m.role === "hookMessage") {
* // Convert custom message to user message
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
* }
* if (m.role === "notification") {
* // Filter out UI-only messages
* return [];
* }
* // Pass through standard LLM messages
* return [m];
* })
* ```
*/
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to the context before `convertToLlm`.
*
* Use this for operations that work at the AgentMessage level:
* - Context window management (pruning old messages)
* - Injecting context from external sources
*
* @example
* ```typescript
* transformContext: async (messages) => {
* if (estimateTokens(messages) > MAX_TOKENS) {
* return pruneOldMessages(messages);
* }
* return messages;
* }
* ```
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Resolves an API key dynamically for each LLM call.
*
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
* during long-running tool execution phases.
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Returns queued messages to inject into the conversation.
*
* Called after each turn to check for user interruptions or injected messages.
* If messages are returned, they're added to the context before the next LLM call.
*/
getQueuedMessages?: () => Promise<AgentMessage[]>;
}
/**
@ -29,11 +89,6 @@ export interface Attachment {
*/
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
/**
* User message with optional attachments.
*/
export type UserMessageWithAttachments = UserMessage & { attachments?: Attachment[] };
/**
* Extensible interface for custom app messages.
* Apps can extend via declaration merging:
@ -41,27 +96,23 @@ export type UserMessageWithAttachments = UserMessage & { attachments?: Attachmen
* @example
* ```typescript
* declare module "@mariozechner/agent" {
* interface CustomMessages {
* interface CustomAgentMessages {
* artifact: ArtifactMessage;
* notification: NotificationMessage;
* }
* }
* ```
*/
export interface CustomMessages {
export interface CustomAgentMessages {
// Empty by default - apps extend via declaration merging
}
/**
* AppMessage: Union of LLM messages + attachments + custom messages.
* AgentMessage: Union of LLM messages + custom messages.
* This abstraction allows apps to add custom message types while maintaining
* type safety and compatibility with the base LLM messages.
*/
export type AppMessage =
| AssistantMessage
| UserMessageWithAttachments
| Message // Includes ToolResultMessage
| CustomMessages[keyof CustomMessages];
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
/**
* Agent state containing all configuration and conversation data.
@ -71,13 +122,42 @@ export interface AgentState {
model: Model<any>;
thinkingLevel: ThinkingLevel;
tools: AgentTool<any>[];
messages: AppMessage[]; // Can include attachments + custom message types
messages: AgentMessage[]; // Can include attachments + custom message types
isStreaming: boolean;
streamMessage: Message | null;
streamMessage: AgentMessage | null;
pendingToolCalls: Set<string>;
error?: string;
}
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: AgentMessage[];
tools?: AgentTool<any>[];
}
/**
* Events emitted by the Agent for UI updates.
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
@ -85,15 +165,15 @@ export interface AgentState {
export type AgentEvent =
// Agent lifecycle
| { type: "agent_start" }
| { type: "agent_end"; messages: AppMessage[] }
| { type: "agent_end"; messages: AgentMessage[] }
// Turn lifecycle - a turn is one assistant response + any tool calls/results
| { type: "turn_start" }
| { type: "turn_end"; message: AppMessage; toolResults: ToolResultMessage[] }
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
// Message lifecycle - emitted for user, assistant, and toolResult messages
| { type: "message_start"; message: AppMessage }
| { type: "message_start"; message: AgentMessage }
// Only emitted for assistant messages during streaming
| { type: "message_update"; message: AppMessage; assistantMessageEvent: AssistantMessageEvent }
| { type: "message_end"; message: AppMessage }
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
| { type: "message_end"; message: AgentMessage }
// Tool execution lifecycle
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }

View file

@ -0,0 +1,535 @@
import {
type AssistantMessage,
type AssistantMessageEvent,
EventStream,
type Message,
type Model,
type UserMessage,
} from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AgentTool } from "../src/types.js";
// Mock stream for testing - mimics MockAssistantStream
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
function createAssistantMessage(
content: AssistantMessage["content"],
stopReason: AssistantMessage["stopReason"] = "stop",
): AssistantMessage {
return {
role: "assistant",
content,
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason,
timestamp: Date.now(),
};
}
function createUserMessage(text: string): UserMessage {
return {
role: "user",
content: text,
timestamp: Date.now(),
};
}
// Simple identity converter for tests - just passes through standard messages
function identityConverter(messages: AgentMessage[]): Message[] {
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
}
describe("agentLoop with AgentMessage", () => {
it("should emit events with AgentMessage types", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Hi there!" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have user message and assistant message
expect(messages.length).toBe(2);
expect(messages[0].role).toBe("user");
expect(messages[1].role).toBe("assistant");
// Verify event sequence
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
});
it("should handle custom message types via convertToLlm", async () => {
// Create a custom message type
interface CustomNotification {
role: "notification";
text: string;
timestamp: number;
}
const notification: CustomNotification = {
role: "notification",
text: "This is a notification",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [notification as unknown as AgentMessage], // Custom message in context
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Filter out notifications, convert rest
convertedMessages = messages
.filter((m) => (m as { role: string }).role !== "notification")
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
// The notification should have been filtered out in convertToLlm
expect(convertedMessages.length).toBe(1); // Only user message
expect(convertedMessages[0].role).toBe("user");
});
it("should apply transformContext before convertToLlm", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [
createUserMessage("old message 1"),
createAssistantMessage([{ type: "text", text: "old response 1" }]),
createUserMessage("old message 2"),
createAssistantMessage([{ type: "text", text: "old response 2" }]),
],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("new message");
let transformedMessages: AgentMessage[] = [];
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
transformContext: async (messages) => {
// Keep only last 2 messages (prune old ones)
transformedMessages = messages.slice(-2);
return transformedMessages;
},
convertToLlm: (messages) => {
convertedMessages = messages.filter(
(m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult",
) as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
for await (const _ of stream) {
// consume
}
// transformContext should have been called first, keeping only last 2
expect(transformedMessages.length).toBe(2);
// Then convertToLlm receives the pruned messages
expect(convertedMessages.length).toBe(2);
});
it("should handle tool calls and results", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `echoed: ${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("echo something");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
let callIndex = 0;
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return tool call
const message = createAssistantMessage(
[{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "hello" } }],
"toolUse",
);
stream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([{ type: "text", text: "done" }]);
stream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
// Tool should have been executed
expect(executed).toEqual(["hello"]);
// Should have tool execution events
const toolStart = events.find((e) => e.type === "tool_execution_start");
const toolEnd = events.find((e) => e.type === "tool_execution_end");
expect(toolStart).toBeDefined();
expect(toolEnd).toBeDefined();
if (toolEnd?.type === "tool_execution_end") {
expect(toolEnd.isError).toBe(false);
}
});
it("should inject queued messages and skip remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("start");
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
let queuedDelivered = false;
let callIndex = 0;
let sawInterruptInContext = false;
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
getQueuedMessages: async () => {
// Return queued message after first tool executes
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return [queuedUserMessage];
}
return [];
},
};
const events: AgentEvent[] = [];
const stream = agentLoop([userPrompt], context, config, undefined, (_model, ctx, _options) => {
// Check if interrupt message is in context on second call
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
const mockStream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return two tool calls
const message = createAssistantMessage(
[
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
"toolUse",
);
mockStream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([{ type: "text", text: "done" }]);
mockStream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return mockStream;
});
for await (const event of stream) {
events.push(event);
}
// Only first tool should have executed
expect(executed).toEqual(["first"]);
// Second tool should be skipped
const toolEnds = events.filter(
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> => e.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[0].isError).toBe(false);
expect(toolEnds[1].isError).toBe(true);
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
// Queued message should appear in events
const queuedMessageEvent = events.find(
(e) =>
e.type === "message_start" &&
e.message.role === "user" &&
typeof e.message.content === "string" &&
e.message.content === "interrupt",
);
expect(queuedMessageEvent).toBeDefined();
// Interrupt message should be in context when second LLM call is made
expect(sawInterruptInContext).toBe(true);
});
});
describe("agentLoopContinue with AgentMessage", () => {
it("should throw when context has no messages", () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
expect(() => agentLoopContinue(context, config)).toThrow("Cannot continue: no messages in context");
});
it("should continue from existing context without emitting user message events", async () => {
const userMessage: AgentMessage = createUserMessage("Hello");
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should only return the new assistant message (not the existing user message)
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Should NOT have user message events (that's the key difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1);
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
it("should allow custom message types as last message (caller responsibility)", async () => {
// Custom message that will be converted to user message by convertToLlm
interface HookMessage {
role: "hookMessage";
text: string;
timestamp: number;
}
const hookMessage: HookMessage = {
role: "hookMessage",
text: "Hook content",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [hookMessage as unknown as AgentMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Convert hookMessage to user message
return messages
.map((m) => {
if ((m as any).role === "hookMessage") {
return {
role: "user" as const,
content: (m as any).text,
timestamp: m.timestamp,
};
}
return m;
})
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response to hook" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
// Should not throw - the hookMessage will be converted to user message
const stream = agentLoopContinue(context, config, undefined, streamFn);
const events: AgentEvent[] = [];
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
});
});

View file

@ -1,12 +1,10 @@
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
import { Agent } from "../src/index.js";
describe("Agent", () => {
it("should create an agent instance with default state", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
expect(agent.state).toBeDefined();
expect(agent.state.systemPrompt).toBe("");
@ -23,7 +21,6 @@ describe("Agent", () => {
it("should create an agent instance with custom initial state", () => {
const customModel = getModel("openai", "gpt-4o-mini");
const agent = new Agent({
transport: new ProviderTransport(),
initialState: {
systemPrompt: "You are a helpful assistant.",
model: customModel,
@ -37,9 +34,7 @@ describe("Agent", () => {
});
it("should subscribe to events", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
let eventCount = 0;
const unsubscribe = agent.subscribe((_event) => {
@ -61,9 +56,7 @@ describe("Agent", () => {
});
it("should update state with mutators", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
// Test setSystemPrompt
agent.setSystemPrompt("Custom prompt");
@ -101,38 +94,19 @@ describe("Agent", () => {
});
it("should support message queueing", async () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
const message = { role: "user" as const, content: "Queued message", timestamp: Date.now() };
await agent.queueMessage(message);
agent.queueMessage(message);
// The message is queued but not yet in state.messages
expect(agent.state.messages).not.toContainEqual(message);
});
it("should handle abort controller", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
// Should not throw even if nothing is running
expect(() => agent.abort()).not.toThrow();
});
});
describe("ProviderTransport", () => {
it("should create a provider transport instance", () => {
const transport = new ProviderTransport();
expect(transport).toBeDefined();
});
it("should create a provider transport with options", () => {
const transport = new ProviderTransport({
getApiKey: async (provider) => `test-key-${provider}`,
corsProxyUrl: "https://proxy.example.com",
});
expect(transport).toBeDefined();
});
});

View file

@ -1,25 +1,8 @@
import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { calculateTool, getModel } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
function createTransport() {
return new ProviderTransport({
getApiKey: async (provider) => {
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
});
}
import { Agent } from "../src/index.js";
import { calculateTool } from "./utils/calculate.js";
async function basicPrompt(model: Model<any>) {
const agent = new Agent({
@ -29,7 +12,6 @@ async function basicPrompt(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("What is 2+2? Answer with just the number.");
@ -57,7 +39,6 @@ async function toolExecution(model: Model<any>) {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
@ -99,7 +80,6 @@ async function abortExecution(model: Model<any>) {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
@ -129,7 +109,6 @@ async function stateUpdates(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
const events: Array<string> = [];
@ -162,7 +141,6 @@ async function multiTurnConversation(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("My name is Alice.");
@ -356,7 +334,6 @@ describe("Agent.continue()", () => {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
@ -368,7 +345,6 @@ describe("Agent.continue()", () => {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
const assistantMessage: AssistantMessage = {
@ -405,7 +381,6 @@ describe("Agent.continue()", () => {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
// Manually add a user message without calling prompt()
@ -445,7 +420,6 @@ describe("Agent.continue()", () => {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
// Set up a conversation state as if tool was just executed

View file

@ -1,5 +1,5 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool, AgentToolResult } from "../../agent/types.js";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface CalculateResult extends AgentToolResult<undefined> {
content: Array<{ type: "text"; text: string }>;

View file

@ -1,6 +1,5 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool } from "../../agent/index.js";
import type { AgentToolResult } from "../types.js";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {}

View file

@ -2,6 +2,10 @@
## [Unreleased]
### Breaking Changes
- **Agent API moved**: All agent functionality (`agentLoop`, `agentLoopContinue`, `AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, etc.) has moved to `@mariozechner/pi-agent-core`. See the [agent-core README](../agent/README.md) for documentation.
## [0.28.0] - 2025-12-25
### Breaking Changes

View file

@ -782,276 +782,6 @@ const continuation = await complete(newModel, restored);
> **Note**: If the context contains images (encoded as base64 as shown in the Image Input section), those will also be serialized.
## Agent API
The Agent API provides a higher-level interface for building agents with tools. It handles tool execution, validation, and provides detailed event streaming for interactive applications.
### Event System
The Agent API streams events during execution, allowing you to build reactive UIs and track agent progress. The agent processes prompts in **turns**, where each turn consists of:
1. An assistant message (the LLM's response)
2. Optional tool executions if the assistant calls tools
3. Tool result messages that are fed back to the LLM
This continues until the assistant produces a response without tool calls.
**Queued messages**: If you provide `getQueuedMessages` in the loop config, the agent checks for queued user messages after each tool call. When queued messages are found, any remaining tool calls from the current assistant message are skipped and returned as error tool results (`isError: true`) with the message "Skipped due to queued user message." The queued user messages are injected before the next assistant response.
### Event Flow Example
Given a prompt asking to calculate two expressions and sum them:
```typescript
import { agentLoop, AgentContext, calculateTool } from '@mariozechner/pi-ai';
const context: AgentContext = {
systemPrompt: 'You are a helpful math assistant.',
messages: [],
tools: [calculateTool]
};
const stream = agentLoop(
{ role: 'user', content: 'Calculate 15 * 20 and 30 * 40, then sum the results', timestamp: Date.now() },
context,
{ model: getModel('openai', 'gpt-4o-mini') }
);
// Expected event sequence:
// 1. agent_start - Agent begins processing
// 2. turn_start - First turn begins
// 3. message_start - User message starts
// 4. message_end - User message ends
// 5. message_start - Assistant message starts
// 6. message_update - Assistant streams response with tool calls
// 7. message_end - Assistant message ends
// 8. tool_execution_start - First calculation (15 * 20)
// 9. tool_execution_update - Streaming progress (for long-running tools)
// 10. tool_execution_end - Result: 300
// 11. tool_execution_start - Second calculation (30 * 40)
// 12. tool_execution_update - Streaming progress
// 13. tool_execution_end - Result: 1200
// 12. message_start - Tool result message for first calculation
// 13. message_end - Tool result message ends
// 14. message_start - Tool result message for second calculation
// 15. message_end - Tool result message ends
// 16. turn_end - First turn ends with 2 tool results
// 17. turn_start - Second turn begins
// 18. message_start - Assistant message starts
// 19. message_update - Assistant streams response with sum calculation
// 20. message_end - Assistant message ends
// 21. tool_execution_start - Sum calculation (300 + 1200)
// 22. tool_execution_end - Result: 1500
// 23. message_start - Tool result message for sum
// 24. message_end - Tool result message ends
// 25. turn_end - Second turn ends with 1 tool result
// 26. turn_start - Third turn begins
// 27. message_start - Final assistant message starts
// 28. message_update - Assistant streams final answer
// 29. message_end - Final assistant message ends
// 30. turn_end - Third turn ends with 0 tool results
// 31. agent_end - Agent completes with all messages
```
### Handling Events
```typescript
for await (const event of stream) {
switch (event.type) {
case 'agent_start':
console.log('Agent started');
break;
case 'turn_start':
console.log('New turn started');
break;
case 'message_start':
console.log(`${event.message.role} message started`);
break;
case 'message_update':
// Only for assistant messages during streaming
if (event.message.content.some(c => c.type === 'text')) {
console.log('Assistant:', event.message.content);
}
break;
case 'tool_execution_start':
console.log(`Calling ${event.toolName} with:`, event.args);
break;
case 'tool_execution_update':
// Streaming progress for long-running tools (e.g., bash output)
console.log(`Progress:`, event.partialResult.content);
break;
case 'tool_execution_end':
if (event.isError) {
console.error(`Tool failed:`, event.result);
} else {
console.log(`Tool result:`, event.result.content);
}
break;
case 'turn_end':
console.log(`Turn ended with ${event.toolResults.length} tool calls`);
break;
case 'agent_end':
console.log(`Agent completed with ${event.messages.length} new messages`);
break;
}
}
// Get all messages generated during this agent execution
// These include the user message and can be directly appended to context.messages
const messages = await stream.result();
context.messages.push(...messages);
```
### Continuing from Existing Context
Use `agentLoopContinue` to resume an agent loop without adding a new user message. This is useful for:
- Retrying after context overflow (after compaction reduces context size)
- Resuming from tool results that were added manually to the context
```typescript
import { agentLoopContinue, AgentContext } from '@mariozechner/pi-ai';
// Context already has messages - last must be 'user' or 'toolResult'
const context: AgentContext = {
systemPrompt: 'You are helpful.',
messages: [userMessage, assistantMessage, toolResult],
tools: [myTool]
};
// Continue processing from the tool result
const stream = agentLoopContinue(context, { model });
for await (const event of stream) {
// Same events as agentLoop, but no user message events emitted
}
const newMessages = await stream.result();
```
**Validation**: Throws if context has no messages or if the last message is an assistant message.
### Defining Tools with TypeBox
Tools use TypeBox schemas for runtime validation and type inference:
```typescript
import { Type, Static, AgentTool, AgentToolResult, StringEnum } from '@mariozechner/pi-ai';
const weatherSchema = Type.Object({
city: Type.String({ minLength: 1 }),
units: StringEnum(['celsius', 'fahrenheit'], { default: 'celsius' })
});
type WeatherParams = Static<typeof weatherSchema>;
const weatherTool: AgentTool<typeof weatherSchema, { temp: number }> = {
label: 'Get Weather',
name: 'get_weather',
description: 'Get current weather for a city',
parameters: weatherSchema,
execute: async (toolCallId, args, signal, onUpdate) => {
// args is fully typed: { city: string, units: 'celsius' | 'fahrenheit' }
// signal: AbortSignal for cancellation
// onUpdate: Optional callback for streaming progress (emits tool_execution_update events)
const temp = Math.round(Math.random() * 30);
return {
content: [{ type: 'text', text: `Temperature in ${args.city}: ${temp}°${args.units[0].toUpperCase()}` }],
details: { temp }
};
}
};
// Tools can also return images alongside text
const chartTool: AgentTool<typeof Type.Object({ data: Type.Array(Type.Number()) })> = {
label: 'Generate Chart',
name: 'generate_chart',
description: 'Generate a chart from data',
parameters: Type.Object({ data: Type.Array(Type.Number()) }),
execute: async (toolCallId, args) => {
const chartImage = await generateChartImage(args.data);
return {
content: [
{ type: 'text', text: `Generated chart with ${args.data.length} data points` },
{ type: 'image', data: chartImage.toString('base64'), mimeType: 'image/png' }
]
};
}
};
// Tools can stream progress via the onUpdate callback (emits tool_execution_update events)
const bashTool: AgentTool<typeof Type.Object({ command: Type.String() }), { exitCode: number }> = {
label: 'Run Bash',
name: 'bash',
description: 'Execute a bash command',
parameters: Type.Object({ command: Type.String() }),
execute: async (toolCallId, args, signal, onUpdate) => {
let output = '';
const child = spawn('bash', ['-c', args.command]);
child.stdout.on('data', (data) => {
output += data.toString();
// Stream partial output to UI via tool_execution_update events
onUpdate?.({
content: [{ type: 'text', text: output }],
details: { exitCode: -1 } // Not finished yet
});
});
const exitCode = await new Promise<number>((resolve) => {
child.on('close', resolve);
});
return {
content: [{ type: 'text', text: output }],
details: { exitCode }
};
}
};
```
### Validation and Error Handling
Tool arguments are automatically validated using AJV with the TypeBox schema. Invalid arguments result in detailed error messages:
```typescript
// If the LLM calls with invalid arguments:
// get_weather({ city: '', units: 'kelvin' })
// The tool execution will fail with:
/*
Validation failed for tool "get_weather":
- city: must NOT have fewer than 1 characters
- units: must be equal to one of the allowed values
Received arguments:
{
"city": "",
"units": "kelvin"
}
*/
```
### Built-in Example Tools
The library includes example tools for common operations:
```typescript
import { calculateTool, getCurrentTimeTool } from '@mariozechner/pi-ai';
const context: AgentContext = {
systemPrompt: 'You are a helpful assistant.',
messages: [],
tools: [calculateTool, getCurrentTimeTool]
};
```
## Browser Usage
The library supports browser environments. You must pass the API key explicitly since environment variables are not available in browsers:

View file

@ -1,11 +0,0 @@
export { agentLoop, agentLoopContinue } from "./agent-loop.js";
export * from "./tools/index.js";
export type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentTool,
AgentToolResult,
AgentToolUpdateCallback,
QueuedMessage,
} from "./types.js";

View file

@ -1,2 +0,0 @@
export { calculate, calculateTool } from "./calculate.js";
export { getCurrentTime, getCurrentTimeTool } from "./get-current-time.js";

View file

@ -1,105 +0,0 @@
import type { Static, TSchema } from "@sinclair/typebox";
import type {
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
TextContent,
Tool,
ToolResultMessage,
} from "../types.js";
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: Message[];
tools?: AgentTool<any>[];
}
// Event types
export type AgentEvent =
// Emitted when the agent starts. An agent can emit multiple turns
| { type: "agent_start" }
// Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages
| { type: "turn_start" }
// Emitted when a user, assistant or tool result message starts
| { type: "message_start"; message: Message }
// Emitted when an asssitant messages is updated due to streaming
| { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage }
// Emitted when a user, assistant or tool result message is complete
| { type: "message_end"; message: Message }
// Emitted when a tool execution starts
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
// Emitted when a tool execution produces output (streaming)
| {
type: "tool_execution_update";
toolCallId: string;
toolName: string;
args: any;
partialResult: AgentToolResult<any>;
}
// Emitted when a tool execution completes
| {
type: "tool_execution_end";
toolCallId: string;
toolName: string;
result: AgentToolResult<any>;
isError: boolean;
}
// Emitted when a full turn completes
| { type: "turn_end"; message: AssistantMessage; toolResults: ToolResultMessage[] }
// Emitted when the agent has completed all its turns. All messages from every turn are
// contained in messages, which can be appended to the context
| { type: "agent_end"; messages: AgentContext["messages"] };
// Queued message with optional LLM representation
export interface QueuedMessage<TApp = Message> {
original: TApp; // Original message for UI events
llm?: Message; // Optional transformed message for loop context (undefined if filtered)
}
// Configuration for agent loop execution
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Optional hook to resolve an API key dynamically for each LLM call.
*
* This is useful for short-lived OAuth tokens (e.g. GitHub Copilot) that may
* expire during long-running tool execution phases.
*
* The agent loop will call this before each assistant response and pass the
* returned value as `apiKey` to `streamSimple()` (or a custom `streamFn`).
*
* If it returns `undefined`, the loop falls back to `config.apiKey`, and then
* to `streamSimple()`'s own provider key lookup (setApiKey/env vars).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
}

View file

@ -1,4 +1,3 @@
export * from "./agent/index.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/google.js";
@ -7,6 +6,7 @@ export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/oauth/index.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";

View file

@ -3325,13 +3325,13 @@ export const MODELS = {
reasoning: true,
input: ["text"],
cost: {
input: 0.224,
output: 0.32,
input: 0.25,
output: 0.38,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 163840,
maxTokens: 4096,
maxTokens: 65536,
} satisfies Model<"openai-completions">,
"deepseek/deepseek-v3.2-exp": {
id: "deepseek/deepseek-v3.2-exp",
@ -3892,7 +3892,7 @@ export const MODELS = {
cacheWrite: 0,
},
contextWindow: 196608,
maxTokens: 131072,
maxTokens: 65536,
} satisfies Model<"openai-completions">,
"minimax/minimax-m2.1": {
id: "minimax/minimax-m2.1",
@ -5371,7 +5371,7 @@ export const MODELS = {
cacheWrite: 0,
},
contextWindow: 131072,
maxTokens: 128000,
maxTokens: 131072,
} satisfies Model<"openai-completions">,
"openai/gpt-oss-safeguard-20b": {
id: "openai/gpt-oss-safeguard-20b",
@ -6249,8 +6249,8 @@ export const MODELS = {
reasoning: true,
input: ["text"],
cost: {
input: 0.3,
output: 1.2,
input: 0.25,
output: 0.85,
cacheRead: 0,
cacheWrite: 0,
},
@ -6266,8 +6266,8 @@ export const MODELS = {
reasoning: true,
input: ["text"],
cost: {
input: 0.3,
output: 1.2,
input: 0.25,
output: 0.85,
cacheRead: 0,
cacheWrite: 0,
},
@ -6538,13 +6538,13 @@ export const MODELS = {
reasoning: true,
input: ["text"],
cost: {
input: 0.39,
output: 1.9,
input: 0.35,
output: 1.5,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: 204800,
maxTokens: 204800,
contextWindow: 202752,
maxTokens: 65536,
} satisfies Model<"openai-completions">,
"z-ai/glm-4.6:exacto": {
id: "z-ai/glm-4.6:exacto",

View file

@ -4,7 +4,7 @@
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
*/
import type { Content, ThinkingConfig, ThinkingLevel } from "@google/genai";
import type { Content, ThinkingConfig } from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
@ -21,6 +21,12 @@ import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { convertMessages, convertTools, mapStopReasonString, mapToolChoice } from "./google-shared.js";
/**
* Thinking level for Gemini 3 models.
* Mirrors Google's ThinkingLevel enum values.
*/
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
export interface GoogleGeminiCliOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
/**
@ -35,7 +41,7 @@ export interface GoogleGeminiCliOptions extends StreamOptions {
/** Thinking budget in tokens. Use for Gemini 2.x models. */
budgetTokens?: number;
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
level?: ThinkingLevel;
level?: GoogleThinkingLevel;
};
projectId?: string;
}
@ -436,7 +442,8 @@ function buildRequest(
};
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
if (options.thinking.level !== undefined) {
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level;
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}

View file

@ -3,7 +3,6 @@ import {
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
type ThinkingLevel,
} from "@google/genai";
import { calculateCost } from "../models.js";
import { getEnvApiKey } from "../stream.js";
@ -20,6 +19,7 @@ import type {
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import { convertMessages, convertTools, mapStopReason, mapToolChoice } from "./google-shared.js";
export interface GoogleOptions extends StreamOptions {
@ -27,7 +27,7 @@ export interface GoogleOptions extends StreamOptions {
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: ThinkingLevel;
level?: GoogleThinkingLevel;
};
}
@ -299,7 +299,8 @@ function buildParams(
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
thinkingConfig.thinkingLevel = options.thinking.level;
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}

View file

@ -1,8 +1,11 @@
import { ThinkingLevel } from "@google/genai";
import { supportsXhigh } from "./models.js";
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
import { type GoogleGeminiCliOptions, streamGoogleGeminiCli } from "./providers/google-gemini-cli.js";
import {
type GoogleGeminiCliOptions,
type GoogleThinkingLevel,
streamGoogleGeminiCli,
} from "./providers/google-gemini-cli.js";
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
import type {
@ -30,9 +33,13 @@ export function getEnvApiKey(provider: any): string | undefined {
return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
}
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
if (provider === "anthropic") {
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
}
const envMap: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
@ -252,53 +259,56 @@ function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return model.id.includes("3-flash");
}
function getGemini3ThinkingLevel(effort: ClampedReasoningEffort, model: Model<"google-generative-ai">): ThinkingLevel {
function getGemini3ThinkingLevel(
effort: ClampedReasoningEffort,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
// Gemini 3 Pro only supports LOW/HIGH (for now)
switch (effort) {
case "minimal":
case "low":
return ThinkingLevel.LOW;
return "LOW";
case "medium":
case "high":
return ThinkingLevel.HIGH;
return "HIGH";
}
}
// Gemini 3 Flash supports all four levels
switch (effort) {
case "minimal":
return ThinkingLevel.MINIMAL;
return "MINIMAL";
case "low":
return ThinkingLevel.LOW;
return "LOW";
case "medium":
return ThinkingLevel.MEDIUM;
return "MEDIUM";
case "high":
return ThinkingLevel.HIGH;
return "HIGH";
}
}
function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): ThinkingLevel {
function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): GoogleThinkingLevel {
if (modelId.includes("3-pro")) {
// Gemini 3 Pro only supports LOW/HIGH (for now)
switch (effort) {
case "minimal":
case "low":
return ThinkingLevel.LOW;
return "LOW";
case "medium":
case "high":
return ThinkingLevel.HIGH;
return "HIGH";
}
}
// Gemini 3 Flash supports all four levels
switch (effort) {
case "minimal":
return ThinkingLevel.MINIMAL;
return "MINIMAL";
case "low":
return ThinkingLevel.LOW;
return "LOW";
case "medium":
return ThinkingLevel.MEDIUM;
return "MEDIUM";
case "high":
return ThinkingLevel.HIGH;
return "HIGH";
}
}

View file

@ -2,25 +2,16 @@
* Anthropic OAuth flow (Claude Pro/Max)
*/
import { createHash, randomBytes } from "crypto";
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js";
const decode = (s: string) => Buffer.from(s, "base64").toString();
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
/**
* Generate PKCE code verifier and challenge
*/
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = randomBytes(32).toString("base64url");
const challenge = createHash("sha256").update(verifier).digest("base64url");
return { verifier, challenge };
}
/**
* Login with Anthropic OAuth (device code flow)
*
@ -31,7 +22,7 @@ export async function loginAnthropic(
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = generatePKCE();
const { verifier, challenge } = await generatePKCE();
// Build authorization URL
const authParams = new URLSearchParams({

View file

@ -5,7 +5,7 @@
import { getModels } from "../../models.js";
import type { OAuthCredentials } from "./types.js";
const decode = (s: string) => Buffer.from(s, "base64").toString();
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
const COPILOT_HEADERS = {

View file

@ -1,14 +1,17 @@
/**
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import { createHash, randomBytes } from "crypto";
import { createServer, type Server } from "http";
import type { Server } from "http";
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js";
// Antigravity OAuth credentials (different from Gemini CLI)
const decode = (s: string) => Buffer.from(s, "base64").toString();
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
);
@ -30,19 +33,15 @@ const TOKEN_URL = "https://oauth2.googleapis.com/token";
// Fallback project ID when discovery fails
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
/**
* Generate PKCE code verifier and challenge
*/
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = randomBytes(32).toString("base64url");
const challenge = createHash("sha256").update(verifier).digest("base64url");
return { verifier, challenge };
}
/**
* Start a local HTTP server to receive the OAuth callback
*/
function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> {
async function startCallbackServer(): Promise<{
server: Server;
getCode: () => Promise<{ code: string; state: string }>;
}> {
const { createServer } = await import("http");
return new Promise((resolve, reject) => {
let codeResolve: (value: { code: string; state: string }) => void;
let codeReject: (error: Error) => void;
@ -232,7 +231,7 @@ export async function loginAntigravity(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
): Promise<OAuthCredentials> {
const { verifier, challenge } = generatePKCE();
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");

View file

@ -1,13 +1,16 @@
/**
* Gemini CLI OAuth flow (Google Cloud Code Assist)
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import { createHash, randomBytes } from "crypto";
import { createServer, type Server } from "http";
import type { Server } from "http";
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials } from "./types.js";
const decode = (s: string) => Buffer.from(s, "base64").toString();
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
);
@ -22,19 +25,15 @@ const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
/**
* Generate PKCE code verifier and challenge
*/
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = randomBytes(32).toString("base64url");
const challenge = createHash("sha256").update(verifier).digest("base64url");
return { verifier, challenge };
}
/**
* Start a local HTTP server to receive the OAuth callback
*/
function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> {
async function startCallbackServer(): Promise<{
server: Server;
getCode: () => Promise<{ code: string; state: string }>;
}> {
const { createServer } = await import("http");
return new Promise((resolve, reject) => {
let codeResolve: (value: { code: string; state: string }) => void;
let codeReject: (error: Error) => void;
@ -263,7 +262,7 @@ export async function loginGeminiCli(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
): Promise<OAuthCredentials> {
const { verifier, challenge } = generatePKCE();
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");

View file

@ -0,0 +1,34 @@
/**
* PKCE utilities using Web Crypto API.
* Works in both Node.js 20+ and browsers.
*/
/**
* Encode bytes as base64url string.
*/
function base64urlEncode(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) {
binary += String.fromCharCode(byte);
}
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
}
/**
* Generate PKCE code verifier and challenge.
* Uses Web Crypto API for cross-platform compatibility.
*/
export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
// Generate random verifier
const verifierBytes = new Uint8Array(32);
crypto.getRandomValues(verifierBytes);
const verifier = base64urlEncode(verifierBytes);
// Compute SHA-256 challenge
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
return { verifier, challenge };
}

View file

@ -1,166 +0,0 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop } from "../src/agent/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js";
import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js";
import { AssistantMessageEventStream } from "../src/utils/event-stream.js";
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
describe("agentLoop queued message interrupt", () => {
it("injects queued messages after a tool call and skips remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: UserMessage = {
role: "user",
content: "start",
timestamp: Date.now(),
};
const queuedUserMessage: Message = {
role: "user",
content: "interrupt",
timestamp: Date.now(),
};
const queuedMessages: QueuedMessage<Message>[] = [{ original: queuedUserMessage, llm: queuedUserMessage }];
let queuedDelivered = false;
let sawInterruptInContext = false;
let callIndex = 0;
const streamFn = () => {
const stream = new AssistantMessageEventStream();
queueMicrotask(() => {
if (callIndex === 0) {
const message: AssistantMessage = {
role: "assistant",
content: [
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "toolUse",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "toolUse", message });
} else {
const message: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "done" }],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "stop", message });
}
callIndex += 1;
});
return stream;
};
const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async <T>() => {
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return queuedMessages as QueuedMessage<T>[];
}
return [];
};
const config: AgentLoopConfig = {
model: createModel(),
getQueuedMessages,
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
return streamFn();
});
for await (const event of stream) {
events.push(event);
}
expect(executed).toEqual(["first"]);
const toolEnds = events.filter(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[1].isError).toBe(true);
expect(toolEnds[1].result.content[0]?.type).toBe("text");
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end");
const queuedMessageIndex = events.findIndex(
(event) =>
event.type === "message_start" &&
event.message.role === "user" &&
typeof event.message.content === "string" &&
event.message.content === "interrupt",
);
const nextAssistantIndex = events.findIndex(
(event, index) =>
index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant",
);
expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex);
expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex);
expect(sawInterruptInContext).toBe(true);
});
});

View file

@ -1,701 +0,0 @@
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js";
import { calculateTool } from "../src/agent/tools/calculate.js";
import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js";
import { getModel } from "../src/models.js";
import type {
Api,
AssistantMessage,
Message,
Model,
OptionsForApi,
ToolResultMessage,
UserMessage,
} from "../src/types.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
]);
const [anthropicOAuthToken, githubCopilotToken, geminiCliToken, antigravityToken] = oauthTokens;
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create the user prompt asking for multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: `Use the calculator tool to complete the following mulit-step task.
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
2. Calculate the sum of the two results using the calculator tool
3. Output ONLY the final sum as a single integer number, nothing else.`,
timestamp: Date.now(),
};
// Calculate expected results (using integers)
const expectedFirst = 3485 * 4234; // = 14755490
const expectedSecond = 88823 * 3482; // = 309281786
const expectedSum = expectedFirst + expectedSecond; // = 324037276
// Track events for verification
const events: AgentEvent[] = [];
let turns = 0;
let toolCallCount = 0;
const toolResults: number[] = [];
let finalAnswer: number | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config);
for await (const event of stream) {
events.push(event);
switch (event.type) {
case "turn_start":
turns++;
console.log(`\n=== Turn ${turns} started ===`);
break;
case "turn_end":
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
console.log(event.message);
break;
case "tool_execution_end":
if (!event.isError && typeof event.result === "object" && event.result.content) {
const textOutput = event.result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
toolCallCount++;
// Extract number from output like "expression = result"
const match = textOutput.match(/=\s*([\d.]+)/);
if (match) {
const value = parseFloat(match[1]);
toolResults.push(value);
console.log(`Tool ${toolCallCount}: ${textOutput}`);
}
}
break;
case "message_end":
// Just track the message end event, don't extract answer here
break;
}
}
// Get the final messages
const finalMessages = await stream.result();
// Verify the results
expect(finalMessages).toBeDefined();
expect(finalMessages.length).toBeGreaterThan(0);
const finalMessage = finalMessages[finalMessages.length - 1];
expect(finalMessage).toBeDefined();
expect(finalMessage.role).toBe("assistant");
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Extract the final answer from the last assistant message
const content = finalMessage.content
.filter((c) => c.type === "text")
.map((c) => (c.type === "text" ? c.text : ""))
.join(" ");
// Look for integers in the response that might be the final answer
const numbers = content.match(/\b\d+\b/g);
if (numbers) {
// Check if any of the numbers matches our expected sum
for (const num of numbers) {
const value = parseInt(num, 10);
if (Math.abs(value - expectedSum) < 10) {
finalAnswer = value;
break;
}
}
// If no exact match, take the last large number as likely the answer
if (finalAnswer === undefined) {
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
if (largeNumbers.length > 0) {
finalAnswer = largeNumbers[largeNumbers.length - 1];
}
}
}
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
// (or possibly 2 if the model calculates the sum itself without a tool)
expect(toolCallCount).toBeGreaterThanOrEqual(2);
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
expect(turns).toBeGreaterThanOrEqual(3);
expect(turns).toBeLessThanOrEqual(4);
// Verify the individual calculations are in the results
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
expect(hasFirstCalc).toBe(true);
expect(hasSecondCalc).toBe(true);
// Verify the final sum
if (finalAnswer !== undefined) {
expect(finalAnswer).toBe(expectedSum);
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
} else {
// If we couldn't extract the final answer from text, check if it's in the tool results
const hasSum = toolResults.some((r) => r === expectedSum);
expect(hasSum).toBe(true);
}
// Log summary
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
if (turns === 3) {
console.log("Model used parallel tool calls for initial calculations");
} else {
console.log("Model used sequential tool calls");
}
return {
turns,
toolCallCount,
toolResults,
finalAnswer,
events,
};
}
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create a prompt that will require multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
timestamp: Date.now(),
};
// Create abort controller
const abortController = new AbortController();
// Track events for verification
const events: AgentEvent[] = [];
let toolCallCount = 0;
const errorReceived = false;
let finalMessages: Message[] | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config, abortController.signal);
// Abort after first tool execution
(async () => {
for await (const event of stream) {
events.push(event);
if (event.type === "tool_execution_end" && !event.isError) {
toolCallCount++;
// Abort after first successful tool execution
if (toolCallCount === 1) {
console.log("Aborting after first tool execution");
abortController.abort();
}
}
if (event.type === "agent_end") {
finalMessages = event.messages;
}
}
})();
finalMessages = await stream.result();
// Verify abort behavior
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
const assistantMessage = finalMessages[finalMessages.length - 1];
if (!assistantMessage) throw new Error("No final message received");
expect(assistantMessage).toBeDefined();
expect(assistantMessage.role).toBe("assistant");
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Should have executed 1 tool call before abort
expect(toolCallCount).toBeGreaterThanOrEqual(1);
expect(assistantMessage.stopReason).toBe("aborted");
return {
toolCallCount,
events,
errorReceived,
finalMessages,
};
}
describe("Agent Calculator Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
const model = getModel("openai", "gpt-5-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
const model = getModel("xai", "grok-3");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
const model = getModel("zai", "glm-4.5-air");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider Agent", () => {
const model = getModel("mistral", "devstral-medium-latest");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it.skipIf(!anthropicOAuthToken)(
"should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const result = await calculateTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!anthropicOAuthToken)("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe("GitHub Copilot Provider Agent", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)("gpt-4o - should handle abort during tool execution", { retry: 3 }, async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Gemini CLI Provider Agent", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await calculateTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await abortTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Antigravity Provider Agent", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gemini-3-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
});
describe("agentLoopContinue", () => {
describe("validation", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
const baseContext: AgentContext = {
systemPrompt: "You are a helpful assistant.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = { model };
it("should throw when context has no messages", () => {
expect(() => agentLoopContinue(baseContext, config)).toThrow("Cannot continue: no messages in context");
});
it("should throw when last message is an assistant message", () => {
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "Hello" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const context: AgentContext = {
...baseContext,
messages: [assistantMessage],
};
expect(() => agentLoopContinue(context, config)).toThrow(
"Cannot continue from message role: assistant. Expected 'user' or 'toolResult'.",
);
});
// Note: "should not throw" tests for valid inputs are covered by the E2E tests below
// which actually consume the stream and verify the output
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and get assistant response when last message is user", { retry: 3 }, async () => {
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Verify event sequence - no user message events since we're continuing
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
// Should NOT have user message events (that's the difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1); // Only assistant message
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue processing after tool results", { retry: 3 }, async () => {
// Simulate a conversation where:
// 1. User asked to calculate something
// 2. Assistant made a tool call
// 3. Tool result is ready
// 4. We continue from here
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What is 5 + 3? Use the calculator." }],
timestamp: Date.now(),
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{ type: "text", text: "Let me calculate that for you." },
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now(),
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: "calc-1",
toolName: "calculate",
content: [{ type: "text", text: "5 + 3 = 8" }],
isError: false,
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. After getting a calculation result, state the answer clearly.",
messages: [userMessage, assistantMessage, toolResult],
tools: [calculateTool],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBeGreaterThanOrEqual(1);
const lastMessage = messages[messages.length - 1];
expect(lastMessage.role).toBe("assistant");
// The assistant should mention the result (8)
if (lastMessage.role === "assistant") {
const textContent = lastMessage.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join(" ");
expect(textContent).toMatch(/8/);
}
});
});
});

View file

@ -752,13 +752,11 @@ describe("Generate E2E Tests", () => {
const llm = getModel("google-gemini-cli", "gemini-3-flash-preview");
it.skipIf(!geminiCliToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => {
const { ThinkingLevel } = await import("@google/genai");
await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.LOW } });
await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "LOW" } });
});
it.skipIf(!geminiCliToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => {
const { ThinkingLevel } = await import("@google/genai");
await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } });
await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "MEDIUM" } });
});
});
@ -778,17 +776,15 @@ describe("Generate E2E Tests", () => {
});
it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => {
const { ThinkingLevel } = await import("@google/genai");
// gemini-3-flash supports all four levels: MINIMAL, LOW, MEDIUM, HIGH
await handleThinking(llm, {
apiKey: antigravityToken,
thinking: { enabled: true, level: ThinkingLevel.LOW },
thinking: { enabled: true, level: "LOW" },
});
});
it.skipIf(!antigravityToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => {
const { ThinkingLevel } = await import("@google/genai");
await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } });
await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: "MEDIUM" } });
});
it.skipIf(!antigravityToken)("should handle image input", { retry: 3 }, async () => {
@ -800,11 +796,10 @@ describe("Generate E2E Tests", () => {
const llm = getModel("google-antigravity", "gemini-3-pro-high");
it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel HIGH", { retry: 3 }, async () => {
const { ThinkingLevel } = await import("@google/genai");
// gemini-3-pro only supports LOW/HIGH
await handleThinking(llm, {
apiKey: antigravityToken,
thinking: { enabled: true, level: ThinkingLevel.HIGH },
thinking: { enabled: true, level: "HIGH" },
});
});
});

View file

@ -1,140 +0,0 @@
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import { describe, expect, it } from "vitest";
import type { AgentTool } from "../src/agent/types.js";
describe("Tool Validation with TypeBox and AJV", () => {
// Define a test tool with TypeBox schema
const testSchema = Type.Object({
name: Type.String({ minLength: 1 }),
age: Type.Integer({ minimum: 0, maximum: 150 }),
email: Type.String({ format: "email" }),
tags: Type.Optional(Type.Array(Type.String())),
});
type TestParams = Static<typeof testSchema>;
const testTool: AgentTool<typeof testSchema, void> = {
label: "Test Tool",
name: "test_tool",
description: "A test tool for validation",
parameters: testSchema,
execute: async (_toolCallId, args) => {
return {
content: [{ type: "text", text: `Processed: ${args.name}, ${args.age}, ${args.email}` }],
details: undefined,
};
},
};
// Create AJV instance for validation
const ajv = new Ajv({ allErrors: true });
addFormats(ajv);
it("should validate correct input", () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
tags: ["developer", "typescript"],
};
// Validate with AJV
const validate = ajv.compile(testTool.parameters);
const isValid = validate(validInput);
expect(isValid).toBe(true);
});
it("should reject invalid email", () => {
const invalidInput = {
name: "John Doe",
age: 30,
email: "not-an-email",
};
const validate = ajv.compile(testTool.parameters);
const isValid = validate(invalidInput);
expect(isValid).toBe(false);
expect(validate.errors).toBeDefined();
});
it("should reject missing required fields", () => {
const invalidInput = {
age: 30,
email: "john@example.com",
};
const validate = ajv.compile(testTool.parameters);
const isValid = validate(invalidInput);
expect(isValid).toBe(false);
expect(validate.errors).toBeDefined();
});
it("should reject invalid age", () => {
const invalidInput = {
name: "John Doe",
age: -5,
email: "john@example.com",
};
const validate = ajv.compile(testTool.parameters);
const isValid = validate(invalidInput);
expect(isValid).toBe(false);
expect(validate.errors).toBeDefined();
});
it("should format validation errors nicely", () => {
const invalidInput = {
name: "",
age: 200,
email: "invalid",
};
const validate = ajv.compile(testTool.parameters);
const isValid = validate(invalidInput);
expect(isValid).toBe(false);
expect(validate.errors).toBeDefined();
if (validate.errors) {
const errors = validate.errors
.map((err: any) => {
const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root";
return ` - ${path}: ${err.message}`;
})
.join("\n");
// AJV error messages are different from Zod
expect(errors).toContain("name: must NOT have fewer than 1 characters");
expect(errors).toContain("age: must be <= 150");
expect(errors).toContain('email: must match format "email"');
}
});
it("should have type-safe execute function", async () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
};
// Validate and execute
const validate = ajv.compile(testTool.parameters);
const isValid = validate(validInput);
expect(isValid).toBe(true);
const result = await testTool.execute("test-id", validInput as TestParams);
const textOutput = result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
expect(textOutput).toBe("Processed: John Doe, 30, john@example.com");
expect(result.details).toBeUndefined();
});
});

View file

@ -2,13 +2,68 @@
## [Unreleased]
### Breaking Changes
- **Session tree structure (v2)**: Sessions now store entries as a tree with `id`/`parentId` fields, enabling in-place branching without creating new files. Existing v1 sessions are auto-migrated on load.
- **SessionManager API**:
- `saveXXX()` renamed to `appendXXX()` (e.g., `appendMessage`, `appendCompaction`)
- `branchInPlace()` renamed to `branch()`
- `reset()` renamed to `newSession()`
- `createBranchedSessionFromEntries(entries, index)` replaced with `createBranchedSession(leafId)`
- `saveCompaction(entry)` replaced with `appendCompaction(summary, firstKeptEntryId, tokensBefore)`
- `getEntries()` now excludes the session header (use `getHeader()` separately)
- New methods: `getTree()`, `getPath()`, `getLeafUuid()`, `getLeafEntry()`, `getEntry()`, `branchWithSummary()`
- New `appendCustomEntry(customType, data)` for hooks to store custom data (not in LLM context)
- New `appendCustomMessageEntry(customType, content, display, details?)` for hooks to inject messages into LLM context
- **Compaction API**:
- `CompactionEntry<T>` and `CompactionResult<T>` are now generic with optional `details?: T` for hook-specific data
- `compact()` now returns `CompactionResult` (`{ summary, firstKeptEntryId, tokensBefore, details? }`) instead of `CompactionEntry`
- `appendCompaction()` now accepts optional `details` parameter
- `CompactionEntry.firstKeptEntryIndex` replaced with `firstKeptEntryId`
- `prepareCompaction()` now returns `firstKeptEntryId` in its result
- **Hook types**:
- `SessionEventBase` no longer has `sessionManager`/`modelRegistry` - access them via `HookEventContext` instead
- `HookEventContext` now has `sessionManager` and `modelRegistry` (moved from events)
- `HookEventContext` no longer has `exec()` - use `pi.exec()` instead
- `HookCommandContext` no longer has `exec()` - use `pi.exec()` instead
- `before_compact` event passes `preparation: CompactionPreparation` and `previousCompactions: CompactionEntry[]` (newest first)
- `before_switch` event now has `targetSessionFile`, `switch` event has `previousSessionFile`
- Removed `resolveApiKey` (use `modelRegistry.getApiKey(model)`)
- Hooks can return `compaction.details` to store custom data (e.g., ArtifactIndex for structured compaction)
- **Hook API**:
- `pi.send(text, attachments?)` replaced with `pi.sendMessage(message, triggerTurn?)` which creates `CustomMessageEntry` instead of user messages
- New `pi.appendEntry(customType, data?)` to persist hook state (does NOT participate in LLM context)
- New `pi.registerCommand(name, options)` to register custom slash commands
- New `pi.registerMessageRenderer(customType, renderer)` to register custom renderers for hook messages
- New `pi.exec(command, args, options?)` to execute shell commands (moved from `HookEventContext`/`HookCommandContext`)
- `HookMessageRenderer` type: `(message: HookMessage, options, theme) => Component | null`
- Renderers return inner content; the TUI wraps it in a styled Box
- New types: `HookMessage<T>`, `RegisteredCommand`, `HookCommandContext`
- Handler types renamed: `SendHandler``SendMessageHandler`, new `AppendEntryHandler`
- **SessionManager**:
- `getSessionFile()` now returns `string | undefined` (undefined for in-memory sessions)
- **Themes**: Custom themes must add `customMessageBg`, `customMessageText`, `customMessageLabel` color tokens
### Added
- **`enabledModels` setting**: Configure whitelisted models in `settings.json` (same format as `--models` CLI flag). CLI `--models` takes precedence over the setting.
### Changed
- **Entry IDs**: Session entries now use short 8-character hex IDs instead of full UUIDs
- **API key priority**: `ANTHROPIC_OAUTH_TOKEN` now takes precedence over `ANTHROPIC_API_KEY`
- **New entry types**: `BranchSummaryEntry` for branch context, `CustomEntry<T>` for hook state persistence, `CustomMessageEntry<T>` for hook-injected context messages, `LabelEntry` for user-defined bookmarks
- **Entry labels**: New `getLabel(id)` and `appendLabelChange(targetId, label)` methods for labeling entries. Labels are included in `SessionTreeNode` for UI/export.
- **TUI**: `CustomMessageEntry` renders with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors). Entries with `display: false` are hidden.
- **AgentSession**: New `sendHookMessage(message, triggerTurn?)` method for hooks to inject messages. Handles queuing during streaming, direct append when idle, and optional turn triggering.
- **HookMessage**: New message type with `role: "hookMessage"` for hook-injected messages in agent events. Use `isHookMessage(msg)` type guard to identify them. These are converted to user messages for LLM context via `messageTransformer`.
- **Agent.prompt()**: Now accepts `AppMessage` directly (in addition to `string, attachments?`) for custom message types like `HookMessage`.
### Fixed
- **Edit tool fails on Windows due to CRLF line endings**: Files with CRLF line endings now match correctly when LLMs send LF-only text. Line endings are normalized before matching and restored to original style on write. ([#355](https://github.com/badlogic/pi-mono/issues/355))
- **Session file validation**: `findMostRecentSession()` now validates session headers before returning, preventing non-session JSONL files from being loaded
- **Compaction error handling**: `generateSummary()` and `generateTurnPrefixSummary()` now throw on LLM errors instead of returning empty strings
## [0.30.2] - 2025-12-26

View file

@ -0,0 +1,441 @@
# Session Tree Implementation Plan
Reference: [session-tree.md](./session-tree.md)
## Phase 1: SessionManager Core ✅
- [x] Update entry types with `id`, `parentId` fields (using SessionEntryBase)
- [x] Add `version` field to `SessionHeader`
- [x] Change `CompactionEntry.firstKeptEntryIndex``firstKeptEntryId`
- [x] Add `BranchSummaryEntry` type
- [x] Add `CustomEntry` type for hooks
- [x] Add `byId: Map<string, SessionEntry>` index
- [x] Add `leafId: string` tracking
- [x] Implement `getPath(fromId?)` tree traversal
- [x] Implement `getTree()` returning `SessionTreeNode[]`
- [x] Implement `getEntry(id)` lookup
- [x] Implement `getLeafUuid()` and `getLeafEntry()` helpers
- [x] Update `_buildIndex()` to populate `byId` map
- [x] Rename `saveXXX()` to `appendXXX()` (returns id, advances leaf)
- [x] Add `appendCustomEntry(customType, data)` for hooks
- [x] Update `buildSessionContext()` to use `getPath()` traversal
## Phase 2: Migration ✅
- [x] Add `CURRENT_SESSION_VERSION = 2` constant
- [x] Implement `migrateV1ToV2()` with extensible migration chain
- [x] Update `setSessionFile()` to detect version and migrate
- [x] Implement `_rewriteFile()` for post-migration persistence
- [x] Handle `firstKeptEntryIndex``firstKeptEntryId` conversion in migration
## Phase 3: Branching ✅
- [x] Implement `branch(id)` - switch leaf pointer
- [x] Implement `branchWithSummary(id, summary)` - create summary entry
- [x] Implement `createBranchedSession(leafId)` - extract path to new file
- [x] Update `AgentSession.branch()` to use new API
## Phase 4: Compaction Integration ✅
- [x] Update `compaction.ts` to work with IDs
- [x] Update `prepareCompaction()` to return `firstKeptEntryId`
- [x] Update `compact()` to return `CompactionResult` with `firstKeptEntryId`
- [x] Update `AgentSession` compaction methods
- [x] Add `firstKeptEntryId` to `before_compact` hook event
## Phase 5: Testing ✅
- [x] `migration.test.ts` - v1 to v2 migration, idempotency
- [x] `build-context.test.ts` - context building with tree structure, compaction, branches
- [x] `tree-traversal.test.ts` - append operations, getPath, getTree, branching
- [x] `file-operations.test.ts` - loadEntriesFromFile, findMostRecentSession
- [x] `save-entry.test.ts` - custom entry integration
- [x] Update existing compaction tests for new types
---
## Remaining Work
### Compaction Refactor
- [x] Use `CompactionResult` type for hook return value
- [x] Make `CompactionEntry<T>` generic with optional `details?: T` field for hook-specific data
- [x] Make `CompactionResult<T>` generic to match
- [x] Update `SessionEventBase` to pass `sessionManager` and `modelRegistry` instead of derived fields
- [x] Update `before_compact` event:
- Pass `preparation: CompactionPreparation` instead of individual fields
- Pass `previousCompactions: CompactionEntry[]` (newest first) instead of `previousSummary?: string`
- Keep: `customInstructions`, `model`, `signal`
- Drop: `resolveApiKey` (use `modelRegistry.getApiKey()`), `cutPoint`, `entries`
- [x] Update hook example `custom-compaction.ts` to use new API
- [x] Update `getSessionFile()` to return `string | undefined` for in-memory sessions
- [x] Update `before_switch` to have `targetSessionFile`, `switch` to have `previousSessionFile`
Reference: [#314](https://github.com/badlogic/pi-mono/pull/314) - Structured compaction with anchored iterative summarization needs `details` field to store `ArtifactIndex` and version markers.
### Branch Summary Design ✅
Current type:
```typescript
export interface BranchSummaryEntry extends SessionEntryBase {
type: "branch_summary";
summary: string;
fromId: string; // References the abandoned leaf
fromHook?: boolean; // Whether summary was generated by a hook
details?: unknown; // File tracking: { readFiles, modifiedFiles }
}
```
- [x] `fromId` field references the abandoned leaf
- [x] `fromHook` field distinguishes pi-generated vs hook-generated summaries
- [x] `details` field for file tracking
- [x] Branch summarizer implemented with structured output format
- [x] Uses serialization approach (same as compaction) to prevent model confusion
- [x] Tests for `branchWithSummary()` flow
### Entry Labels ✅
- [x] Add `LabelEntry` type with `targetId` and `label` fields
- [x] Add `labelsById: Map<string, string>` private field
- [x] Build labels map in `_buildIndex()` via linear scan
- [x] Add `getLabel(id)` method
- [x] Add `appendLabelChange(targetId, label)` method (undefined clears)
- [x] Update `createBranchedSession()` to filter out LabelEntry and recreate from resolved map
- [x] `buildSessionContext()` already ignores LabelEntry (only handles message types)
- [x] Add `label?: string` to `SessionTreeNode`, populated by `getTree()`
- [x] Display labels in UI (tree-selector shows labels)
- [x] `/label` command (implemented in tree-selector)
### CustomMessageEntry<T>
Hook-injected messages that participate in LLM context. Unlike `CustomEntry<T>` (for hook state only), these are sent to the model.
```typescript
export interface CustomMessageEntry<T = unknown> extends SessionEntryBase {
type: "custom_message";
customType: string; // Hook identifier
content: string | (TextContent | ImageContent)[]; // Message content (same as UserMessage)
details?: T; // Hook-specific data for state reconstruction on reload
display: boolean; // Whether to display in TUI
}
```
Behavior:
- [x] Type definition matching plan
- [x] `appendCustomMessageEntry(customType, content, display, details?)` in SessionManager
- [x] `buildSessionContext()` includes custom_message entries as user messages
- [x] Exported from main index
- [x] TUI rendering:
- `display: false` - hidden entirely
- `display: true` - rendered with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors)
- [x] `registerCustomMessageRenderer(customType, renderer)` in HookAPI for custom renderers
- [x] Renderer returns inner Component, TUI wraps in styled Box
### Hook API Changes ✅
**Renamed:**
- `renderCustomMessage()``registerCustomMessageRenderer()`
**New: `sendMessage()` ✅**
Replaces `send()`. Always creates CustomMessageEntry, never user messages.
```typescript
type HookMessage<T = unknown> = Pick<CustomMessageEntry<T>, 'customType' | 'content' | 'display' | 'details'>;
sendMessage(message: HookMessage, triggerTurn?: boolean): void;
```
Implementation:
- Uses agent's queue mechanism with `_hookData` marker on AppMessage
- `message_end` handler routes based on marker presence
- `AgentSession.sendHookMessage()` handles three cases:
- Streaming: queues via `agent.queueMessage()`, loop processes and emits `message_end`
- Not streaming + triggerTurn: direct append + `agent.continue()`
- Not streaming + no trigger: direct append only
- TUI updates via event (streaming) or explicit rebuild (non-streaming)
**New: `appendEntry()` ✅**
For hook state persistence (NOT in LLM context):
```typescript
appendEntry(customType: string, data?: unknown): void;
```
Calls `sessionManager.appendCustomEntry()` directly.
**New: `registerCommand()` (types ✅, wiring TODO)**
```typescript
// HookAPI (the `pi` object) - utilities available to all hooks:
interface HookAPI {
sendMessage(message: HookMessage, triggerTurn?: boolean): void;
appendEntry(customType: string, data?: unknown): void;
registerCommand(name: string, options: RegisteredCommand): void;
registerCustomMessageRenderer(customType: string, renderer: CustomMessageRenderer): void;
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
}
// HookEventContext - passed to event handlers, has stable context:
interface HookEventContext {
ui: HookUIContext;
hasUI: boolean;
cwd: string;
sessionManager: SessionManager;
modelRegistry: ModelRegistry;
}
// Note: exec moved to HookAPI, sessionManager/modelRegistry moved from SessionEventBase
// HookCommandContext - passed to command handlers:
interface HookCommandContext {
args: string; // Everything after /commandname
ui: HookUIContext;
hasUI: boolean;
cwd: string;
sessionManager: SessionManager;
modelRegistry: ModelRegistry;
}
// Note: exec and sendMessage accessed via `pi` closure
registerCommand(name: string, options: {
description?: string;
handler: (ctx: HookCommandContext) => Promise<void>;
}): void;
```
Handler return:
- `void` - command completed (use `sendMessage()` with `triggerTurn: true` to prompt LLM)
Wiring (all in AgentSession.prompt()):
- [x] Add hook commands to autocomplete in interactive-mode
- [x] `_tryExecuteHookCommand()` in AgentSession handles command execution
- [x] Build HookCommandContext with ui (from hookRunner), exec, sessionManager, etc.
- [x] If handler returns string, use as prompt text
- [x] If handler returns undefined, return early (no LLM call)
- [x] Works for all modes (interactive, RPC, print) via shared AgentSession
**New: `ui.custom()` ✅**
For arbitrary hook UI with keyboard focus:
```typescript
interface HookUIContext {
// ... existing: select, confirm, input, notify
/** Show custom component with keyboard focus. Call done() when finished. */
custom(component: Component, done: () => void): void;
}
```
See also: `CustomEntry<T>` for storing hook state that does NOT participate in context.
**New: `context` event ✅**
Fires before messages are sent to the LLM, allowing hooks to modify context non-destructively.
```typescript
interface ContextEvent {
type: "context";
/** Messages that will be sent to the LLM */
messages: Message[];
}
interface ContextEventResult {
/** Modified messages to send instead */
messages?: Message[];
}
// In HookAPI:
on(event: "context", handler: HookHandler<ContextEvent, ContextEventResult | void>): void;
```
Example use case: **Dynamic Context Pruning** ([discussion #330](https://github.com/badlogic/pi-mono/discussions/330))
Non-destructive pruning of tool results to reduce context size:
```typescript
export default function(pi: HookAPI) {
// Register /prune command
pi.registerCommand("prune", {
description: "Mark tool results for pruning",
handler: async (ctx) => {
// Show UI to select which tool results to prune
// Append custom entry recording pruning decisions:
// { toolResultId, strategy: "summary" | "truncate" | "remove" }
pi.appendEntry("tool-result-pruning", { ... });
}
});
// Intercept context before LLM call
pi.on("context", async (event, ctx) => {
// Find all pruning entries in session
const entries = ctx.sessionManager.getEntries();
const pruningRules = entries
.filter(e => e.type === "custom" && e.customType === "tool-result-pruning")
.map(e => e.data);
// Apply pruning rules to messages
const prunedMessages = applyPruning(event.messages, pruningRules);
return { messages: prunedMessages };
});
}
```
Benefits:
- Original tool results stay intact in session
- Pruning is stored as custom entries, survives session reload
- Works with branching (pruning entries are part of the tree)
- Trade-off: cache busting on first submission after pruning
### Investigate: `context` event vs `before_agent_start`
References:
- [#324](https://github.com/badlogic/pi-mono/issues/324) - `before_agent_start` proposal
- [#330](https://github.com/badlogic/pi-mono/discussions/330) - Dynamic Context Pruning (why `context` was added)
**Current `context` event:**
- Fires before each LLM call within the agent loop
- Receives `AgentMessage[]` (deep copy, safe to modify)
- Returns `Message[]` (inconsistent with input type)
- Modifications are transient (not persisted to session)
- No TUI visibility of what was changed
- Use case: non-destructive pruning, dynamic context manipulation
**Type inconsistency:** Event receives `AgentMessage[]` but result returns `Message[]`:
```typescript
interface ContextEvent {
messages: AgentMessage[]; // Input
}
interface ContextEventResult {
messages?: Message[]; // Output - different type!
}
```
Questions:
- [ ] Should input/output both be `Message[]` (LLM format)?
- [ ] Or both be `AgentMessage[]` with conversion happening after?
- [ ] Where does `AgentMessage[]``Message[]` conversion currently happen?
**Proposed `before_agent_start` event:**
- Fires once when user submits a prompt, before `agent_start`
- Allows hooks to inject additional content that gets **persisted** to session
- Injected content is visible in TUI (observability)
- Does not bust prompt cache (appended after user message, not modifying system prompt)
**Key difference:**
| Aspect | `context` | `before_agent_start` |
|--------|-----------|---------------------|
| When | Before each LLM call | Once per user prompt |
| Persisted | No | Yes (as SystemMessage) |
| TUI visible | No | Yes (collapsible) |
| Cache impact | Can bust cache | Append-only, cache-safe |
| Use case | Transient manipulation | Persistent context injection |
**Implementation (completed):**
- Reuses `HookMessage` type (no new message type needed)
- Handler returns `{ message: Pick<HookMessage, "customType" | "content" | "display" | "details"> }`
- Message is appended to agent state AND persisted to session before `agent.prompt()` is called
- Renders using existing `HookMessageComponent` (or custom renderer if registered)
- [ ] How does it interact with compaction? (treated like user messages?)
- [ ] Can hook return multiple messages or just one?
**Implementation sketch:**
```typescript
interface BeforeAgentStartEvent {
type: "before_agent_start";
userMessage: UserMessage; // The prompt user just submitted
}
interface BeforeAgentStartResult {
/** Additional context to inject (persisted as SystemMessage) */
inject?: {
label: string; // Shown in collapsed TUI state
content: string | (TextContent | ImageContent)[];
};
}
```
### HTML Export
- [ ] Add collapsible sidebar showing full tree structure
- [ ] Allow selecting any node in tree to view that path
- [ ] Add "reset to session leaf" button
- [ ] Render full path (no compaction resolution needed)
- [ ] Responsive: collapse sidebar on mobile
### UI Commands ✅
- [x] `/branch` - Creates new session file from current path (uses `createBranchedSession()`)
- [x] `/tree` - In-session tree navigation via tree-selector component
- Shows full tree structure with labels
- Navigate between branches (moves leaf pointer)
- Shows current position
- Generates branch summaries when switching branches
### Tree Selector Improvements ✅
- [x] Active line highlight using `selectedBg` theme color
- [x] Filter modes via `^O` (forward) / `Shift+^O` (backward):
- `default`: hides label/custom entries
- `no-tools`: default minus tool results
- `user-only`: just user messages
- `labeled-only`: just labeled entries
- `all`: everything
### Documentation
Review and update all docs:
- [ ] `docs/hooks.md` - Major update for hook API:
- `pi.send()``pi.sendMessage()` with new signature
- New `pi.appendEntry()` for state persistence
- New `pi.registerCommand()` for custom slash commands
- New `pi.registerCustomMessageRenderer()` for custom TUI rendering
- `HookCommandContext` interface and handler patterns
- `HookMessage<T>` type
- Updated event signatures (`SessionEventBase`, `before_compact`, etc.)
- [ ] `docs/hooks-v2.md` - Review/merge or remove if obsolete
- [ ] `docs/sdk.md` - Update for:
- `HookMessage` and `isHookMessage()`
- `Agent.prompt(AppMessage)` overload
- Session v2 tree structure
- SessionManager API changes
- [ ] `docs/session.md` - Update for v2 tree structure, new entry types
- [ ] `docs/custom-tools.md` - Check if hook changes affect custom tools
- [ ] `docs/rpc.md` - Check if hook commands work in RPC mode
- [ ] `docs/skills.md` - Review for any hook-related updates
- [ ] `docs/extension-loading.md` - Review
- [x] `docs/theme.md` - Added selectedBg, customMessageBg/Text/Label color tokens (50 total)
- [ ] `README.md` - Update hook examples if any
### Examples
Review and update examples:
- [ ] `examples/hooks/` - Update existing, add new examples:
- [ ] Review `custom-compaction.ts` for new API
- [ ] Add `registerCommand()` example
- [ ] Add `sendMessage()` example
- [ ] Add `registerCustomMessageRenderer()` example
- [ ] `examples/sdk/` - Update for new session/hook APIs
- [ ] `examples/custom-tools/` - Review for compatibility
---
## Before Release
- [ ] Run full automated test suite: `npm test`
- [ ] Manual testing of tree navigation and branch summarization
- [ ] Verify compaction with file tracking works correctly
---
## Notes
- All append methods return the new entry's ID
- Migration rewrites file on first load if version < CURRENT_VERSION
- Existing sessions become linear chains after migration (parentId = previous entry)
- Tree features available immediately after migration
- SessionHeader does NOT have id/parentId (it's metadata, not part of tree)
- Session is append-only: entries cannot be modified or deleted, only branching changes the leaf pointer

View file

@ -21,12 +21,16 @@ Every theme must define all color tokens. There are no optional colors.
| `dim` | Very dimmed text | Less important info, placeholders |
| `text` | Default text color | Main content (usually `""`) |
### Backgrounds & Content Text (7 colors)
### Backgrounds & Content Text (11 colors)
| Token | Purpose |
|-------|---------|
| `selectedBg` | Selected/active line background (e.g., tree selector) |
| `userMessageBg` | User message background |
| `userMessageText` | User message text color |
| `customMessageBg` | Hook custom message background |
| `customMessageText` | Hook custom message text color |
| `customMessageLabel` | Hook custom message label/type text |
| `toolPendingBg` | Tool execution box (pending state) |
| `toolSuccessBg` | Tool execution box (success state) |
| `toolErrorBg` | Tool execution box (error state) |
@ -95,7 +99,7 @@ These create a visual hierarchy: off → minimal → low → medium → high →
|-------|---------|
| `bashMode` | Editor border color when in bash mode (! prefix) |
**Total: 46 color tokens** (all required)
**Total: 50 color tokens** (all required)
## Theme Format

View file

@ -0,0 +1,197 @@
# Session Tree Navigation
The `/tree` command provides tree-based navigation of the session history.
## Overview
Sessions are stored as trees where each entry has an `id` and `parentId`. The "leaf" pointer tracks the current position. `/tree` lets you navigate to any point and optionally summarize the branch you're leaving.
### Comparison with `/branch`
| Feature | `/branch` | `/tree` |
|---------|-----------|---------|
| View | Flat list of user messages | Full tree structure |
| Action | Extracts path to **new session file** | Changes leaf in **same session** |
| Summary | Never | Optional (user prompted) |
| Events | `session_before_branch` / `session_branch` | `session_before_tree` / `session_tree` |
## Tree UI
```
├─ user: "Hello, can you help..."
│ └─ assistant: "Of course! I can..."
│ ├─ user: "Let's try approach A..."
│ │ └─ assistant: "For approach A..."
│ │ └─ [compaction: 12k tokens]
│ │ └─ user: "That worked..." ← active
│ └─ user: "Actually, approach B..."
│ └─ assistant: "For approach B..."
```
### Controls
| Key | Action |
|-----|--------|
| ↑/↓ | Navigate (depth-first order) |
| Enter | Select node |
| Escape/Ctrl+C | Cancel |
| Ctrl+U | Toggle: user messages only |
| Ctrl+O | Toggle: show all (including custom/label entries) |
### Display
- Height: half terminal height
- Current leaf marked with `← active`
- Labels shown inline: `[label-name]`
- Default filter hides `label` and `custom` entries (shown in Ctrl+O mode)
- Children sorted by timestamp (oldest first)
## Selection Behavior
### User Message or Custom Message
1. Leaf set to **parent** of selected node (or `null` if root)
2. Message text placed in **editor** for re-submission
3. User edits and submits, creating a new branch
### Non-User Message (assistant, compaction, etc.)
1. Leaf set to **selected node**
2. Editor stays empty
3. User continues from that point
### Selecting Root User Message
If user selects the very first message (has no parent):
1. Leaf reset to `null` (empty conversation)
2. Message text placed in editor
3. User effectively restarts from scratch
## Branch Summarization
When switching, user is prompted: "Summarize the branch you're leaving?"
### What Gets Summarized
Path from old leaf back to common ancestor with target:
```
A → B → C → D → E → F ← old leaf
↘ G → H ← target
```
Abandoned path: D → E → F (summarized)
Summarization stops at:
1. Common ancestor (always)
2. Compaction node (if encountered first)
### Summary Storage
Stored as `BranchSummaryEntry`:
```typescript
interface BranchSummaryEntry {
type: "branch_summary";
id: string;
parentId: string; // New leaf position
timestamp: string;
fromId: string; // Old leaf we abandoned
summary: string; // LLM-generated summary
details?: unknown; // Optional hook data
}
```
## Implementation
### AgentSession.navigateTree()
```typescript
async navigateTree(
targetId: string,
options?: { summarize?: boolean; customInstructions?: string }
): Promise<{ editorText?: string; cancelled: boolean }>
```
Flow:
1. Validate target, check no-op (target === current leaf)
2. Find common ancestor between old leaf and target
3. Collect entries to summarize (if requested)
4. Fire `session_before_tree` event (hook can cancel or provide summary)
5. Run default summarizer if needed
6. Switch leaf via `branch()` or `branchWithSummary()`
7. Update agent: `agent.replaceMessages(sessionManager.buildSessionContext().messages)`
8. Fire `session_tree` event
9. Notify custom tools via session event
10. Return result with `editorText` if user message was selected
### SessionManager
- `getLeafUuid(): string | null` - Current leaf (null if empty)
- `resetLeaf(): void` - Set leaf to null (for root user message navigation)
- `getTree(): SessionTreeNode[]` - Full tree with children sorted by timestamp
- `branch(id)` - Change leaf pointer
- `branchWithSummary(id, summary)` - Change leaf and create summary entry
### InteractiveMode
`/tree` command shows `TreeSelectorComponent`, then:
1. Prompt for summarization
2. Call `session.navigateTree()`
3. Clear and re-render chat
4. Set editor text if applicable
## Hook Events
### `session_before_tree`
```typescript
interface TreePreparation {
targetId: string;
oldLeafId: string | null;
commonAncestorId: string | null;
entriesToSummarize: SessionEntry[];
userWantsSummary: boolean;
}
interface SessionBeforeTreeEvent {
type: "session_before_tree";
preparation: TreePreparation;
model: Model;
signal: AbortSignal;
}
interface SessionBeforeTreeResult {
cancel?: boolean;
summary?: { summary: string; details?: unknown };
}
```
### `session_tree`
```typescript
interface SessionTreeEvent {
type: "session_tree";
newLeafId: string | null;
oldLeafId: string | null;
summaryEntry?: BranchSummaryEntry;
fromHook?: boolean;
}
```
### Example: Custom Summarizer
```typescript
export default function(pi: HookAPI) {
pi.on("session_before_tree", async (event, ctx) => {
if (!event.preparation.userWantsSummary) return;
if (event.preparation.entriesToSummarize.length === 0) return;
const summary = await myCustomSummarizer(event.preparation.entriesToSummarize);
return { summary: { summary, details: { custom: true } } };
});
}
```
## Error Handling
- Summarization failure: cancels navigation, shows error
- User abort (Escape): cancels navigation
- Hook returns `cancel: true`: cancels navigation silently

View file

@ -41,7 +41,7 @@ const factory: CustomToolFactory = (pi) => {
const answer = await pi.ui.select(params.question, params.options);
if (answer === null) {
if (answer === undefined) {
return {
content: [{ type: "text", text: "User cancelled the selection" }],
details: { question: params.question, options: params.options, answer: null },

View file

@ -16,7 +16,8 @@ import { spawn } from "node:child_process";
import * as fs from "node:fs";
import * as os from "node:os";
import * as path from "node:path";
import type { AgentToolResult, Message } from "@mariozechner/pi-ai";
import type { AgentToolResult } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
import { StringEnum } from "@mariozechner/pi-ai";
import {
type CustomAgentTool,

View file

@ -8,11 +8,9 @@
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
pi.on("session", async (event, ctx) => {
if (event.reason !== "shutdown") return;
pi.on("session_shutdown", async (_event, ctx) => {
// Check for uncommitted changes
const { stdout: status, code } = await ctx.exec("git", ["status", "--porcelain"]);
const { stdout: status, code } = await pi.exec("git", ["status", "--porcelain"]);
if (code !== 0 || status.trim().length === 0) {
// Not a git repo or no changes
@ -20,9 +18,10 @@ export default function (pi: HookAPI) {
}
// Find the last assistant message for commit context
const entries = ctx.sessionManager.getEntries();
let lastAssistantText = "";
for (let i = event.entries.length - 1; i >= 0; i--) {
const entry = event.entries[i];
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message" && entry.message.role === "assistant") {
const content = entry.message.content;
if (Array.isArray(content)) {
@ -40,8 +39,8 @@ export default function (pi: HookAPI) {
const commitMessage = `[pi] ${firstLine.slice(0, 50)}${firstLine.length > 50 ? "..." : ""}`;
// Stage and commit
await ctx.exec("git", ["add", "-A"]);
const { code: commitCode } = await ctx.exec("git", ["commit", "-m", commitMessage]);
await pi.exec("git", ["add", "-A"]);
const { code: commitCode } = await pi.exec("git", ["commit", "-m", commitMessage]);
if (commitCode === 0 && ctx.hasUI) {
ctx.ui.notify(`Auto-committed: ${commitMessage}`, "info");

View file

@ -2,59 +2,57 @@
* Confirm Destructive Actions Hook
*
* Prompts for confirmation before destructive session actions (clear, switch, branch).
* Demonstrates how to cancel session events using the before_* variants.
* Demonstrates how to cancel session events using the before_* events.
*/
import type { SessionMessageEntry } from "@mariozechner/pi-coding-agent";
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
pi.on("session", async (event, ctx) => {
// Only handle before_* events (the ones that can be cancelled)
if (event.reason === "before_new") {
if (!ctx.hasUI) return;
pi.on("session_before_new", async (_event, ctx) => {
if (!ctx.hasUI) return;
const confirmed = await ctx.ui.confirm("Clear session?", "This will delete all messages in the current session.");
if (!confirmed) {
ctx.ui.notify("Clear cancelled", "info");
return { cancel: true };
}
});
pi.on("session_before_switch", async (_event, ctx) => {
if (!ctx.hasUI) return;
// Check if there are unsaved changes (messages since last assistant response)
const entries = ctx.sessionManager.getEntries();
const hasUnsavedWork = entries.some(
(e): e is SessionMessageEntry => e.type === "message" && e.message.role === "user",
);
if (hasUnsavedWork) {
const confirmed = await ctx.ui.confirm(
"Clear session?",
"This will delete all messages in the current session.",
"Switch session?",
"You have messages in the current session. Switch anyway?",
);
if (!confirmed) {
ctx.ui.notify("Clear cancelled", "info");
return { cancel: true };
}
}
if (event.reason === "before_switch") {
if (!ctx.hasUI) return;
// Check if there are unsaved changes (messages since last assistant response)
const hasUnsavedWork = event.entries.some((e) => e.type === "message" && e.message.role === "user");
if (hasUnsavedWork) {
const confirmed = await ctx.ui.confirm(
"Switch session?",
"You have messages in the current session. Switch anyway?",
);
if (!confirmed) {
ctx.ui.notify("Switch cancelled", "info");
return { cancel: true };
}
}
}
if (event.reason === "before_branch") {
if (!ctx.hasUI) return;
const choice = await ctx.ui.select(`Branch from turn ${event.targetTurnIndex}?`, [
"Yes, create branch",
"No, stay in current session",
]);
if (choice !== "Yes, create branch") {
ctx.ui.notify("Branch cancelled", "info");
ctx.ui.notify("Switch cancelled", "info");
return { cancel: true };
}
}
});
pi.on("session_before_branch", async (event, ctx) => {
if (!ctx.hasUI) return;
const choice = await ctx.ui.select(`Branch from turn ${event.entryIndex}?`, [
"Yes, create branch",
"No, stay in current session",
]);
if (choice !== "Yes, create branch") {
ctx.ui.notify("Branch cancelled", "info");
return { cancel: true };
}
});
}

View file

@ -14,17 +14,18 @@
*/
import { complete, getModel } from "@mariozechner/pi-ai";
import { messageTransformer } from "@mariozechner/pi-coding-agent";
import { convertToLlm } from "@mariozechner/pi-coding-agent";
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
pi.on("session", async (event, ctx) => {
if (event.reason !== "before_compact") return;
pi.on("session_before_compact", async (event, ctx) => {
ctx.ui.notify("Custom compaction hook triggered", "info");
const { messagesToSummarize, messagesToKeep, previousSummary, tokensBefore, resolveApiKey, entries, signal } =
event;
const { preparation, previousCompactions, signal } = event;
const { messagesToSummarize, messagesToKeep, tokensBefore, firstKeptEntryId } = preparation;
// Get previous summary from most recent compaction (if any)
const previousSummary = previousCompactions[0]?.summary;
// Use Gemini Flash for summarization (cheaper/faster than most conversation models)
const model = getModel("google", "gemini-2.5-flash");
@ -34,7 +35,7 @@ export default function (pi: HookAPI) {
}
// Resolve API key for the summarization model
const apiKey = await resolveApiKey(model);
const apiKey = await ctx.modelRegistry.getApiKey(model);
if (!apiKey) {
ctx.ui.notify(`No API key for ${model.provider}, using default compaction`, "warning");
return;
@ -49,7 +50,7 @@ export default function (pi: HookAPI) {
);
// Transform app messages to pi-ai package format
const transformedMessages = messageTransformer(allMessages);
const transformedMessages = convertToLlm(allMessages);
// Include previous summary context if available
const previousContext = previousSummary ? `\n\nPrevious session summary for context:\n${previousSummary}` : "";
@ -94,14 +95,12 @@ Format the summary as structured markdown with clear sections.`,
return;
}
// Return a compaction entry that discards ALL messages
// firstKeptEntryIndex points past all current entries
// Return compaction content - SessionManager adds id/parentId
// Use firstKeptEntryId from preparation to keep recent messages
return {
compactionEntry: {
type: "compaction" as const,
timestamp: new Date().toISOString(),
compaction: {
summary,
firstKeptEntryIndex: entries.length,
firstKeptEntryId,
tokensBefore,
},
};

View file

@ -5,47 +5,55 @@
* Useful to ensure work is committed before switching context.
*/
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
import type { HookAPI, HookEventContext } from "@mariozechner/pi-coding-agent/hooks";
async function checkDirtyRepo(
pi: HookAPI,
ctx: HookEventContext,
action: string,
): Promise<{ cancel: boolean } | undefined> {
// Check for uncommitted changes
const { stdout, code } = await pi.exec("git", ["status", "--porcelain"]);
if (code !== 0) {
// Not a git repo, allow the action
return;
}
const hasChanges = stdout.trim().length > 0;
if (!hasChanges) {
return;
}
if (!ctx.hasUI) {
// In non-interactive mode, block by default
return { cancel: true };
}
// Count changed files
const changedFiles = stdout.trim().split("\n").filter(Boolean).length;
const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [
"Yes, proceed anyway",
"No, let me commit first",
]);
if (choice !== "Yes, proceed anyway") {
ctx.ui.notify("Commit your changes first", "warning");
return { cancel: true };
}
}
export default function (pi: HookAPI) {
pi.on("session", async (event, ctx) => {
// Only guard destructive actions
if (event.reason !== "before_new" && event.reason !== "before_switch" && event.reason !== "before_branch") {
return;
}
pi.on("session_before_new", async (_event, ctx) => {
return checkDirtyRepo(pi, ctx, "new session");
});
// Check for uncommitted changes
const { stdout, code } = await ctx.exec("git", ["status", "--porcelain"]);
pi.on("session_before_switch", async (_event, ctx) => {
return checkDirtyRepo(pi, ctx, "switch session");
});
if (code !== 0) {
// Not a git repo, allow the action
return;
}
const hasChanges = stdout.trim().length > 0;
if (!hasChanges) {
return;
}
if (!ctx.hasUI) {
// In non-interactive mode, block by default
return { cancel: true };
}
// Count changed files
const changedFiles = stdout.trim().split("\n").filter(Boolean).length;
const action =
event.reason === "before_new" ? "new session" : event.reason === "before_switch" ? "switch session" : "branch";
const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [
"Yes, proceed anyway",
"No, let me commit first",
]);
if (choice !== "Yes, proceed anyway") {
ctx.ui.notify("Commit your changes first", "warning");
return { cancel: true };
}
pi.on("session_before_branch", async (_event, ctx) => {
return checkDirtyRepo(pi, ctx, "branch");
});
}

View file

@ -12,16 +12,21 @@ import * as fs from "node:fs";
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
pi.on("session", async (event, ctx) => {
if (event.reason !== "start") return;
pi.on("session_start", async (_event, ctx) => {
const triggerFile = "/tmp/agent-trigger.txt";
fs.watch(triggerFile, () => {
try {
const content = fs.readFileSync(triggerFile, "utf-8").trim();
if (content) {
pi.send(`External trigger: ${content}`);
pi.sendMessage(
{
customType: "file-trigger",
content: `External trigger: ${content}`,
display: true,
},
true, // triggerTurn - get LLM to respond
);
fs.writeFileSync(triggerFile, ""); // Clear after reading
}
} catch {

View file

@ -10,20 +10,17 @@ import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
export default function (pi: HookAPI) {
const checkpoints = new Map<number, string>();
pi.on("turn_start", async (event, ctx) => {
pi.on("turn_start", async (event) => {
// Create a git stash entry before LLM makes changes
const { stdout } = await ctx.exec("git", ["stash", "create"]);
const { stdout } = await pi.exec("git", ["stash", "create"]);
const ref = stdout.trim();
if (ref) {
checkpoints.set(event.turnIndex, ref);
}
});
pi.on("session", async (event, ctx) => {
// Only handle before_branch events
if (event.reason !== "before_branch") return;
const ref = checkpoints.get(event.targetTurnIndex);
pi.on("session_before_branch", async (event, ctx) => {
const ref = checkpoints.get(event.entryIndex);
if (!ref) return;
if (!ctx.hasUI) {
@ -37,7 +34,7 @@ export default function (pi: HookAPI) {
]);
if (choice?.startsWith("Yes")) {
await ctx.exec("git", ["stash", "apply", ref]);
await pi.exec("git", ["stash", "apply", ref]);
ctx.ui.notify("Code restored to checkpoint", "info");
}
});

View file

@ -0,0 +1,345 @@
/**
* Snake game hook - play snake with /snake command
*/
import { isArrowDown, isArrowLeft, isArrowRight, isArrowUp, isEscape, visibleWidth } from "@mariozechner/pi-tui";
import type { HookAPI } from "../../src/core/hooks/types.js";
const GAME_WIDTH = 40;
const GAME_HEIGHT = 15;
const TICK_MS = 100;
type Direction = "up" | "down" | "left" | "right";
type Point = { x: number; y: number };
interface GameState {
snake: Point[];
food: Point;
direction: Direction;
nextDirection: Direction;
score: number;
gameOver: boolean;
highScore: number;
}
function createInitialState(): GameState {
const startX = Math.floor(GAME_WIDTH / 2);
const startY = Math.floor(GAME_HEIGHT / 2);
return {
snake: [
{ x: startX, y: startY },
{ x: startX - 1, y: startY },
{ x: startX - 2, y: startY },
],
food: spawnFood([{ x: startX, y: startY }]),
direction: "right",
nextDirection: "right",
score: 0,
gameOver: false,
highScore: 0,
};
}
function spawnFood(snake: Point[]): Point {
let food: Point;
do {
food = {
x: Math.floor(Math.random() * GAME_WIDTH),
y: Math.floor(Math.random() * GAME_HEIGHT),
};
} while (snake.some((s) => s.x === food.x && s.y === food.y));
return food;
}
class SnakeComponent {
private state: GameState;
private interval: ReturnType<typeof setInterval> | null = null;
private onClose: () => void;
private onSave: (state: GameState | null) => void;
private requestRender: () => void;
private cachedLines: string[] = [];
private cachedWidth = 0;
private version = 0;
private cachedVersion = -1;
private paused: boolean;
constructor(
onClose: () => void,
onSave: (state: GameState | null) => void,
requestRender: () => void,
savedState?: GameState,
) {
if (savedState && !savedState.gameOver) {
// Resume from saved state, start paused
this.state = savedState;
this.paused = true;
} else {
// New game or saved game was over
this.state = createInitialState();
if (savedState) {
this.state.highScore = savedState.highScore;
}
this.paused = false;
this.startGame();
}
this.onClose = onClose;
this.onSave = onSave;
this.requestRender = requestRender;
}
private startGame(): void {
this.interval = setInterval(() => {
if (!this.state.gameOver) {
this.tick();
this.version++;
this.requestRender();
}
}, TICK_MS);
}
private tick(): void {
// Apply queued direction change
this.state.direction = this.state.nextDirection;
// Calculate new head position
const head = this.state.snake[0];
let newHead: Point;
switch (this.state.direction) {
case "up":
newHead = { x: head.x, y: head.y - 1 };
break;
case "down":
newHead = { x: head.x, y: head.y + 1 };
break;
case "left":
newHead = { x: head.x - 1, y: head.y };
break;
case "right":
newHead = { x: head.x + 1, y: head.y };
break;
}
// Check wall collision
if (newHead.x < 0 || newHead.x >= GAME_WIDTH || newHead.y < 0 || newHead.y >= GAME_HEIGHT) {
this.state.gameOver = true;
return;
}
// Check self collision
if (this.state.snake.some((s) => s.x === newHead.x && s.y === newHead.y)) {
this.state.gameOver = true;
return;
}
// Move snake
this.state.snake.unshift(newHead);
// Check food collision
if (newHead.x === this.state.food.x && newHead.y === this.state.food.y) {
this.state.score += 10;
if (this.state.score > this.state.highScore) {
this.state.highScore = this.state.score;
}
this.state.food = spawnFood(this.state.snake);
} else {
this.state.snake.pop();
}
}
handleInput(data: string): void {
// If paused (resuming), wait for any key
if (this.paused) {
if (isEscape(data) || data === "q" || data === "Q") {
// Quit without clearing save
this.dispose();
this.onClose();
return;
}
// Any other key resumes
this.paused = false;
this.startGame();
return;
}
// ESC to pause and save
if (isEscape(data)) {
this.dispose();
this.onSave(this.state);
this.onClose();
return;
}
// Q to quit without saving (clears saved state)
if (data === "q" || data === "Q") {
this.dispose();
this.onSave(null); // Clear saved state
this.onClose();
return;
}
// Arrow keys or WASD
if (isArrowUp(data) || data === "w" || data === "W") {
if (this.state.direction !== "down") this.state.nextDirection = "up";
} else if (isArrowDown(data) || data === "s" || data === "S") {
if (this.state.direction !== "up") this.state.nextDirection = "down";
} else if (isArrowRight(data) || data === "d" || data === "D") {
if (this.state.direction !== "left") this.state.nextDirection = "right";
} else if (isArrowLeft(data) || data === "a" || data === "A") {
if (this.state.direction !== "right") this.state.nextDirection = "left";
}
// Restart on game over
if (this.state.gameOver && (data === "r" || data === "R" || data === " ")) {
const highScore = this.state.highScore;
this.state = createInitialState();
this.state.highScore = highScore;
this.onSave(null); // Clear saved state on restart
this.version++;
this.requestRender();
}
}
invalidate(): void {
this.cachedWidth = 0;
}
render(width: number): string[] {
if (width === this.cachedWidth && this.cachedVersion === this.version) {
return this.cachedLines;
}
const lines: string[] = [];
// Each game cell is 2 chars wide to appear square (terminal cells are ~2:1 aspect)
const cellWidth = 2;
const effectiveWidth = Math.min(GAME_WIDTH, Math.floor((width - 4) / cellWidth));
const effectiveHeight = GAME_HEIGHT;
// Colors
const dim = (s: string) => `\x1b[2m${s}\x1b[22m`;
const green = (s: string) => `\x1b[32m${s}\x1b[0m`;
const red = (s: string) => `\x1b[31m${s}\x1b[0m`;
const yellow = (s: string) => `\x1b[33m${s}\x1b[0m`;
const bold = (s: string) => `\x1b[1m${s}\x1b[22m`;
const boxWidth = effectiveWidth * cellWidth;
// Helper to pad content inside box
const boxLine = (content: string) => {
const contentLen = visibleWidth(content);
const padding = Math.max(0, boxWidth - contentLen);
return dim(" │") + content + " ".repeat(padding) + dim("│");
};
// Top border
lines.push(this.padLine(dim(`${"─".repeat(boxWidth)}`), width));
// Header with score
const scoreText = `Score: ${bold(yellow(String(this.state.score)))}`;
const highText = `High: ${bold(yellow(String(this.state.highScore)))}`;
const title = `${bold(green("SNAKE"))}${scoreText}${highText}`;
lines.push(this.padLine(boxLine(title), width));
// Separator
lines.push(this.padLine(dim(`${"─".repeat(boxWidth)}`), width));
// Game grid
for (let y = 0; y < effectiveHeight; y++) {
let row = "";
for (let x = 0; x < effectiveWidth; x++) {
const isHead = this.state.snake[0].x === x && this.state.snake[0].y === y;
const isBody = this.state.snake.slice(1).some((s) => s.x === x && s.y === y);
const isFood = this.state.food.x === x && this.state.food.y === y;
if (isHead) {
row += green("██"); // Snake head (2 chars)
} else if (isBody) {
row += green("▓▓"); // Snake body (2 chars)
} else if (isFood) {
row += red("◆ "); // Food (2 chars)
} else {
row += " "; // Empty cell (2 spaces)
}
}
lines.push(this.padLine(dim(" │") + row + dim("│"), width));
}
// Separator
lines.push(this.padLine(dim(`${"─".repeat(boxWidth)}`), width));
// Footer
let footer: string;
if (this.paused) {
footer = `${yellow(bold("PAUSED"))} Press any key to continue, ${bold("Q")} to quit`;
} else if (this.state.gameOver) {
footer = `${red(bold("GAME OVER!"))} Press ${bold("R")} to restart, ${bold("Q")} to quit`;
} else {
footer = `↑↓←→ or WASD to move, ${bold("ESC")} pause, ${bold("Q")} quit`;
}
lines.push(this.padLine(boxLine(footer), width));
// Bottom border
lines.push(this.padLine(dim(`${"─".repeat(boxWidth)}`), width));
this.cachedLines = lines;
this.cachedWidth = width;
this.cachedVersion = this.version;
return lines;
}
private padLine(line: string, width: number): string {
// Calculate visible length (strip ANSI codes)
const visibleLen = line.replace(/\x1b\[[0-9;]*m/g, "").length;
const padding = Math.max(0, width - visibleLen);
return line + " ".repeat(padding);
}
dispose(): void {
if (this.interval) {
clearInterval(this.interval);
this.interval = null;
}
}
}
const SNAKE_SAVE_TYPE = "snake-save";
export default function (pi: HookAPI) {
pi.registerCommand("snake", {
description: "Play Snake!",
handler: async (ctx) => {
if (!ctx.hasUI) {
ctx.ui.notify("Snake requires interactive mode", "error");
return;
}
// Load saved state from session
const entries = ctx.sessionManager.getEntries();
let savedState: GameState | undefined;
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "custom" && entry.customType === SNAKE_SAVE_TYPE) {
savedState = entry.data as GameState;
break;
}
}
let ui: { close: () => void; requestRender: () => void } | null = null;
const component = new SnakeComponent(
() => ui?.close(),
(state) => {
// Save or clear state
pi.appendEntry(SNAKE_SAVE_TYPE, state);
},
() => ui?.requestRender(),
savedState,
);
ui = ctx.ui.custom(component);
},
});
}

View file

@ -3,21 +3,21 @@
*/
import { access, readFile, stat } from "node:fs/promises";
import type { Attachment } from "@mariozechner/pi-agent-core";
import type { ImageContent } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { resolve } from "path";
import { resolveReadPath } from "../core/tools/path-utils.js";
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
export interface ProcessedFiles {
textContent: string;
imageAttachments: Attachment[];
text: string;
images: ImageContent[];
}
/** Process @file arguments into text content and image attachments */
export async function processFileArguments(fileArgs: string[]): Promise<ProcessedFiles> {
let textContent = "";
const imageAttachments: Attachment[] = [];
let text = "";
const images: ImageContent[] = [];
for (const fileArg of fileArgs) {
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
@ -45,24 +45,21 @@ export async function processFileArguments(fileArgs: string[]): Promise<Processe
const content = await readFile(absolutePath);
const base64Content = content.toString("base64");
const attachment: Attachment = {
id: `file-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
const attachment: ImageContent = {
type: "image",
fileName: absolutePath.split("/").pop() || absolutePath,
mimeType,
size: stats.size,
content: base64Content,
data: base64Content,
};
imageAttachments.push(attachment);
images.push(attachment);
// Add text reference to image
textContent += `<file name="${absolutePath}"></file>\n`;
text += `<file name="${absolutePath}"></file>\n`;
} else {
// Handle text file
try {
const content = await readFile(absolutePath, "utf-8");
textContent += `<file name="${absolutePath}">\n${content}\n</file>\n`;
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
} catch (error: unknown) {
const message = error instanceof Error ? error.message : String(error);
console.error(chalk.red(`Error: Could not read file ${absolutePath}: ${message}`));
@ -71,5 +68,5 @@ export async function processFileArguments(fileArgs: string[]): Promise<Processe
}
}
return { textContent, imageAttachments };
return { text, images };
}

File diff suppressed because it is too large Load diff

View file

@ -94,8 +94,8 @@ export class AuthStorage {
/**
* Get credential for a provider.
*/
get(provider: string): AuthCredential | null {
return this.data[provider] ?? null;
get(provider: string): AuthCredential | undefined {
return this.data[provider] ?? undefined;
}
/**
@ -191,7 +191,7 @@ export class AuthStorage {
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(provider: string): Promise<string | null> {
async getApiKey(provider: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(provider);
if (runtimeKey) {
@ -230,6 +230,6 @@ export class AuthStorage {
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(provider) ?? null;
return this.fallbackResolver?.(provider) ?? undefined;
}
}

View file

@ -29,8 +29,8 @@ export interface BashExecutorOptions {
export interface BashResult {
/** Combined stdout + stderr output (sanitized, possibly truncated) */
output: string;
/** Process exit code (null if killed/cancelled) */
exitCode: number | null;
/** Process exit code (undefined if killed/cancelled) */
exitCode: number | undefined;
/** Whether the command was cancelled via signal */
cancelled: boolean;
/** Whether the output was truncated */
@ -88,7 +88,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro
child.kill();
resolve({
output: "",
exitCode: null,
exitCode: undefined,
cancelled: true,
truncated: false,
});
@ -154,7 +154,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro
resolve({
output: truncationResult.truncated ? truncationResult.content : fullOutput,
exitCode: code,
exitCode: cancelled ? undefined : code,
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,

View file

@ -1,530 +0,0 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { complete } from "@mariozechner/pi-ai";
import { messageTransformer } from "./messages.js";
import type { CompactionEntry, SessionEntry } from "./session-manager.js";
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
}
/**
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AppMessage): Usage | null {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
return assistantMsg.usage;
}
}
return null;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | null {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return null;
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AppMessage): number {
let chars = 0;
// Handle bashExecution messages
if (message.role === "bashExecution") {
const bash = message as unknown as { command: string; output: string };
chars = bash.command.length + bash.output.length;
return Math.ceil(chars / 4);
}
// Handle user messages
if (message.role === "user") {
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
if (typeof content === "string") {
chars = content.length;
} else if (Array.isArray(content)) {
for (const block of content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
}
return Math.ceil(chars / 4);
}
// Handle assistant messages
if (message.role === "assistant") {
const assistant = message as AssistantMessage;
for (const block of assistant.content) {
if (block.type === "text") {
chars += block.text.length;
} else if (block.type === "thinking") {
chars += block.thinking.length;
} else if (block.type === "toolCall") {
chars += block.name.length + JSON.stringify(block.arguments).length;
}
}
return Math.ceil(chars / 4);
}
// Handle tool results
if (message.role === "toolResult") {
const toolResult = message as { content: Array<{ type: string; text?: string }> };
for (const block of toolResult.content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
return Math.ceil(chars / 4);
}
return 0;
}
/**
* Find valid cut points: indices of user, assistant, or bashExecution messages.
* Never cut at tool results (they must follow their tool call).
* When we cut at an assistant message with tool calls, its tool results follow it
* and will be kept.
* BashExecutionMessage is treated like a user message (user-initiated context).
*/
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
const cutPoints: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
if (entry.type === "message") {
const role = entry.message.role;
// user, assistant, and bashExecution are valid cut points
// toolResult must stay with its preceding tool call
if (role === "user" || role === "assistant" || role === "bashExecution") {
cutPoints.push(i);
}
}
}
return cutPoints;
}
/**
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
* Returns -1 if no turn start found before the index.
* BashExecutionMessage is treated like a user message for turn boundaries.
*/
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
for (let i = entryIndex; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type === "message") {
const role = entry.message.role;
if (role === "user" || role === "bashExecution") {
return i;
}
}
}
return -1;
}
export interface CutPointResult {
/** Index of first entry to keep */
firstKeptEntryIndex: number;
/** Index of user message that starts the turn being split, or -1 if not splitting */
turnStartIndex: number;
/** Whether this cut splits a turn (cut point is not a user message) */
isSplitTurn: boolean;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
*
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
*
* Can cut at user OR assistant messages (never tool results). When cutting at an
* assistant message with tool calls, its tool results come after and will be kept.
*
* Returns CutPointResult with:
* - firstKeptEntryIndex: the entry index to start keeping from
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
* - isSplitTurn: whether we're cutting in the middle of a turn
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): CutPointResult {
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
if (cutPoints.length === 0) {
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
}
// Walk backwards from newest, accumulating estimated message sizes
let accumulatedTokens = 0;
let cutIndex = startIndex; // Default: keep everything in range
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type !== "message") continue;
// Estimate this message's size
const messageTokens = estimateTokens(entry.message);
accumulatedTokens += messageTokens;
// Check if we've exceeded the budget
if (accumulatedTokens >= keepRecentTokens) {
// Find the closest valid cut point at or after this entry
for (let c = 0; c < cutPoints.length; c++) {
if (cutPoints[c] >= i) {
cutIndex = cutPoints[c];
break;
}
}
break;
}
}
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
while (cutIndex > startIndex) {
const prevEntry = entries[cutIndex - 1];
// Stop at compaction boundaries
if (prevEntry.type === "compaction") {
break;
}
if (prevEntry.type === "message") {
// Stop if we hit any message
break;
}
// Include this non-message entry (bash, settings change, etc.)
cutIndex--;
}
// Determine if this is a split turn
const cutEntry = entries[cutIndex];
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
return {
firstKeptEntryIndex: cutIndex,
turnStartIndex,
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
};
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION. Create a handoff summary for another LLM that will resume the task.
Include:
- Current progress and key decisions made
- Important context, constraints, or user preferences
- Absolute file paths of any relevant files that were read or modified
- What remains to be done (clear next steps)
- Any critical data, examples, or references needed to continue
Be concise, structured, and focused on helping the next LLM seamlessly continue the work.`;
/**
* Generate a summary of the conversation using the LLM.
*/
export async function generateSummary(
currentMessages: AppMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
const prompt = customInstructions
? `${SUMMARIZATION_PROMPT}\n\nAdditional focus: ${customInstructions}`
: SUMMARIZATION_PROMPT;
// Transform custom messages (like bashExecution) to LLM-compatible messages
const transformedMessages = messageTransformer(currentMessages);
const summarizationMessages = [
...transformedMessages,
{
role: "user" as const,
content: [{ type: "text" as const, text: prompt }],
timestamp: Date.now(),
},
];
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Compaction Preparation (for hooks)
// ============================================================================
export interface CompactionPreparation {
cutPoint: CutPointResult;
/** Messages that will be summarized and discarded */
messagesToSummarize: AppMessage[];
/** Messages that will be kept after the summary (recent turns) */
messagesToKeep: AppMessage[];
tokensBefore: number;
boundaryStart: number;
}
export function prepareCompaction(entries: SessionEntry[], settings: CompactionSettings): CompactionPreparation | null {
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
return null;
}
let prevCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = entries.length;
const lastUsage = getLastAssistantUsage(entries);
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AppMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const entry = entries[i];
if (entry.type === "message") {
messagesToSummarize.push(entry.message);
}
}
// Messages to keep (recent turns, kept after summary)
const messagesToKeep: AppMessage[] = [];
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
const entry = entries[i];
if (entry.type === "message") {
messagesToKeep.push(entry.message);
}
}
return { cutPoint, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
}
// ============================================================================
// Main compaction function
// ============================================================================
const TURN_PREFIX_SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION for a split turn.
This is the PREFIX of a turn that was too large to keep in full. The SUFFIX (recent work) is being kept.
Create a handoff summary that captures:
- What the user originally asked for in this turn
- Key decisions and progress made early in this turn
- Important context needed to understand the kept suffix
Be concise. Focus on information needed to understand the retained recent work.`;
/**
* Calculate compaction and generate summary.
* Returns the CompactionEntry to append to the session file.
*
* @param entries - All session entries
* @param model - Model to use for summarization
* @param settings - Compaction settings
* @param apiKey - API key for LLM
* @param signal - Optional abort signal
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
entries: SessionEntry[],
model: Model<any>,
settings: CompactionSettings,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
): Promise<CompactionEntry> {
// Don't compact if the last entry is already a compaction
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
throw new Error("Already compacted");
}
// Find previous compaction boundary
let prevCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = entries.length;
// Get token count before compaction
const lastUsage = getLastAssistantUsage(entries);
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
// Find cut point (entry index) within the valid range
const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
// Extract messages for history summary (before the turn that contains the cut point)
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
const historyMessages: AppMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const entry = entries[i];
if (entry.type === "message") {
historyMessages.push(entry.message);
}
}
// Include previous summary if there was a compaction
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
historyMessages.unshift({
role: "user",
content: `Previous session summary:\n${prevCompaction.summary}`,
timestamp: Date.now(),
});
}
// Extract messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AppMessage[] = [];
if (cutResult.isSplitTurn) {
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
const entry = entries[i];
if (entry.type === "message") {
turnPrefixMessages.push(entry.message);
}
}
}
// Generate summaries (can be parallel if both needed) and merge into one
let summary: string;
if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) {
// Generate both summaries in parallel
const [historyResult, turnPrefixResult] = await Promise.all([
historyMessages.length > 0
? generateSummary(historyMessages, model, settings.reserveTokens, apiKey, signal, customInstructions)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
]);
// Merge into single summary
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
} else {
// Just generate history summary
summary = await generateSummary(
historyMessages,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
);
}
return {
type: "compaction",
timestamp: new Date().toISOString(),
summary,
firstKeptEntryIndex: cutResult.firstKeptEntryIndex,
tokensBefore,
};
}
/**
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AppMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
): Promise<string> {
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
const transformedMessages = messageTransformer(messages);
const summarizationMessages = [
...transformedMessages,
{
role: "user" as const,
content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }],
timestamp: Date.now(),
},
];
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
return response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
}

View file

@ -0,0 +1,343 @@
/**
* Branch summarization for tree navigation.
*
* When navigating to a different point in the session tree, this generates
* a summary of the branch being left so context isn't lost.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import { completeSimple } from "@mariozechner/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createHookMessage,
} from "../messages.js";
import type { ReadonlySessionManager, SessionEntry } from "../session-manager.js";
import { estimateTokens } from "./compaction.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// Types
// ============================================================================
export interface BranchSummaryResult {
summary?: string;
readFiles?: string[];
modifiedFiles?: string[];
aborted?: boolean;
error?: string;
}
/** Details stored in BranchSummaryEntry.details for file tracking */
export interface BranchSummaryDetails {
readFiles: string[];
modifiedFiles: string[];
}
export type { FileOperations } from "./utils.js";
export interface BranchPreparation {
/** Messages extracted for summarization, in chronological order */
messages: AgentMessage[];
/** File operations extracted from tool calls */
fileOps: FileOperations;
/** Total estimated tokens in messages */
totalTokens: number;
}
export interface CollectEntriesResult {
/** Entries to summarize, in chronological order */
entries: SessionEntry[];
/** Common ancestor between old and new position, if any */
commonAncestorId: string | null;
}
export interface GenerateBranchSummaryOptions {
/** Model to use for summarization */
model: Model<any>;
/** API key for the model */
apiKey: string;
/** Abort signal for cancellation */
signal: AbortSignal;
/** Optional custom instructions for summarization */
customInstructions?: string;
/** Tokens reserved for prompt + LLM response (default 16384) */
reserveTokens?: number;
}
// ============================================================================
// Entry Collection
// ============================================================================
/**
* Collect entries that should be summarized when navigating from one position to another.
*
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
* along the way. Does NOT stop at compaction boundaries - those are included and their
* summaries become context.
*
* @param session - Session manager (read-only access)
* @param oldLeafId - Current position (where we're navigating from)
* @param targetId - Target position (where we're navigating to)
* @returns Entries to summarize and the common ancestor
*/
export function collectEntriesForBranchSummary(
session: ReadonlySessionManager,
oldLeafId: string | null,
targetId: string,
): CollectEntriesResult {
// If no old position, nothing to summarize
if (!oldLeafId) {
return { entries: [], commonAncestorId: null };
}
// Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getPath(oldLeafId).map((e) => e.id));
const targetPath = session.getPath(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null;
for (let i = targetPath.length - 1; i >= 0; i--) {
if (oldPath.has(targetPath[i].id)) {
commonAncestorId = targetPath[i].id;
break;
}
}
// Collect entries from old leaf back to common ancestor
const entries: SessionEntry[] = [];
let current: string | null = oldLeafId;
while (current && current !== commonAncestorId) {
const entry = session.getEntry(current);
if (!entry) break;
entries.push(entry);
current = entry.parentId;
}
// Reverse to get chronological order
entries.reverse();
return { entries, commonAncestorId };
}
// ============================================================================
// Entry to Message Conversion
// ============================================================================
/**
* Extract AgentMessage from a session entry.
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
switch (entry.type) {
case "message":
// Skip tool results - context is in assistant's tool call
if (entry.message.role === "toolResult") return undefined;
return entry.message;
case "custom_message":
return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
case "branch_summary":
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
case "compaction":
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
// These don't contribute to conversation content
case "thinking_level_change":
case "model_change":
case "custom":
case "label":
return undefined;
}
}
/**
* Prepare entries for summarization with token budget.
*
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
* This ensures we keep the most recent context when the branch is too long.
*
* Also collects file operations from:
* - Tool calls in assistant messages
* - Existing branch_summary entries' details (for cumulative tracking)
*
* @param entries - Entries in chronological order
* @param tokenBudget - Maximum tokens to include (0 = no limit)
*/
export function prepareBranchEntries(entries: SessionEntry[], tokenBudget: number = 0): BranchPreparation {
const messages: AgentMessage[] = [];
const fileOps = createFileOps();
let totalTokens = 0;
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
// This ensures we capture cumulative file tracking from nested branch summaries
// Only extract from pi-generated summaries (fromHook !== true), not hook-generated ones
for (const entry of entries) {
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
const details = entry.details as BranchSummaryDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
// Modified files go into both edited and written for proper deduplication
for (const f of details.modifiedFiles) {
fileOps.edited.add(f);
}
}
}
}
// Second pass: walk from newest to oldest, adding messages until token budget
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
const message = getMessageFromEntry(entry);
if (!message) continue;
// Extract file ops from assistant messages (tool calls)
extractFileOpsFromMessage(message, fileOps);
const tokens = estimateTokens(message);
// Check budget before adding
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
// If this is a summary entry, try to fit it anyway as it's important context
if (entry.type === "compaction" || entry.type === "branch_summary") {
if (totalTokens < tokenBudget * 0.9) {
messages.unshift(message);
totalTokens += tokens;
}
}
// Stop - we've hit the budget
break;
}
messages.unshift(message);
totalTokens += tokens;
}
return { messages, fileOps, totalTokens };
}
// ============================================================================
// Summary Generation
// ============================================================================
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
Summary of that exploration:
`;
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
Use this EXACT format:
## Goal
[What was the user trying to accomplish in this branch?]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Work that was started but not finished]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [What should happen next to continue this work]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of abandoned branch entries.
*
* @param entries - Session entries to summarize (chronological order)
* @param options - Generation options
*/
export async function generateBranchSummary(
entries: SessionEntry[],
options: GenerateBranchSummaryOptions,
): Promise<BranchSummaryResult> {
const { model, apiKey, signal, customInstructions, reserveTokens = 16384 } = options;
// Token budget = context window minus reserved space for prompt + response
const contextWindow = model.contextWindow || 128000;
const tokenBudget = contextWindow - reserveTokens;
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
if (messages.length === 0) {
return { summary: "No content to summarize" };
}
// Transform to LLM-compatible messages, then serialize to text
// Serialization prevents the model from treating it as a conversation to continue
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
// Build prompt
const instructions = customInstructions || BRANCH_SUMMARY_PROMPT;
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
// Call LLM for summarization
const response = await completeSimple(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
{ apiKey, signal, maxTokens: 2048 },
);
// Check if aborted or errored
if (response.stopReason === "aborted") {
return { aborted: true };
}
if (response.stopReason === "error") {
return { error: response.errorMessage || "Summarization failed" };
}
let summary = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
// Prepend preamble to provide context about the branch summary
summary = BRANCH_SUMMARY_PREAMBLE + summary;
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
return {
summary: summary || "No summary generated",
readFiles,
modifiedFiles,
};
}

View file

@ -0,0 +1,759 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { complete, completeSimple } from "@mariozechner/pi-ai";
import { convertToLlm, createBranchSummaryMessage, createHookMessage } from "../messages.js";
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// File Operation Tracking
// ============================================================================
/** Details stored in CompactionEntry.details for file tracking */
export interface CompactionDetails {
readFiles: string[];
modifiedFiles: string[];
}
/**
* Extract file operations from messages and previous compaction entries.
*/
function extractFileOperations(
messages: AgentMessage[],
entries: SessionEntry[],
prevCompactionIndex: number,
): FileOperations {
const fileOps = createFileOps();
// Collect from previous compaction's details (if pi-generated)
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
if (!prevCompaction.fromHook && prevCompaction.details) {
const details = prevCompaction.details as CompactionDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
for (const f of details.modifiedFiles) fileOps.edited.add(f);
}
}
}
// Extract from tool calls in messages
for (const msg of messages) {
extractFileOpsFromMessage(msg, fileOps);
}
return fileOps;
}
// ============================================================================
// Message Extraction
// ============================================================================
/**
* Extract AgentMessage from an entry if it produces one.
* Returns undefined for entries that don't contribute to LLM context.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
if (entry.type === "message") {
return entry.message;
}
if (entry.type === "custom_message") {
return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
}
if (entry.type === "branch_summary") {
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
}
return undefined;
}
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
export interface CompactionResult<T = unknown> {
summary: string;
firstKeptEntryId: string;
tokensBefore: number;
/** Hook-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
details?: T;
}
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
}
/**
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
return assistantMsg.usage;
}
}
return undefined;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | undefined {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return undefined;
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AgentMessage): number {
let chars = 0;
switch (message.role) {
case "user": {
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
if (typeof content === "string") {
chars = content.length;
} else if (Array.isArray(content)) {
for (const block of content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
}
return Math.ceil(chars / 4);
}
case "assistant": {
const assistant = message as AssistantMessage;
for (const block of assistant.content) {
if (block.type === "text") {
chars += block.text.length;
} else if (block.type === "thinking") {
chars += block.thinking.length;
} else if (block.type === "toolCall") {
chars += block.name.length + JSON.stringify(block.arguments).length;
}
}
return Math.ceil(chars / 4);
}
case "hookMessage":
case "toolResult": {
if (typeof message.content === "string") {
chars = message.content.length;
} else {
for (const block of message.content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
if (block.type === "image") {
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
}
}
}
return Math.ceil(chars / 4);
}
case "bashExecution": {
chars = message.command.length + message.output.length;
return Math.ceil(chars / 4);
}
case "branchSummary":
case "compactionSummary": {
chars = message.summary.length;
return Math.ceil(chars / 4);
}
}
return 0;
}
/**
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
* Never cut at tool results (they must follow their tool call).
* When we cut at an assistant message with tool calls, its tool results follow it
* and will be kept.
* BashExecutionMessage is treated like a user message (user-initiated context).
*/
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
const cutPoints: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
switch (entry.type) {
case "message": {
const role = entry.message.role;
switch (role) {
case "bashExecution":
case "hookMessage":
case "branchSummary":
case "compactionSummary":
case "user":
case "assistant":
cutPoints.push(i);
break;
case "toolResult":
break;
}
break;
}
case "thinking_level_change":
case "model_change":
case "compaction":
case "branch_summary":
case "custom":
case "custom_message":
case "label":
}
// branch_summary and custom_message are user-role messages, valid cut points
if (entry.type === "branch_summary" || entry.type === "custom_message") {
cutPoints.push(i);
}
}
return cutPoints;
}
/**
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
* Returns -1 if no turn start found before the index.
* BashExecutionMessage is treated like a user message for turn boundaries.
*/
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
for (let i = entryIndex; i >= startIndex; i--) {
const entry = entries[i];
// branch_summary and custom_message are user-role messages, can start a turn
if (entry.type === "branch_summary" || entry.type === "custom_message") {
return i;
}
if (entry.type === "message") {
const role = entry.message.role;
if (role === "user" || role === "bashExecution") {
return i;
}
}
}
return -1;
}
export interface CutPointResult {
/** Index of first entry to keep */
firstKeptEntryIndex: number;
/** Index of user message that starts the turn being split, or -1 if not splitting */
turnStartIndex: number;
/** Whether this cut splits a turn (cut point is not a user message) */
isSplitTurn: boolean;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
*
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
*
* Can cut at user OR assistant messages (never tool results). When cutting at an
* assistant message with tool calls, its tool results come after and will be kept.
*
* Returns CutPointResult with:
* - firstKeptEntryIndex: the entry index to start keeping from
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
* - isSplitTurn: whether we're cutting in the middle of a turn
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): CutPointResult {
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
if (cutPoints.length === 0) {
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
}
// Walk backwards from newest, accumulating estimated message sizes
let accumulatedTokens = 0;
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type !== "message") continue;
// Estimate this message's size
const messageTokens = estimateTokens(entry.message);
accumulatedTokens += messageTokens;
// Check if we've exceeded the budget
if (accumulatedTokens >= keepRecentTokens) {
// Find the closest valid cut point at or after this entry
for (let c = 0; c < cutPoints.length; c++) {
if (cutPoints[c] >= i) {
cutIndex = cutPoints[c];
break;
}
}
break;
}
}
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
while (cutIndex > startIndex) {
const prevEntry = entries[cutIndex - 1];
// Stop at session header or compaction boundaries
if (prevEntry.type === "compaction") {
break;
}
if (prevEntry.type === "message") {
// Stop if we hit any message
break;
}
// Include this non-message entry (bash, settings change, etc.)
cutIndex--;
}
// Determine if this is a split turn
const cutEntry = entries[cutIndex];
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
return {
firstKeptEntryIndex: cutIndex,
turnStartIndex,
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
};
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
Use this EXACT format:
## Goal
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned by user]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Current work]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [Ordered list of what should happen next]
## Critical Context
- [Any data, examples, or references needed to continue]
- [Or "(none)" if not applicable]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
Update the existing structured summary with new information. RULES:
- PRESERVE all existing information from the previous summary
- ADD new progress, decisions, and context from the new messages
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
- UPDATE "Next Steps" based on what was accomplished
- PRESERVE exact file paths, function names, and error messages
- If something is no longer relevant, you may remove it
Use this EXACT format:
## Goal
[Preserve existing goals, add new ones if the task expanded]
## Constraints & Preferences
- [Preserve existing, add new ones discovered]
## Progress
### Done
- [x] [Include previously done items AND newly completed items]
### In Progress
- [ ] [Current work - update based on progress]
### Blocked
- [Current blockers - remove if resolved]
## Key Decisions
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
## Next Steps
1. [Update based on current state]
## Critical Context
- [Preserve important context, add new if needed]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of the conversation using the LLM.
* If previousSummary is provided, uses the update prompt to merge.
*/
export async function generateSummary(
currentMessages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
previousSummary?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
// Use update prompt if we have a previous summary, otherwise initial prompt
let basePrompt = previousSummary ? UPDATE_SUMMARIZATION_PROMPT : SUMMARIZATION_PROMPT;
if (customInstructions) {
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
}
// Serialize conversation to text so model doesn't try to continue it
// Convert to LLM messages first (handles custom types like bashExecution, hookMessage, etc.)
const llmMessages = convertToLlm(currentMessages);
const conversationText = serializeConversation(llmMessages);
// Build the prompt with conversation wrapped in tags
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
if (previousSummary) {
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
}
promptText += basePrompt;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const response = await completeSimple(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
{ maxTokens, signal, apiKey, reasoning: "high" },
);
if (response.stopReason === "error") {
throw new Error(`Summarization failed: ${response.errorMessage || "Unknown error"}`);
}
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Compaction Preparation (for hooks)
// ============================================================================
export interface CompactionPreparation {
cutPoint: CutPointResult;
/** UUID of first entry to keep */
firstKeptEntryId: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AgentMessage[];
/** Messages that will be kept after the summary (recent turns) */
messagesToKeep: AgentMessage[];
tokensBefore: number;
boundaryStart: number;
}
export function prepareCompaction(
entries: SessionEntry[],
settings: CompactionSettings,
): CompactionPreparation | undefined {
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
return undefined;
}
let prevCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = entries.length;
const lastUsage = getLastAssistantUsage(entries);
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
// Get UUID of first kept entry
const firstKeptEntry = entries[cutPoint.firstKeptEntryIndex];
if (!firstKeptEntry?.id) {
return undefined; // Session needs migration
}
const firstKeptEntryId = firstKeptEntry.id;
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) messagesToSummarize.push(msg);
}
// Messages to keep (recent turns, kept after summary)
const messagesToKeep: AgentMessage[] = [];
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) messagesToKeep.push(msg);
}
return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
}
// ============================================================================
// Main compaction function
// ============================================================================
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
Summarize the prefix to provide context for the retained suffix:
## Original Request
[What did the user ask for in this turn?]
## Early Progress
- [Key decisions and work done in the prefix]
## Context for Suffix
- [Information needed to understand the retained recent work]
Be concise. Focus on what's needed to understand the kept suffix.`;
/**
* Calculate compaction and generate summary.
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
*
* @param entries - All session entries (must have uuid fields for v2)
* @param model - Model to use for summarization
* @param settings - Compaction settings
* @param apiKey - API key for LLM
* @param signal - Optional abort signal
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
entries: SessionEntry[],
model: Model<any>,
settings: CompactionSettings,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
): Promise<CompactionResult> {
// Don't compact if the last entry is already a compaction
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
throw new Error("Already compacted");
}
// Find previous compaction boundary
let prevCompactionIndex = -1;
for (let i = entries.length - 1; i >= 0; i--) {
if (entries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = entries.length;
// Get token count before compaction
const lastUsage = getLastAssistantUsage(entries);
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
// Find cut point (entry index) within the valid range
const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
// Extract messages for history summary (before the turn that contains the cut point)
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
const historyMessages: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) historyMessages.push(msg);
}
// Get previous summary for iterative update (if not from hook)
let previousSummary: string | undefined;
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
previousSummary = prevCompaction.summary;
}
// Extract file operations from messages and previous compaction
const fileOps = extractFileOperations(historyMessages, entries, prevCompactionIndex);
// Extract messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AgentMessage[] = [];
if (cutResult.isSplitTurn) {
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) turnPrefixMessages.push(msg);
}
// Also extract file ops from turn prefix
for (const msg of turnPrefixMessages) {
extractFileOpsFromMessage(msg, fileOps);
}
}
// Generate summaries (can be parallel if both needed) and merge into one
let summary: string;
if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) {
// Generate both summaries in parallel
const [historyResult, turnPrefixResult] = await Promise.all([
historyMessages.length > 0
? generateSummary(
historyMessages,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
]);
// Merge into single summary
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
} else {
// Just generate history summary
summary = await generateSummary(
historyMessages,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
);
}
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
// Get UUID of first kept entry
const firstKeptEntry = entries[cutResult.firstKeptEntryIndex];
const firstKeptEntryId = firstKeptEntry.id;
if (!firstKeptEntryId) {
throw new Error("First kept entry has no UUID - session may need migration");
}
return {
summary,
firstKeptEntryId,
tokensBefore,
details: { readFiles, modifiedFiles } as CompactionDetails,
};
}
/**
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
): Promise<string> {
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
const transformedMessages = convertToLlm(messages);
const summarizationMessages = [
...transformedMessages,
{
role: "user" as const,
content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }],
timestamp: Date.now(),
},
];
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
if (response.stopReason === "error") {
throw new Error(`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`);
}
return response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
}

View file

@ -0,0 +1,7 @@
/**
* Compaction and summarization utilities.
*/
export * from "./branch-summarization.js";
export * from "./compaction.js";
export * from "./utils.js";

View file

@ -0,0 +1,154 @@
/**
* Shared utilities for compaction and branch summarization.
*/
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
// ============================================================================
// File Operation Tracking
// ============================================================================
export interface FileOperations {
read: Set<string>;
written: Set<string>;
edited: Set<string>;
}
export function createFileOps(): FileOperations {
return {
read: new Set(),
written: new Set(),
edited: new Set(),
};
}
/**
* Extract file operations from tool calls in an assistant message.
*/
export function extractFileOpsFromMessage(message: AgentMessage, fileOps: FileOperations): void {
if (message.role !== "assistant") return;
if (!("content" in message) || !Array.isArray(message.content)) return;
for (const block of message.content) {
if (typeof block !== "object" || block === null) continue;
if (!("type" in block) || block.type !== "toolCall") continue;
if (!("arguments" in block) || !("name" in block)) continue;
const args = block.arguments as Record<string, unknown> | undefined;
if (!args) continue;
const path = typeof args.path === "string" ? args.path : undefined;
if (!path) continue;
switch (block.name) {
case "read":
fileOps.read.add(path);
break;
case "write":
fileOps.written.add(path);
break;
case "edit":
fileOps.edited.add(path);
break;
}
}
}
/**
* Compute final file lists from file operations.
* Returns readFiles (files only read, not modified) and modifiedFiles.
*/
export function computeFileLists(fileOps: FileOperations): { readFiles: string[]; modifiedFiles: string[] } {
const modified = new Set([...fileOps.edited, ...fileOps.written]);
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
const modifiedFiles = [...modified].sort();
return { readFiles: readOnly, modifiedFiles };
}
/**
* Format file operations as XML tags for summary.
*/
export function formatFileOperations(readFiles: string[], modifiedFiles: string[]): string {
const sections: string[] = [];
if (readFiles.length > 0) {
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
}
if (modifiedFiles.length > 0) {
sections.push(`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`);
}
if (sections.length === 0) return "";
return `\n\n${sections.join("\n\n")}`;
}
// ============================================================================
// Message Serialization
// ============================================================================
/**
* Serialize LLM messages to text for summarization.
* This prevents the model from treating it as a conversation to continue.
* Call convertToLlm() first to handle custom message types.
*/
export function serializeConversation(messages: Message[]): string {
const parts: string[] = [];
for (const msg of messages) {
if (msg.role === "user") {
const content =
typeof msg.content === "string"
? msg.content
: msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) parts.push(`[User]: ${content}`);
} else if (msg.role === "assistant") {
const textParts: string[] = [];
const thinkingParts: string[] = [];
const toolCalls: string[] = [];
for (const block of msg.content) {
if (block.type === "text") {
textParts.push(block.text);
} else if (block.type === "thinking") {
thinkingParts.push(block.thinking);
} else if (block.type === "toolCall") {
const args = block.arguments as Record<string, unknown>;
const argsStr = Object.entries(args)
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
.join(", ");
toolCalls.push(`${block.name}(${argsStr})`);
}
}
if (thinkingParts.length > 0) {
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
}
if (textParts.length > 0) {
parts.push(`[Assistant]: ${textParts.join("\n")}`);
}
if (toolCalls.length > 0) {
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
}
} else if (msg.role === "toolResult") {
const content = msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) {
parts.push(`[Tool result]: ${content}`);
}
}
}
return parts.join("\n\n");
}
// ============================================================================
// Summarization System Prompt
// ============================================================================
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;

View file

@ -7,7 +7,6 @@
* for custom tools that depend on pi packages.
*/
import { spawn } from "node:child_process";
import * as fs from "node:fs";
import { createRequire } from "node:module";
import * as os from "node:os";
@ -15,15 +14,10 @@ import * as path from "node:path";
import { fileURLToPath } from "node:url";
import { createJiti } from "jiti";
import { getAgentDir, isBunBinary } from "../../config.js";
import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js";
import type {
CustomToolFactory,
CustomToolsLoadResult,
ExecOptions,
ExecResult,
LoadedCustomTool,
ToolAPI,
} from "./types.js";
import type { CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool, ToolAPI } from "./types.js";
// Create require function to resolve module paths at runtime
const require = createRequire(import.meta.url);
@ -87,97 +81,16 @@ function resolveToolPath(toolPath: string, cwd: string): string {
return path.resolve(cwd, expanded);
}
/**
* Execute a command and return stdout/stderr/code.
* Supports cancellation via AbortSignal and timeout.
*/
async function execCommand(command: string, args: string[], cwd: string, options?: ExecOptions): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, {
cwd,
shell: false,
stdio: ["ignore", "pipe", "pipe"],
});
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout.on("data", (data) => {
stdout += data.toString();
});
proc.stderr.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({
stdout,
stderr,
code: code ?? 0,
killed,
});
});
proc.on("error", (err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({
stdout,
stderr: stderr || err.message,
code: 1,
killed,
});
});
});
}
/**
* Create a no-op UI context for headless modes.
*/
function createNoOpUIContext(): HookUIContext {
return {
select: async () => null,
select: async () => undefined,
confirm: async () => false,
input: async () => null,
input: async () => undefined,
notify: () => {},
custom: () => ({ close: () => {}, requestRender: () => {} }),
};
}
@ -298,7 +211,8 @@ export async function loadCustomTools(
// Shared API object - all tools get the same instance
const sharedApi: ToolAPI = {
cwd,
exec: (command: string, args: string[], options?: ExecOptions) => execCommand(command, args, cwd, options),
exec: (command: string, args: string[], options?: ExecOptions) =>
execCommand(command, args, options?.cwd ?? cwd, options),
ui: createNoOpUIContext(),
hasUI: false,
};

View file

@ -5,10 +5,11 @@
* They can provide custom rendering for tool calls and results in the TUI.
*/
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-ai";
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
import type { Component } from "@mariozechner/pi-tui";
import type { Static, TSchema } from "@sinclair/typebox";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookUIContext } from "../hooks/types.js";
import type { SessionEntry } from "../session-manager.js";
@ -18,20 +19,8 @@ export type ToolUIContext = HookUIContext;
/** Re-export for custom tools to use in execute signature */
export type { AgentToolUpdateCallback };
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
/** True if the process was killed due to signal or timeout */
killed?: boolean;
}
export interface ExecOptions {
/** AbortSignal to cancel the process */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
}
// Re-export for backward compatibility
export type { ExecOptions, ExecResult } from "../exec.js";
/** API passed to custom tool factory (stable across session changes) */
export interface ToolAPI {
@ -49,12 +38,12 @@ export interface ToolAPI {
export interface SessionEvent {
/** All session entries (including pre-compaction history) */
entries: SessionEntry[];
/** Current session file path, or null in --no-session mode */
sessionFile: string | null;
/** Previous session file path, or null for "start" and "new" */
previousSessionFile: string | null;
/** Current session file path, or undefined in --no-session mode */
sessionFile: string | undefined;
/** Previous session file path, or undefined for "start" and "new" */
previousSessionFile: string | undefined;
/** Reason for the session event */
reason: "start" | "switch" | "branch" | "new";
reason: "start" | "switch" | "branch" | "new" | "tree";
}
/** Rendering options passed to renderResult */

View file

@ -0,0 +1,104 @@
/**
* Shared command execution utilities for hooks and custom tools.
*/
import { spawn } from "node:child_process";
/**
* Options for executing shell commands.
*/
export interface ExecOptions {
/** AbortSignal to cancel the command */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
/** Working directory */
cwd?: string;
}
/**
* Result of executing a shell command.
*/
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
killed: boolean;
}
/**
* Execute a shell command and return stdout/stderr/code.
* Supports timeout and abort signal.
*/
export async function execCommand(
command: string,
args: string[],
cwd: string,
options?: ExecOptions,
): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, {
cwd,
shell: false,
stdio: ["ignore", "pipe", "pipe"],
});
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout?.on("data", (data) => {
stdout += data.toString();
});
proc.stderr?.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: code ?? 0, killed });
});
proc.on("error", (_err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: 1, killed });
});
});
}

View file

@ -1,4 +1,4 @@
import type { AgentState } from "@mariozechner/pi-agent-core";
import type { AgentMessage, AgentState } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, ImageContent, Message, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { existsSync, readFileSync, writeFileSync } from "fs";
import hljs from "highlight.js";
@ -7,7 +7,6 @@ import { homedir } from "os";
import * as path from "path";
import { basename } from "path";
import { APP_NAME, getCustomThemesDir, getThemesDir, VERSION } from "../config.js";
import { type BashExecutionMessage, isBashExecutionMessage } from "./messages.js";
import type { SessionManager } from "./session-manager.js";
// ============================================================================
@ -122,7 +121,7 @@ function resolveColorValue(
}
/** Load theme JSON from built-in or custom themes directory. */
function loadThemeJson(name: string): ThemeJson | null {
function loadThemeJson(name: string): ThemeJson | undefined {
// Try built-in themes first
const themesDir = getThemesDir();
const builtinPath = path.join(themesDir, `${name}.json`);
@ -130,7 +129,7 @@ function loadThemeJson(name: string): ThemeJson | null {
try {
return JSON.parse(readFileSync(builtinPath, "utf-8")) as ThemeJson;
} catch {
return null;
return undefined;
}
}
@ -141,11 +140,11 @@ function loadThemeJson(name: string): ThemeJson | null {
try {
return JSON.parse(readFileSync(customPath, "utf-8")) as ThemeJson;
} catch {
return null;
return undefined;
}
}
return null;
return undefined;
}
/** Build complete theme colors object, resolving theme JSON values against defaults. */
@ -821,110 +820,138 @@ function formatToolExecution(
return { html, bgColor };
}
function formatMessage(message: Message, toolResultsMap: Map<string, ToolResultMessage>, colors: ThemeColors): string {
function formatMessage(
message: AgentMessage,
toolResultsMap: Map<string, ToolResultMessage>,
colors: ThemeColors,
): string {
let html = "";
const timestamp = (message as { timestamp?: number }).timestamp;
const timestampHtml = timestamp ? `<div class="message-timestamp">${formatTimestamp(timestamp)}</div>` : "";
// Handle bash execution messages (user-executed via ! command)
if (isBashExecutionMessage(message)) {
const bashMsg = message as unknown as BashExecutionMessage;
const isError = bashMsg.cancelled || (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null);
switch (message.role) {
case "bashExecution": {
const isError =
message.cancelled ||
(message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined);
html += `<div class="tool-execution user-bash${isError ? " user-bash-error" : ""}">`;
html += timestampHtml;
html += `<div class="tool-command">$ ${escapeHtml(bashMsg.command)}</div>`;
html += `<div class="tool-execution user-bash${isError ? " user-bash-error" : ""}">`;
html += timestampHtml;
html += `<div class="tool-command">$ ${escapeHtml(message.command)}</div>`;
if (bashMsg.output) {
const lines = bashMsg.output.split("\n");
html += formatExpandableOutput(lines, 10);
if (message.output) {
const lines = message.output.split("\n");
html += formatExpandableOutput(lines, 10);
}
if (message.cancelled) {
html += `<div class="bash-status warning">(cancelled)</div>`;
} else if (message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined) {
html += `<div class="bash-status error">(exit ${message.exitCode})</div>`;
}
if (message.truncated && message.fullOutputPath) {
html += `<div class="bash-truncation warning">Output truncated. Full output: ${escapeHtml(message.fullOutputPath)}</div>`;
}
html += `</div>`;
break;
}
case "user": {
const userMsg = message as UserMessage;
let textContent = "";
const images: ImageContent[] = [];
if (bashMsg.cancelled) {
html += `<div class="bash-status warning">(cancelled)</div>`;
} else if (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null) {
html += `<div class="bash-status error">(exit ${bashMsg.exitCode})</div>`;
}
if (bashMsg.truncated && bashMsg.fullOutputPath) {
html += `<div class="bash-truncation warning">Output truncated. Full output: ${escapeHtml(bashMsg.fullOutputPath)}</div>`;
}
html += `</div>`;
return html;
}
if (message.role === "user") {
const userMsg = message as UserMessage;
let textContent = "";
const images: ImageContent[] = [];
if (typeof userMsg.content === "string") {
textContent = userMsg.content;
} else {
for (const block of userMsg.content) {
if (block.type === "text") {
textContent += block.text;
} else if (block.type === "image") {
images.push(block as ImageContent);
if (typeof userMsg.content === "string") {
textContent = userMsg.content;
} else {
for (const block of userMsg.content) {
if (block.type === "text") {
textContent += block.text;
} else if (block.type === "image") {
images.push(block as ImageContent);
}
}
}
}
html += `<div class="user-message">${timestampHtml}`;
html += `<div class="user-message">${timestampHtml}`;
// Render images first
if (images.length > 0) {
html += `<div class="message-images">`;
for (const img of images) {
html += `<img src="data:${img.mimeType};base64,${img.data}" alt="User uploaded image" class="message-image" />`;
// Render images first
if (images.length > 0) {
html += `<div class="message-images">`;
for (const img of images) {
html += `<img src="data:${img.mimeType};base64,${img.data}" alt="User uploaded image" class="message-image" />`;
}
html += `</div>`;
}
// Render text as markdown (server-side)
if (textContent.trim()) {
html += `<div class="markdown-content">${renderMarkdown(textContent)}</div>`;
}
html += `</div>`;
break;
}
case "assistant": {
html += timestampHtml ? `<div class="assistant-message">${timestampHtml}` : "";
// Render text as markdown (server-side)
if (textContent.trim()) {
html += `<div class="markdown-content">${renderMarkdown(textContent)}</div>`;
}
html += `</div>`;
} else if (message.role === "assistant") {
const assistantMsg = message as AssistantMessage;
html += timestampHtml ? `<div class="assistant-message">${timestampHtml}` : "";
for (const content of assistantMsg.content) {
if (content.type === "text" && content.text.trim()) {
// Render markdown server-side
html += `<div class="assistant-text markdown-content">${renderMarkdown(content.text)}</div>`;
} else if (content.type === "thinking" && content.thinking.trim()) {
html += `<div class="thinking-text">${escapeHtml(content.thinking.trim()).replace(/\n/g, "<br>")}</div>`;
for (const content of message.content) {
if (content.type === "text" && content.text.trim()) {
// Render markdown server-side
html += `<div class="assistant-text markdown-content">${renderMarkdown(content.text)}</div>`;
} else if (content.type === "thinking" && content.thinking.trim()) {
html += `<div class="thinking-text">${escapeHtml(content.thinking.trim()).replace(/\n/g, "<br>")}</div>`;
}
}
}
for (const content of assistantMsg.content) {
if (content.type === "toolCall") {
const toolResult = toolResultsMap.get(content.id);
const { html: toolHtml, bgColor } = formatToolExecution(
content.name,
content.arguments as Record<string, unknown>,
toolResult,
colors,
);
html += `<div class="tool-execution" style="background-color: ${bgColor}">${toolHtml}</div>`;
for (const content of message.content) {
if (content.type === "toolCall") {
const toolResult = toolResultsMap.get(content.id);
const { html: toolHtml, bgColor } = formatToolExecution(
content.name,
content.arguments as Record<string, unknown>,
toolResult,
colors,
);
html += `<div class="tool-execution" style="background-color: ${bgColor}">${toolHtml}</div>`;
}
}
}
const hasToolCalls = assistantMsg.content.some((c) => c.type === "toolCall");
if (!hasToolCalls) {
if (assistantMsg.stopReason === "aborted") {
html += '<div class="error-text">Aborted</div>';
} else if (assistantMsg.stopReason === "error") {
html += `<div class="error-text">Error: ${escapeHtml(assistantMsg.errorMessage || "Unknown error")}</div>`;
const hasToolCalls = message.content.some((c) => c.type === "toolCall");
if (!hasToolCalls) {
if (message.stopReason === "aborted") {
html += '<div class="error-text">Aborted</div>';
} else if (message.stopReason === "error") {
html += `<div class="error-text">Error: ${escapeHtml(message.errorMessage || "Unknown error")}</div>`;
}
}
}
if (timestampHtml) {
html += "</div>";
if (timestampHtml) {
html += "</div>";
}
break;
}
case "toolResult":
// Tool results are rendered inline with tool calls
break;
case "hookMessage":
// Hook messages with display:true shown as info boxes
if (message.display) {
const content = typeof message.content === "string" ? message.content : JSON.stringify(message.content);
html += `<div class="hook-message">${timestampHtml}<div class="hook-type">[${escapeHtml(message.customType)}]</div><div class="markdown-content">${renderMarkdown(content)}</div></div>`;
}
break;
case "compactionSummary":
// Rendered separately via formatCompaction
break;
case "branchSummary":
// Rendered as compaction-like summary
html += `<div class="compaction-container expanded"><div class="compaction-content"><div class="compaction-summary"><div class="compaction-summary-header">Branch Summary</div><div class="compaction-summary-content">${escapeHtml(message.summary).replace(/\n/g, "<br>")}</div></div></div></div>`;
break;
default: {
// Exhaustive check
const _exhaustive: never = message;
}
}
@ -995,7 +1022,7 @@ function generateHtml(data: ParsedSessionData, filename: string, colors: ThemeCo
const lastModelInfo = lastProvider ? `${lastProvider}/${lastModel}` : lastModel;
const contextWindow = data.contextWindow || 0;
const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : null;
const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : undefined;
let messagesHtml = "";
for (const event of data.sessionEvents) {
@ -1343,6 +1370,9 @@ export function exportSessionToHtml(
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
const sessionFile = sessionManager.getSessionFile();
if (!sessionFile) {
throw new Error("Cannot export in-memory session to HTML");
}
const content = readFileSync(sessionFile, "utf8");
const data = parseSessionFile(content);

View file

@ -1,39 +1,13 @@
export { discoverAndLoadHooks, type LoadedHook, type LoadHooksResult, loadHooks, type SendHandler } from "./loader.js";
export { type HookErrorListener, HookRunner } from "./runner.js";
export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js";
export type {
AgentEndEvent,
AgentStartEvent,
BashToolResultEvent,
CustomToolResultEvent,
EditToolResultEvent,
ExecResult,
FindToolResultEvent,
GrepToolResultEvent,
HookAPI,
HookError,
HookEvent,
HookEventContext,
HookFactory,
HookUIContext,
LsToolResultEvent,
ReadToolResultEvent,
SessionEvent,
SessionEventResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEvent,
ToolResultEventResult,
TurnEndEvent,
TurnStartEvent,
WriteToolResultEvent,
} from "./types.js";
// biome-ignore assist/source/organizeImports: biome is not smart
export {
isBashToolResult,
isEditToolResult,
isFindToolResult,
isGrepToolResult,
isLsToolResult,
isReadToolResult,
isWriteToolResult,
} from "./types.js";
discoverAndLoadHooks,
loadHooks,
type AppendEntryHandler,
type LoadedHook,
type LoadHooksResult,
type SendMessageHandler,
} from "./loader.js";
export { execCommand, HookRunner, type HookErrorListener } from "./runner.js";
export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js";
export type * from "./types.js";
export type { ReadonlySessionManager } from "../session-manager.js";

View file

@ -7,10 +7,11 @@ import { createRequire } from "node:module";
import * as os from "node:os";
import * as path from "node:path";
import { fileURLToPath } from "node:url";
import type { Attachment } from "@mariozechner/pi-agent-core";
import { createJiti } from "jiti";
import { getAgentDir } from "../../config.js";
import type { HookAPI, HookFactory } from "./types.js";
import type { HookMessage } from "../messages.js";
import { execCommand } from "./runner.js";
import type { ExecOptions, HookAPI, HookFactory, HookMessageRenderer, RegisteredCommand } from "./types.js";
// Create require function to resolve module paths at runtime
const require = createRequire(import.meta.url);
@ -47,9 +48,17 @@ function getAliases(): Record<string, string> {
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
/**
* Send handler type for pi.send().
* Send message handler type for pi.sendMessage().
*/
export type SendHandler = (text: string, attachments?: Attachment[]) => void;
export type SendMessageHandler = <T = unknown>(
message: Pick<HookMessage<T>, "customType" | "content" | "display" | "details">,
triggerTurn?: boolean,
) => void;
/**
* Append entry handler type for pi.appendEntry().
*/
export type AppendEntryHandler = <T = unknown>(customType: string, data?: T) => void;
/**
* Registered handlers for a loaded hook.
@ -61,8 +70,14 @@ export interface LoadedHook {
resolvedPath: string;
/** Map of event type to handler functions */
handlers: Map<string, HandlerFn[]>;
/** Set the send handler for this hook's pi.send() */
setSendHandler: (handler: SendHandler) => void;
/** Map of customType to hook message renderer */
messageRenderers: Map<string, HookMessageRenderer>;
/** Map of command name to registered command */
commands: Map<string, RegisteredCommand>;
/** Set the send message handler for this hook's pi.sendMessage() */
setSendMessageHandler: (handler: SendMessageHandler) => void;
/** Set the append entry handler for this hook's pi.appendEntry() */
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
}
/**
@ -110,32 +125,62 @@ function resolveHookPath(hookPath: string, cwd: string): string {
}
/**
* Create a HookAPI instance that collects handlers.
* Returns the API and a function to set the send handler later.
* Create a HookAPI instance that collects handlers, renderers, and commands.
* Returns the API, maps, and a function to set the send message handler later.
*/
function createHookAPI(handlers: Map<string, HandlerFn[]>): {
function createHookAPI(
handlers: Map<string, HandlerFn[]>,
cwd: string,
): {
api: HookAPI;
setSendHandler: (handler: SendHandler) => void;
messageRenderers: Map<string, HookMessageRenderer>;
commands: Map<string, RegisteredCommand>;
setSendMessageHandler: (handler: SendMessageHandler) => void;
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
} {
let sendHandler: SendHandler = () => {
let sendMessageHandler: SendMessageHandler = () => {
// Default no-op until mode sets the handler
};
let appendEntryHandler: AppendEntryHandler = () => {
// Default no-op until mode sets the handler
};
const messageRenderers = new Map<string, HookMessageRenderer>();
const commands = new Map<string, RegisteredCommand>();
const api: HookAPI = {
// Cast to HookAPI - the implementation is more general (string event names)
// but the interface has specific overloads for type safety in hooks
const api = {
on(event: string, handler: HandlerFn): void {
const list = handlers.get(event) ?? [];
list.push(handler);
handlers.set(event, list);
},
send(text: string, attachments?: Attachment[]): void {
sendHandler(text, attachments);
sendMessage<T = unknown>(message: HookMessage<T>, triggerTurn?: boolean): void {
sendMessageHandler(message, triggerTurn);
},
appendEntry<T = unknown>(customType: string, data?: T): void {
appendEntryHandler(customType, data);
},
registerMessageRenderer<T = unknown>(customType: string, renderer: HookMessageRenderer<T>): void {
messageRenderers.set(customType, renderer as HookMessageRenderer);
},
registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void {
commands.set(name, { name, ...options });
},
exec(command: string, args: string[], options?: ExecOptions) {
return execCommand(command, args, options?.cwd ?? cwd, options);
},
} as HookAPI;
return {
api,
setSendHandler: (handler: SendHandler) => {
sendHandler = handler;
messageRenderers,
commands,
setSendMessageHandler: (handler: SendMessageHandler) => {
sendMessageHandler = handler;
},
setAppendEntryHandler: (handler: AppendEntryHandler) => {
appendEntryHandler = handler;
},
};
}
@ -164,13 +209,24 @@ async function loadHook(hookPath: string, cwd: string): Promise<{ hook: LoadedHo
// Create handlers map and API
const handlers = new Map<string, HandlerFn[]>();
const { api, setSendHandler } = createHookAPI(handlers);
const { api, messageRenderers, commands, setSendMessageHandler, setAppendEntryHandler } = createHookAPI(
handlers,
cwd,
);
// Call factory to register handlers
factory(api);
return {
hook: { path: hookPath, resolvedPath, handlers, setSendHandler },
hook: {
path: hookPath,
resolvedPath,
handlers,
messageRenderers,
commands,
setSendMessageHandler,
setAppendEntryHandler,
},
error: null,
};
} catch (err) {

View file

@ -2,17 +2,23 @@
* Hook runner - executes hooks and manages their lifecycle.
*/
import { spawn } from "node:child_process";
import type { LoadedHook, SendHandler } from "./loader.js";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js";
import type { AppendEntryHandler, LoadedHook, SendMessageHandler } from "./loader.js";
import type {
ExecOptions,
ExecResult,
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
ContextEvent,
ContextEventResult,
HookError,
HookEvent,
HookEventContext,
HookMessageRenderer,
HookUIContext,
SessionEvent,
SessionEventResult,
RegisteredCommand,
SessionBeforeCompactResult,
SessionBeforeTreeResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEventResult,
@ -28,73 +34,8 @@ const DEFAULT_TIMEOUT = 30000;
*/
export type HookErrorListener = (error: HookError) => void;
/**
* Execute a command and return stdout/stderr/code.
* Supports cancellation via AbortSignal and timeout.
*/
async function exec(command: string, args: string[], cwd: string, options?: ExecOptions): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, { cwd, shell: false });
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout?.on("data", (data) => {
stdout += data.toString();
});
proc.stderr?.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: code ?? 0, killed });
});
proc.on("error", (_err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: 1, killed });
});
});
}
// Re-export execCommand for backward compatibility
export { execCommand } from "../exec.js";
/**
* Create a promise that rejects after a timeout.
@ -112,10 +53,11 @@ function createTimeout(ms: number): { promise: Promise<never>; clear: () => void
/** No-op UI context used when no UI is available */
const noOpUIContext: HookUIContext = {
select: async () => null,
select: async () => undefined,
confirm: async () => false,
input: async () => null,
input: async () => undefined,
notify: () => {},
custom: () => ({ close: () => {}, requestRender: () => {} }),
};
/**
@ -126,16 +68,24 @@ export class HookRunner {
private uiContext: HookUIContext;
private hasUI: boolean;
private cwd: string;
private sessionFile: string | null;
private sessionManager: SessionManager;
private modelRegistry: ModelRegistry;
private timeout: number;
private errorListeners: Set<HookErrorListener> = new Set();
constructor(hooks: LoadedHook[], cwd: string, timeout: number = DEFAULT_TIMEOUT) {
constructor(
hooks: LoadedHook[],
cwd: string,
sessionManager: SessionManager,
modelRegistry: ModelRegistry,
timeout: number = DEFAULT_TIMEOUT,
) {
this.hooks = hooks;
this.uiContext = noOpUIContext;
this.hasUI = false;
this.cwd = cwd;
this.sessionFile = null;
this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry;
this.timeout = timeout;
}
@ -148,6 +98,20 @@ export class HookRunner {
this.hasUI = hasUI;
}
/**
* Get the UI context (set by mode).
*/
getUIContext(): HookUIContext | null {
return this.uiContext;
}
/**
* Get whether UI is available.
*/
getHasUI(): boolean {
return this.hasUI;
}
/**
* Get the paths of all loaded hooks.
*/
@ -156,19 +120,22 @@ export class HookRunner {
}
/**
* Set the session file path.
* Set the send message handler for all hooks' pi.sendMessage().
* Call this when the mode initializes.
*/
setSessionFile(sessionFile: string | null): void {
this.sessionFile = sessionFile;
setSendMessageHandler(handler: SendMessageHandler): void {
for (const hook of this.hooks) {
hook.setSendMessageHandler(handler);
}
}
/**
* Set the send handler for all hooks' pi.send().
* Set the append entry handler for all hooks' pi.appendEntry().
* Call this when the mode initializes.
*/
setSendHandler(handler: SendHandler): void {
setAppendEntryHandler(handler: AppendEntryHandler): void {
for (const hook of this.hooks) {
hook.setSendHandler(handler);
hook.setAppendEntryHandler(handler);
}
}
@ -184,7 +151,10 @@ export class HookRunner {
/**
* Emit an error to all listeners.
*/
private emitError(error: HookError): void {
/**
* Emit an error to all error listeners.
*/
emitError(error: HookError): void {
for (const listener of this.errorListeners) {
listener(error);
}
@ -203,26 +173,89 @@ export class HookRunner {
return false;
}
/**
* Get a message renderer for the given customType.
* Returns the first renderer found across all hooks, or undefined if none.
*/
getMessageRenderer(customType: string): HookMessageRenderer | undefined {
for (const hook of this.hooks) {
const renderer = hook.messageRenderers.get(customType);
if (renderer) {
return renderer;
}
}
return undefined;
}
/**
* Get all registered commands from all hooks.
*/
getRegisteredCommands(): RegisteredCommand[] {
const commands: RegisteredCommand[] = [];
for (const hook of this.hooks) {
for (const command of hook.commands.values()) {
commands.push(command);
}
}
return commands;
}
/**
* Get a registered command by name.
* Returns the first command found across all hooks, or undefined if none.
*/
getCommand(name: string): RegisteredCommand | undefined {
for (const hook of this.hooks) {
const command = hook.commands.get(name);
if (command) {
return command;
}
}
return undefined;
}
/**
* Create the event context for handlers.
*/
private createContext(): HookEventContext {
return {
exec: (command: string, args: string[], options?: ExecOptions) => exec(command, args, this.cwd, options),
ui: this.uiContext,
hasUI: this.hasUI,
cwd: this.cwd,
sessionFile: this.sessionFile,
sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry,
};
}
/**
* Emit an event to all hooks.
* Returns the result from session/tool_result events (if any handler returns one).
* Check if event type is a session "before_*" event that can be cancelled.
*/
async emit(event: HookEvent): Promise<SessionEventResult | ToolResultEventResult | undefined> {
private isSessionBeforeEvent(
type: string,
): type is
| "session_before_switch"
| "session_before_new"
| "session_before_branch"
| "session_before_compact"
| "session_before_tree" {
return (
type === "session_before_switch" ||
type === "session_before_new" ||
type === "session_before_branch" ||
type === "session_before_compact" ||
type === "session_before_tree"
);
}
/**
* Emit an event to all hooks.
* Returns the result from session before_* / tool_result events (if any handler returns one).
*/
async emit(
event: HookEvent,
): Promise<SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined> {
const ctx = this.createContext();
let result: SessionEventResult | ToolResultEventResult | undefined;
let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined;
for (const hook of this.hooks) {
const handlers = hook.handlers.get(event.type);
@ -230,11 +263,10 @@ export class HookRunner {
for (const handler of handlers) {
try {
// No timeout for before_compact events (like tool_call, they may take a while)
const isBeforeCompact = event.type === "session" && (event as SessionEvent).reason === "before_compact";
// No timeout for session_before_compact events (like tool_call, they may take a while)
let handlerResult: unknown;
if (isBeforeCompact) {
if (event.type === "session_before_compact") {
handlerResult = await handler(event, ctx);
} else {
const timeout = createTimeout(this.timeout);
@ -242,9 +274,9 @@ export class HookRunner {
timeout.clear();
}
// For session events, capture the result (for before_* cancellation)
if (event.type === "session" && handlerResult) {
result = handlerResult as SessionEventResult;
// For session before_* events, capture the result (for cancellation)
if (this.isSessionBeforeEvent(event.type) && handlerResult) {
result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult;
// If cancelled, stop processing further hooks
if (result.cancel) {
return result;
@ -298,4 +330,83 @@ export class HookRunner {
return result;
}
/**
* Emit a context event to all hooks.
* Handlers are chained - each gets the previous handler's output (if any).
* Returns the final modified messages, or the original if no modifications.
*
* Note: Messages are already deep-copied by the caller (pi-ai preprocessor).
*/
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
const ctx = this.createContext();
let currentMessages = messages;
for (const hook of this.hooks) {
const handlers = hook.handlers.get("context");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ContextEvent = { type: "context", messages: currentMessages };
const timeout = createTimeout(this.timeout);
const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]);
timeout.clear();
if (handlerResult && (handlerResult as ContextEventResult).messages) {
currentMessages = (handlerResult as ContextEventResult).messages!;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
this.emitError({
hookPath: hook.path,
event: "context",
error: message,
});
}
}
}
return currentMessages;
}
/**
* Emit before_agent_start event to all hooks.
* Returns the first message to inject (if any handler returns one).
*/
async emitBeforeAgentStart(
prompt: string,
images?: import("@mariozechner/pi-ai").ImageContent[],
): Promise<BeforeAgentStartEventResult | undefined> {
const ctx = this.createContext();
let result: BeforeAgentStartEventResult | undefined;
for (const hook of this.hooks) {
const handlers = hook.handlers.get("before_agent_start");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images };
const timeout = createTimeout(this.timeout);
const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]);
timeout.clear();
// Take the first message returned
if (handlerResult && (handlerResult as BeforeAgentStartEventResult).message && !result) {
result = handlerResult as BeforeAgentStartEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
this.emitError({
hookPath: hook.path,
event: "before_agent_start",
error: message,
});
}
}
}
return result;
}
}

View file

@ -2,7 +2,7 @@
* Tool wrapper - wraps tools with hook callbacks for interception.
*/
import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-ai";
import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
import type { HookRunner } from "./runner.js";
import type { ToolCallEventResult, ToolResultEventResult } from "./types.js";

View file

@ -5,10 +5,37 @@
* and interact with the user via UI primitives.
*/
import type { AppMessage, Attachment } from "@mariozechner/pi-agent-core";
import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
import type { CutPointResult } from "../compaction.js";
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Message, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { CompactionPreparation, CompactionResult } from "../compaction/index.js";
import type { ExecOptions, ExecResult } from "../exec.js";
import type { HookMessage } from "../messages.js";
import type { ModelRegistry } from "../model-registry.js";
import type { BranchSummaryEntry, CompactionEntry, SessionEntry, SessionManager } from "../session-manager.js";
/**
* Read-only view of SessionManager for hooks.
* Hooks should use pi.sendMessage() and pi.appendEntry() for writes.
*/
export type ReadonlySessionManager = Pick<
SessionManager,
| "getCwd"
| "getSessionDir"
| "getSessionId"
| "getSessionFile"
| "getLeafId"
| "getLeafEntry"
| "getEntry"
| "getLabel"
| "getPath"
| "getHeader"
| "getEntries"
| "getTree"
>;
import type { EditToolDetails } from "../tools/edit.js";
import type {
BashToolDetails,
FindToolDetails,
@ -17,27 +44,8 @@ import type {
ReadToolDetails,
} from "../tools/index.js";
// ============================================================================
// Execution Context
// ============================================================================
/**
* Result of executing a command via ctx.exec()
*/
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
/** True if the process was killed due to signal or timeout */
killed?: boolean;
}
export interface ExecOptions {
/** AbortSignal to cancel the process */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
}
// Re-export for backward compatibility
export type { ExecOptions, ExecResult } from "../exec.js";
/**
* UI context for hooks to request interactive UI from the harness.
@ -50,7 +58,7 @@ export interface HookUIContext {
* @param options - Array of string options
* @returns Selected option string, or null if cancelled
*/
select(title: string, options: string[]): Promise<string | null>;
select(title: string, options: string[]): Promise<string | undefined>;
/**
* Show a confirmation dialog.
@ -60,97 +68,191 @@ export interface HookUIContext {
/**
* Show a text input dialog.
* @returns User input, or null if cancelled
* @returns User input, or undefined if cancelled
*/
input(title: string, placeholder?: string): Promise<string | null>;
input(title: string, placeholder?: string): Promise<string | undefined>;
/**
* Show a notification to the user.
*/
notify(message: string, type?: "info" | "warning" | "error"): void;
/**
* Show a custom component with keyboard focus.
* The component receives keyboard input via handleInput() if implemented.
*
* @param component - Component to display (implement handleInput for keyboard, dispose for cleanup)
* @returns Object with close() to restore normal UI and requestRender() to trigger redraw
*/
custom(component: Component & { dispose?(): void }): { close: () => void; requestRender: () => void };
}
/**
* Context passed to hook event handlers.
*/
export interface HookEventContext {
/** Execute a command and return stdout/stderr/code */
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
/** UI methods for user interaction */
ui: HookUIContext;
/** Whether UI is available (false in print mode) */
hasUI: boolean;
/** Current working directory */
cwd: string;
/** Path to session file, or null if --no-session */
sessionFile: string | null;
/** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */
sessionManager: ReadonlySessionManager;
/** Model registry - use for API key resolution and model retrieval */
modelRegistry: ModelRegistry;
}
// ============================================================================
// Events
// Session Events
// ============================================================================
/**
* Base fields shared by all session events.
*/
interface SessionEventBase {
type: "session";
/** All session entries (including pre-compaction history) */
entries: SessionEntry[];
/** Current session file path, or null in --no-session mode */
sessionFile: string | null;
/** Previous session file path, or null for "start" and "new" */
previousSessionFile: string | null;
/** Fired on initial session load */
export interface SessionStartEvent {
type: "session_start";
}
/**
* Event data for session events.
* Discriminated union based on reason.
*
* Lifecycle:
* - start: Initial session load
* - before_switch / switch: Session switch (e.g., /resume command)
* - before_new / new: New session (e.g., /new command)
* - before_branch / branch: Session branch (e.g., /branch command)
* - before_compact / compact: Before/after context compaction
* - shutdown: Process exit (SIGINT/SIGTERM)
*
* "before_*" events fire before the action and can be cancelled via SessionEventResult.
* Other events fire after the action completes.
*/
/** Fired before switching to another session (can be cancelled) */
export interface SessionBeforeSwitchEvent {
type: "session_before_switch";
/** Session file we're switching to */
targetSessionFile: string;
}
/** Fired after switching to another session */
export interface SessionSwitchEvent {
type: "session_switch";
/** Session file we came from */
previousSessionFile: string | undefined;
}
/** Fired before creating a new session (can be cancelled) */
export interface SessionBeforeNewEvent {
type: "session_before_new";
}
/** Fired after creating a new session */
export interface SessionNewEvent {
type: "session_new";
}
/** Fired before branching a session (can be cancelled) */
export interface SessionBeforeBranchEvent {
type: "session_before_branch";
/** Index of the entry in the session (SessionManager.getEntries()) to branch from */
entryIndex: number;
}
/** Fired after branching a session */
export interface SessionBranchEvent {
type: "session_branch";
previousSessionFile: string | undefined;
}
/** Fired before context compaction (can be cancelled or customized) */
export interface SessionBeforeCompactEvent {
type: "session_before_compact";
/** Compaction preparation with cut point, messages to summarize/keep, etc. */
preparation: CompactionPreparation;
/** Previous compaction entries, newest first. Use for iterative summarization. */
previousCompactions: CompactionEntry[];
/** Optional user-provided instructions for the summary */
customInstructions?: string;
/** Current model */
model: Model<any>;
/** Abort signal - hooks should pass this to LLM calls and check it periodically */
signal: AbortSignal;
}
/** Fired after context compaction */
export interface SessionCompactEvent {
type: "session_compact";
compactionEntry: CompactionEntry;
/** Whether the compaction entry was provided by a hook */
fromHook: boolean;
}
/** Fired on process exit (SIGINT/SIGTERM) */
export interface SessionShutdownEvent {
type: "session_shutdown";
}
/** Preparation data for tree navigation (used by session_before_tree event) */
export interface TreePreparation {
/** Node being switched to */
targetId: string;
/** Current active leaf (being abandoned), null if no current position */
oldLeafId: string | null;
/** Common ancestor of target and old leaf, null if no common ancestor */
commonAncestorId: string | null;
/** Entries to summarize (old leaf back to common ancestor or compaction) */
entriesToSummarize: SessionEntry[];
/** Whether user chose to summarize */
userWantsSummary: boolean;
}
/** Fired before navigating to a different node in the session tree (can be cancelled) */
export interface SessionBeforeTreeEvent {
type: "session_before_tree";
/** Preparation data for the navigation */
preparation: TreePreparation;
/** Model to use for summarization (conversation model) */
model: Model<any>;
/** Abort signal - honors Escape during summarization */
signal: AbortSignal;
}
/** Fired after navigating to a different node in the session tree */
export interface SessionTreeEvent {
type: "session_tree";
/** The new active leaf, null if navigated to before first entry */
newLeafId: string | null;
/** Previous active leaf, null if there was no position */
oldLeafId: string | null;
/** Branch summary entry if one was created */
summaryEntry?: BranchSummaryEntry;
/** Whether summary came from hook */
fromHook?: boolean;
}
/** Union of all session event types */
export type SessionEvent =
| (SessionEventBase & {
reason: "start" | "switch" | "new" | "before_switch" | "before_new" | "shutdown";
})
| (SessionEventBase & {
reason: "branch" | "before_branch";
/** Index of the turn to branch from */
targetTurnIndex: number;
})
| (SessionEventBase & {
reason: "before_compact";
cutPoint: CutPointResult;
/** Summary from previous compaction, if any. Include this in your summary to preserve context. */
previousSummary?: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AppMessage[];
/** Messages that will be kept after the summary (recent turns) */
messagesToKeep: AppMessage[];
tokensBefore: number;
customInstructions?: string;
model: Model<any>;
/** Resolve API key for any model (checks settings, OAuth, env vars) */
resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
/** Abort signal - hooks should pass this to LLM calls and check it periodically */
signal: AbortSignal;
})
| (SessionEventBase & {
reason: "compact";
compactionEntry: CompactionEntry;
tokensBefore: number;
/** Whether the compaction entry was provided by a hook */
fromHook: boolean;
});
| SessionStartEvent
| SessionBeforeSwitchEvent
| SessionSwitchEvent
| SessionBeforeNewEvent
| SessionNewEvent
| SessionBeforeBranchEvent
| SessionBranchEvent
| SessionBeforeCompactEvent
| SessionCompactEvent
| SessionShutdownEvent
| SessionBeforeTreeEvent
| SessionTreeEvent;
/**
* Event data for context event.
* Fired before each LLM call, allowing hooks to modify context non-destructively.
* Original session messages are NOT modified - only the messages sent to the LLM are affected.
*/
export interface ContextEvent {
type: "context";
/** Messages about to be sent to the LLM (deep copy, safe to modify) */
messages: AgentMessage[];
}
/**
* Event data for before_agent_start event.
* Fired after user submits a prompt but before the agent loop starts.
* Allows hooks to inject context that will be persisted and visible in TUI.
*/
export interface BeforeAgentStartEvent {
type: "before_agent_start";
/** The user's prompt text */
prompt: string;
/** Any images attached to the prompt */
images?: ImageContent[];
}
/**
* Event data for agent_start event.
@ -165,7 +267,7 @@ export interface AgentStartEvent {
*/
export interface AgentEndEvent {
type: "agent_end";
messages: AppMessage[];
messages: AgentMessage[];
}
/**
@ -183,7 +285,7 @@ export interface TurnStartEvent {
export interface TurnEndEvent {
type: "turn_end";
turnIndex: number;
message: AppMessage;
message: AgentMessage;
toolResults: ToolResultMessage[];
}
@ -231,7 +333,7 @@ export interface ReadToolResultEvent extends ToolResultEventBase {
/** Tool result event for edit tool */
export interface EditToolResultEvent extends ToolResultEventBase {
toolName: "edit";
details: undefined;
details: EditToolDetails | undefined;
}
/** Tool result event for write tool */
@ -307,6 +409,8 @@ export function isLsToolResult(e: ToolResultEvent): e is LsToolResultEvent {
*/
export type HookEvent =
| SessionEvent
| ContextEvent
| BeforeAgentStartEvent
| AgentStartEvent
| AgentEndEvent
| TurnStartEvent
@ -318,6 +422,15 @@ export type HookEvent =
// Event Results
// ============================================================================
/**
* Return type for context event handlers.
* Allows hooks to modify messages before they're sent to the LLM.
*/
export interface ContextEventResult {
/** Modified messages to send instead of the original */
messages?: Message[];
}
/**
* Return type for tool_call event handlers.
* Allows hooks to block tool execution.
@ -343,16 +456,68 @@ export interface ToolResultEventResult {
}
/**
* Return type for session event handlers.
* Allows hooks to cancel "before_*" actions.
* Return type for before_agent_start event handlers.
* Allows hooks to inject context before the agent runs.
*/
export interface SessionEventResult {
/** If true, cancel the pending action (switch, clear, or branch) */
export interface BeforeAgentStartEventResult {
/** Message to inject into context (persisted to session, visible in TUI) */
message?: Pick<HookMessage, "customType" | "content" | "display" | "details">;
}
/** Return type for session_before_switch handlers */
export interface SessionBeforeSwitchResult {
/** If true, cancel the switch */
cancel?: boolean;
/** If true (for before_branch only), skip restoring conversation to branch point while still creating the branched session file */
}
/** Return type for session_before_new handlers */
export interface SessionBeforeNewResult {
/** If true, cancel the new session */
cancel?: boolean;
}
/** Return type for session_before_branch handlers */
export interface SessionBeforeBranchResult {
/**
* If true, abort the branch entirely. No new session file is created,
* conversation stays unchanged.
*/
cancel?: boolean;
/**
* If true, the branch proceeds (new session file created, session state updated)
* but the in-memory conversation is NOT rewound to the branch point.
*
* Use case: git-checkpoint hook that restores code state separately.
* The hook handles state restoration itself, so it doesn't want the
* agent's conversation to be rewound (which would lose recent context).
*
* - `cancel: true` nothing happens, user stays in current session
* - `skipConversationRestore: true` branch happens, but messages stay as-is
* - neither branch happens AND messages rewind to branch point (default)
*/
skipConversationRestore?: boolean;
/** Custom compaction entry (for before_compact event) */
compactionEntry?: CompactionEntry;
}
/** Return type for session_before_compact handlers */
export interface SessionBeforeCompactResult {
/** If true, cancel the compaction */
cancel?: boolean;
/** Custom compaction result - SessionManager adds id/parentId */
compaction?: CompactionResult;
}
/** Return type for session_before_tree handlers */
export interface SessionBeforeTreeResult {
/** If true, cancel the navigation entirely */
cancel?: boolean;
/**
* Custom summary (skips default summarizer).
* Only used if preparation.userWantsSummary is true.
*/
summary?: {
summary: string;
details?: unknown;
};
}
// ============================================================================
@ -361,29 +526,152 @@ export interface SessionEventResult {
/**
* Handler function type for each event.
* Handlers can return R, undefined, or void (bare return statements).
*/
export type HookHandler<E, R = void> = (event: E, ctx: HookEventContext) => Promise<R>;
// biome-ignore lint/suspicious/noConfusingVoidType: void allows bare return statements in handlers
export type HookHandler<E, R = undefined> = (event: E, ctx: HookEventContext) => Promise<R | void> | R | void;
export interface HookMessageRenderOptions {
/** Whether the view is expanded */
expanded: boolean;
}
/**
* Renderer for hook messages.
* Hooks register these to provide custom TUI rendering for their message types.
*/
export type HookMessageRenderer<T = unknown> = (
message: HookMessage<T>,
options: HookMessageRenderOptions,
theme: Theme,
) => Component | undefined;
/**
* Context passed to hook command handlers.
*/
export interface HookCommandContext {
/** Arguments after the command name */
args: string;
/** UI methods for user interaction */
ui: HookUIContext;
/** Whether UI is available (false in print mode) */
hasUI: boolean;
/** Current working directory */
cwd: string;
/** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */
sessionManager: ReadonlySessionManager;
/** Model registry for API keys */
modelRegistry: ModelRegistry;
}
/**
* Command registration options.
*/
export interface RegisteredCommand {
name: string;
description?: string;
handler: (ctx: HookCommandContext) => Promise<void>;
}
/**
* HookAPI passed to hook factory functions.
* Hooks use pi.on() to subscribe to events and pi.send() to inject messages.
* Hooks use pi.on() to subscribe to events and pi.sendMessage() to inject messages.
*/
export interface HookAPI {
// biome-ignore lint/suspicious/noConfusingVoidType: void allows handlers to not return anything
on(event: "session", handler: HookHandler<SessionEvent, SessionEventResult | void>): void;
// Session events
on(event: "session_start", handler: HookHandler<SessionStartEvent>): void;
on(event: "session_before_switch", handler: HookHandler<SessionBeforeSwitchEvent, SessionBeforeSwitchResult>): void;
on(event: "session_switch", handler: HookHandler<SessionSwitchEvent>): void;
on(event: "session_before_new", handler: HookHandler<SessionBeforeNewEvent, SessionBeforeNewResult>): void;
on(event: "session_new", handler: HookHandler<SessionNewEvent>): void;
on(event: "session_before_branch", handler: HookHandler<SessionBeforeBranchEvent, SessionBeforeBranchResult>): void;
on(event: "session_branch", handler: HookHandler<SessionBranchEvent>): void;
on(
event: "session_before_compact",
handler: HookHandler<SessionBeforeCompactEvent, SessionBeforeCompactResult>,
): void;
on(event: "session_compact", handler: HookHandler<SessionCompactEvent>): void;
on(event: "session_shutdown", handler: HookHandler<SessionShutdownEvent>): void;
on(event: "session_before_tree", handler: HookHandler<SessionBeforeTreeEvent, SessionBeforeTreeResult>): void;
on(event: "session_tree", handler: HookHandler<SessionTreeEvent>): void;
// Context and agent events
on(event: "context", handler: HookHandler<ContextEvent, ContextEventResult>): void;
on(event: "before_agent_start", handler: HookHandler<BeforeAgentStartEvent, BeforeAgentStartEventResult>): void;
on(event: "agent_start", handler: HookHandler<AgentStartEvent>): void;
on(event: "agent_end", handler: HookHandler<AgentEndEvent>): void;
on(event: "turn_start", handler: HookHandler<TurnStartEvent>): void;
on(event: "turn_end", handler: HookHandler<TurnEndEvent>): void;
on(event: "tool_call", handler: HookHandler<ToolCallEvent, ToolCallEventResult | undefined>): void;
on(event: "tool_result", handler: HookHandler<ToolResultEvent, ToolResultEventResult | undefined>): void;
on(event: "tool_call", handler: HookHandler<ToolCallEvent, ToolCallEventResult>): void;
on(event: "tool_result", handler: HookHandler<ToolResultEvent, ToolResultEventResult>): void;
/**
* Send a message to the agent.
* If the agent is streaming, the message is queued.
* If the agent is idle, a new agent loop is started.
* Send a custom message to the session. Creates a CustomMessageEntry that
* participates in LLM context and can be displayed in the TUI.
*
* Use this when you want the LLM to see the message content.
* For hook state that should NOT be sent to the LLM, use appendEntry() instead.
*
* @param message - The message to send
* @param message.customType - Identifier for your hook (used for filtering on reload)
* @param message.content - Message content (string or TextContent/ImageContent array)
* @param message.display - Whether to show in TUI (true = styled display, false = hidden)
* @param message.details - Optional hook-specific metadata (not sent to LLM)
* @param triggerTurn - If true and agent is idle, triggers a new LLM turn. Default: false.
* If agent is streaming, message is queued and triggerTurn is ignored.
*/
send(text: string, attachments?: Attachment[]): void;
sendMessage<T = unknown>(
message: Pick<HookMessage<T>, "customType" | "content" | "display" | "details">,
triggerTurn?: boolean,
): void;
/**
* Append a custom entry to the session for hook state persistence.
* Creates a CustomEntry that does NOT participate in LLM context.
*
* Use this to store hook-specific data that should persist across session reloads
* but should NOT be sent to the LLM. On reload, scan session entries for your
* customType to reconstruct hook state.
*
* For messages that SHOULD be sent to the LLM, use sendMessage() instead.
*
* @param customType - Identifier for your hook (used for filtering on reload)
* @param data - Hook-specific data to persist (must be JSON-serializable)
*
* @example
* // Store permission state
* pi.appendEntry("permissions", { level: "full", grantedAt: Date.now() });
*
* // On reload, reconstruct state from entries
* pi.on("session", async (event, ctx) => {
* if (event.reason === "start") {
* const entries = event.sessionManager.getEntries();
* const myEntries = entries.filter(e => e.type === "custom" && e.customType === "permissions");
* // Reconstruct state from myEntries...
* }
* });
*/
appendEntry<T = unknown>(customType: string, data?: T): void;
/**
* Register a custom renderer for CustomMessageEntry with a specific customType.
* The renderer is called when rendering the entry in the TUI.
* Return nothing to use the default renderer.
*/
registerMessageRenderer<T = unknown>(customType: string, renderer: HookMessageRenderer<T>): void;
/**
* Register a custom slash command.
* Handler receives HookCommandContext.
*/
registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void;
/**
* Execute a shell command and return stdout/stderr/code.
* Supports timeout and abort signal.
*/
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
}
/**

View file

@ -7,12 +7,12 @@ export {
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type CompactionResult,
type ModelCycleResult,
type PromptOptions,
type SessionStats,
} from "./agent-session.js";
export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js";
export {
type CustomAgentTool,
type CustomToolFactory,

View file

@ -1,16 +1,27 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AppMessage type with coding-agent specific message types,
* Extends the base AgentMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
// ============================================================================
// Custom Message Types
// ============================================================================
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
<summary>
`;
export const COMPACTION_SUMMARY_SUFFIX = `
</summary>`;
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
<summary>
`;
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
/**
* Message type for bash executions via the ! command.
@ -19,35 +30,50 @@ export interface BashExecutionMessage {
role: "bashExecution";
command: string;
output: string;
exitCode: number | null;
exitCode: number | undefined;
cancelled: boolean;
truncated: boolean;
fullOutputPath?: string;
timestamp: number;
}
// Extend CustomMessages via declaration merging
/**
* Message type for hook-injected messages via sendMessage().
* These are custom messages that hooks can inject into the conversation.
*/
export interface HookMessage<T = unknown> {
role: "hookMessage";
customType: string;
content: string | (TextContent | ImageContent)[];
display: boolean;
details?: T;
timestamp: number;
}
export interface BranchSummaryMessage {
role: "branchSummary";
summary: string;
fromId: string;
timestamp: number;
}
export interface CompactionSummaryMessage {
role: "compactionSummary";
summary: string;
tokensBefore: number;
timestamp: number;
}
// Extend CustomAgentMessages via declaration merging
declare module "@mariozechner/pi-agent-core" {
interface CustomMessages {
interface CustomAgentMessages {
bashExecution: BashExecutionMessage;
hookMessage: HookMessage;
branchSummary: BranchSummaryMessage;
compactionSummary: CompactionSummaryMessage;
}
}
// ============================================================================
// Type Guards
// ============================================================================
/**
* Type guard for BashExecutionMessage.
*/
export function isBashExecutionMessage(msg: AppMessage | Message): msg is BashExecutionMessage {
return (msg as BashExecutionMessage).role === "bashExecution";
}
// ============================================================================
// Message Formatting
// ============================================================================
/**
* Convert a BashExecutionMessage to user message text for LLM context.
*/
@ -60,7 +86,7 @@ export function bashExecutionToText(msg: BashExecutionMessage): string {
}
if (msg.cancelled) {
text += "\n\n(command cancelled)";
} else if (msg.exitCode !== null && msg.exitCode !== 0) {
} else if (msg.exitCode !== null && msg.exitCode !== undefined && msg.exitCode !== 0) {
text += `\n\nCommand exited with code ${msg.exitCode}`;
}
if (msg.truncated && msg.fullOutputPath) {
@ -69,34 +95,95 @@ export function bashExecutionToText(msg: BashExecutionMessage): string {
return text;
}
// ============================================================================
// Message Transformer
// ============================================================================
export function createBranchSummaryMessage(summary: string, fromId: string, timestamp: string): BranchSummaryMessage {
return {
role: "branchSummary",
summary,
fromId,
timestamp: new Date(timestamp).getTime(),
};
}
export function createCompactionSummaryMessage(
summary: string,
tokensBefore: number,
timestamp: string,
): CompactionSummaryMessage {
return {
role: "compactionSummary",
summary: summary,
tokensBefore,
timestamp: new Date(timestamp).getTime(),
};
}
/** Convert CustomMessageEntry to AgentMessage format */
export function createHookMessage(
customType: string,
content: string | (TextContent | ImageContent)[],
display: boolean,
details: unknown | undefined,
timestamp: string,
): HookMessage {
return {
role: "hookMessage",
customType,
content,
display,
details,
timestamp: new Date(timestamp).getTime(),
};
}
/**
* Transform AppMessages (including custom types) to LLM-compatible Messages.
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's messageTransformer option (for prompt calls)
* - Agent's transormToLlm option (for prompt calls and queued messages)
* - Compaction's generateSummary (for summarization)
* - Custom hooks and tools
*/
export function messageTransformer(messages: AppMessage[]): Message[] {
export function convertToLlm(messages: AgentMessage[]): Message[] {
return messages
.map((m): Message | null => {
if (isBashExecutionMessage(m)) {
// Convert bash execution to user message
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
.map((m): Message | undefined => {
switch (m.role) {
case "bashExecution":
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
case "hookMessage": {
const content = typeof m.content === "string" ? [{ type: "text" as const, text: m.content }] : m.content;
return {
role: "user",
content,
timestamp: m.timestamp,
};
}
case "branchSummary":
return {
role: "user",
content: [{ type: "text" as const, text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX }],
timestamp: m.timestamp,
};
case "compactionSummary":
return {
role: "user",
content: [
{ type: "text" as const, text: COMPACTION_SUMMARY_PREFIX + m.summary + COMPACTION_SUMMARY_SUFFIX },
],
timestamp: m.timestamp,
};
case "user":
case "assistant":
case "toolResult":
return m;
default:
// biome-ignore lint/correctness/noSwitchDeclarations: fine
const _exhaustiveCheck: never = m;
return undefined;
}
// Pass through standard LLM roles
if (m.role === "user" || m.role === "assistant" || m.role === "toolResult") {
return m as Message;
}
// Filter out unknown message types
return null;
})
.filter((m): m is Message => m !== null);
.filter((m) => m !== undefined);
}

View file

@ -90,11 +90,11 @@ function resolveApiKeyConfig(keyConfig: string): string | undefined {
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private loadError: string | null = null;
private loadError: string | undefined = undefined;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | null = null,
private modelsJsonPath: string | undefined = undefined,
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
@ -114,14 +114,14 @@ export class ModelRegistry {
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = null;
this.loadError = undefined;
this.loadModels();
}
/**
* Get any error from loading models.json (null if no error).
* Get any error from loading models.json (undefined if no error).
*/
getError(): string | null {
getError(): string | undefined {
return this.loadError;
}
@ -160,9 +160,9 @@ export class ModelRegistry {
}
}
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | undefined } {
if (!existsSync(modelsJsonPath)) {
return { models: [], error: null };
return { models: [], error: undefined };
}
try {
@ -186,7 +186,7 @@ export class ModelRegistry {
this.validateConfig(config);
// Parse models
return { models: this.parseModels(config), error: null };
return { models: this.parseModels(config), error: undefined };
} catch (error) {
if (error instanceof SyntaxError) {
return {
@ -294,14 +294,14 @@ export class ModelRegistry {
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | null {
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
find(provider: string, modelId: string): Model<Api> | undefined {
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? undefined;
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | null> {
async getApiKey(model: Model<Api>): Promise<string | undefined> {
return this.authStorage.getApiKey(model.provider);
}

View file

@ -44,9 +44,9 @@ function isAlias(id: string): boolean {
/**
* Try to match a pattern to a model from the available models list.
* Returns the matched model or null if no match found.
* Returns the matched model or undefined if no match found.
*/
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | null {
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | undefined {
// Check for provider/modelId format (provider is everything before the first /)
const slashIndex = modelPattern.indexOf("/");
if (slashIndex !== -1) {
@ -75,7 +75,7 @@ function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Mod
);
if (matches.length === 0) {
return null;
return undefined;
}
// Separate into aliases and dated versions
@ -94,9 +94,9 @@ function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Mod
}
export interface ParsedModelResult {
model: Model<Api> | null;
model: Model<Api> | undefined;
thinkingLevel: ThinkingLevel;
warning: string | null;
warning: string | undefined;
}
/**
@ -116,14 +116,14 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
// Try exact match first
const exactMatch = tryMatchModel(pattern, availableModels);
if (exactMatch) {
return { model: exactMatch, thinkingLevel: "off", warning: null };
return { model: exactMatch, thinkingLevel: "off", warning: undefined };
}
// No match - try splitting on last colon if present
const lastColonIndex = pattern.lastIndexOf(":");
if (lastColonIndex === -1) {
// No colons, pattern simply doesn't match any model
return { model: null, thinkingLevel: "off", warning: null };
return { model: undefined, thinkingLevel: "off", warning: undefined };
}
const prefix = pattern.substring(0, lastColonIndex);
@ -193,9 +193,9 @@ export async function resolveModelScope(patterns: string[], modelRegistry: Model
}
export interface InitialModelResult {
model: Model<Api> | null;
model: Model<Api> | undefined;
thinkingLevel: ThinkingLevel;
fallbackMessage: string | null;
fallbackMessage: string | undefined;
}
/**
@ -227,7 +227,7 @@ export async function findInitialModel(options: {
modelRegistry,
} = options;
let model: Model<Api> | null = null;
let model: Model<Api> | undefined;
let thinkingLevel: ThinkingLevel = "off";
// 1. CLI args take priority
@ -237,7 +237,7 @@ export async function findInitialModel(options: {
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
process.exit(1);
}
return { model: found, thinkingLevel: "off", fallbackMessage: null };
return { model: found, thinkingLevel: "off", fallbackMessage: undefined };
}
// 2. Use first model from scoped models (skip if continuing/resuming)
@ -245,7 +245,7 @@ export async function findInitialModel(options: {
return {
model: scopedModels[0].model,
thinkingLevel: scopedModels[0].thinkingLevel,
fallbackMessage: null,
fallbackMessage: undefined,
};
}
@ -257,7 +257,7 @@ export async function findInitialModel(options: {
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: null };
return { model, thinkingLevel, fallbackMessage: undefined };
}
}
@ -270,16 +270,16 @@ export async function findInitialModel(options: {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
if (match) {
return { model: match, thinkingLevel: "off", fallbackMessage: null };
return { model: match, thinkingLevel: "off", fallbackMessage: undefined };
}
}
// If no default found, use first available
return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: null };
return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: undefined };
}
// 5. No model found
return { model: null, thinkingLevel: "off", fallbackMessage: null };
return { model: undefined, thinkingLevel: "off", fallbackMessage: undefined };
}
/**
@ -288,10 +288,10 @@ export async function findInitialModel(options: {
export async function restoreModelFromSession(
savedProvider: string,
savedModelId: string,
currentModel: Model<Api> | null,
currentModel: Model<Api> | undefined,
shouldPrintMessages: boolean,
modelRegistry: ModelRegistry,
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
): Promise<{ model: Model<Api> | undefined; fallbackMessage: string | undefined }> {
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
@ -301,7 +301,7 @@ export async function restoreModelFromSession(
if (shouldPrintMessages) {
console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`));
}
return { model: restoredModel, fallbackMessage: null };
return { model: restoredModel, fallbackMessage: undefined };
}
// Model not found or no API key - fall back
@ -327,7 +327,7 @@ export async function restoreModelFromSession(
if (availableModels.length > 0) {
// Try to find a default model from known providers
let fallbackModel: Model<Api> | null = null;
let fallbackModel: Model<Api> | undefined;
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
@ -353,5 +353,5 @@ export async function restoreModelFromSession(
}
// No models available
return { model: null, fallbackMessage: null };
return { model: undefined, fallbackMessage: undefined };
}

View file

@ -29,7 +29,7 @@
* ```
*/
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import { Agent, type ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import { join } from "path";
import { getAgentDir } from "../config.js";
@ -39,7 +39,7 @@ import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tool
import type { CustomAgentTool } from "./custom-tools/types.js";
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
import type { HookFactory } from "./hooks/types.js";
import { messageTransformer } from "./messages.js";
import { convertToLlm } from "./messages.js";
import { ModelRegistry } from "./model-registry.js";
import { SessionManager } from "./session-manager.js";
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
@ -340,7 +340,10 @@ function createFactoryFromLoadedHook(loaded: LoadedHook): HookFactory {
function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; factory: HookFactory }>): LoadedHook[] {
return definitions.map((def) => {
const handlers = new Map<string, Array<(...args: unknown[]) => Promise<unknown>>>();
let sendHandler: (text: string, attachments?: any[]) => void = () => {};
const messageRenderers = new Map<string, any>();
const commands = new Map<string, any>();
let sendMessageHandler: (message: any, triggerTurn?: boolean) => void = () => {};
let appendEntryHandler: (customType: string, data?: any) => void = () => {};
const api = {
on: (event: string, handler: (...args: unknown[]) => Promise<unknown>) => {
@ -348,8 +351,17 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
list.push(handler);
handlers.set(event, list);
},
send: (text: string, attachments?: any[]) => {
sendHandler(text, attachments);
sendMessage: (message: any, triggerTurn?: boolean) => {
sendMessageHandler(message, triggerTurn);
},
appendEntry: (customType: string, data?: any) => {
appendEntryHandler(customType, data);
},
registerMessageRenderer: (customType: string, renderer: any) => {
messageRenderers.set(customType, renderer);
},
registerCommand: (name: string, options: any) => {
commands.set(name, { name, ...options });
},
};
@ -359,8 +371,13 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
path: def.path ?? "<inline>",
resolvedPath: def.path ?? "<inline>",
handlers,
setSendHandler: (handler: (text: string, attachments?: any[]) => void) => {
sendHandler = handler;
messageRenderers,
commands,
setSendMessageHandler: (handler: (message: any, triggerTurn?: boolean) => void) => {
sendMessageHandler = handler;
},
setAppendEntryHandler: (handler: (customType: string, data?: any) => void) => {
appendEntryHandler = handler;
},
};
});
@ -513,11 +530,11 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
customToolsResult = result;
}
let hookRunner: HookRunner | null = null;
let hookRunner: HookRunner | undefined;
if (options.hooks !== undefined) {
if (options.hooks.length > 0) {
const loadedHooks = createLoadedHooksFromDefinitions(options.hooks);
hookRunner = new HookRunner(loadedHooks, cwd, settingsManager.getHookTimeout());
hookRunner = new HookRunner(loadedHooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout());
}
} else {
// Discover hooks, merging with additional paths
@ -528,7 +545,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
console.error(`Failed to load hook "${path}": ${error}`);
}
if (hooks.length > 0) {
hookRunner = new HookRunner(hooks, cwd, settingsManager.getHookTimeout());
hookRunner = new HookRunner(hooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout());
}
}
@ -571,21 +588,24 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
thinkingLevel,
tools: allToolsArray,
},
messageTransformer,
convertToLlm,
transformContext: hookRunner
? async (messages) => {
return hookRunner.emitContext(messages);
}
: undefined,
queueMode: settingsManager.getQueueMode(),
transport: new ProviderTransport({
getApiKey: async () => {
const currentModel = agent.state.model;
if (!currentModel) {
throw new Error("No model selected");
}
const key = await modelRegistry.getApiKey(currentModel);
if (!key) {
throw new Error(`No API key found for provider "${currentModel.provider}"`);
}
return key;
},
}),
getApiKey: async () => {
const currentModel = agent.state.model;
if (!currentModel) {
throw new Error("No model selected");
}
const key = await modelRegistry.getApiKey(currentModel);
if (!key) {
throw new Error(`No API key found for provider "${currentModel.provider}"`);
}
return key;
},
});
time("createAgent");

File diff suppressed because it is too large Load diff

View file

@ -8,6 +8,10 @@ export interface CompactionSettings {
keepRecentTokens?: number; // default: 20000
}
export interface BranchSummarySettings {
reserveTokens?: number; // default: 16384 (tokens reserved for prompt + LLM response)
}
export interface RetrySettings {
enabled?: boolean; // default: true
maxRetries?: number; // default: 3
@ -38,6 +42,7 @@ export interface Settings {
queueMode?: "all" | "one-at-a-time";
theme?: string;
compaction?: CompactionSettings;
branchSummary?: BranchSummarySettings;
retry?: RetrySettings;
hideThinkingBlock?: boolean;
shellPath?: string; // Custom shell path (e.g., for Cygwin users on Windows)
@ -255,6 +260,12 @@ export class SettingsManager {
};
}
getBranchSummarySettings(): { reserveTokens: number } {
return {
reserveTokens: this.settings.branchSummary?.reserveTokens ?? 16384,
};
}
getRetryEnabled(): boolean {
return this.settings.retry?.enabled ?? true;
}
@ -303,7 +314,7 @@ export class SettingsManager {
}
getHookPaths(): string[] {
return this.settings.hooks ?? [];
return [...(this.settings.hooks ?? [])];
}
setHookPaths(paths: string[]): void {
@ -321,7 +332,7 @@ export class SettingsManager {
}
getCustomToolPaths(): string[] {
return this.settings.customTools ?? [];
return [...(this.settings.customTools ?? [])];
}
setCustomToolPaths(paths: string[]): void {
@ -349,9 +360,9 @@ export class SettingsManager {
enableClaudeProject: this.settings.skills?.enableClaudeProject ?? true,
enablePiUser: this.settings.skills?.enablePiUser ?? true,
enablePiProject: this.settings.skills?.enablePiProject ?? true,
customDirectories: this.settings.skills?.customDirectories ?? [],
ignoredSkills: this.settings.skills?.ignoredSkills ?? [],
includeSkills: this.settings.skills?.includeSkills ?? [],
customDirectories: [...(this.settings.skills?.customDirectories ?? [])],
ignoredSkills: [...(this.settings.skills?.ignoredSkills ?? [])],
includeSkills: [...(this.settings.skills?.includeSkills ?? [])],
};
}

View file

@ -2,7 +2,7 @@ import { randomBytes } from "node:crypto";
import { createWriteStream } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { getShellConfig, killProcessTree } from "../../utils/shell.js";

View file

@ -1,4 +1,4 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import * as Diff from "diff";
import { constants } from "fs";
@ -23,8 +23,13 @@ function restoreLineEndings(text: string, ending: "\r\n" | "\n"): string {
/**
* Generate a unified diff string with line numbers and context
* Returns both the diff string and the first changed line number (in the new file)
*/
function generateDiffString(oldContent: string, newContent: string, contextLines = 4): string {
function generateDiffString(
oldContent: string,
newContent: string,
contextLines = 4,
): { diff: string; firstChangedLine: number | undefined } {
const parts = Diff.diffLines(oldContent, newContent);
const output: string[] = [];
@ -36,6 +41,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
let oldLineNum = 1;
let newLineNum = 1;
let lastWasChange = false;
let firstChangedLine: number | undefined;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
@ -45,6 +51,11 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
}
if (part.added || part.removed) {
// Capture the first changed line (in the new file)
if (firstChangedLine === undefined) {
firstChangedLine = newLineNum;
}
// Show the change
for (const line of raw) {
if (part.added) {
@ -113,7 +124,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
}
}
return output.join("\n");
return { diff: output.join("\n"), firstChangedLine };
}
const editSchema = Type.Object({
@ -122,6 +133,13 @@ const editSchema = Type.Object({
newText: Type.String({ description: "New text to replace the old text with" }),
});
export interface EditToolDetails {
/** Unified diff of the changes made */
diff: string;
/** Line number of the first change in the new file (for editor navigation) */
firstChangedLine?: number;
}
export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
return {
name: "edit",
@ -138,7 +156,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
return new Promise<{
content: Array<{ type: "text"; text: string }>;
details: { diff: string } | undefined;
details: EditToolDetails | undefined;
}>((resolve, reject) => {
// Check if already aborted
if (signal?.aborted) {
@ -257,6 +275,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
signal.removeEventListener("abort", onAbort);
}
const diffResult = generateDiffString(normalizedContent, normalizedNewContent);
resolve({
content: [
{
@ -264,7 +283,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
text: `Successfully replaced text in ${path}.`,
},
],
details: { diff: generateDiffString(normalizedContent, normalizedNewContent) },
details: { diff: diffResult.diff, firstChangedLine: diffResult.firstChangedLine },
});
} catch (error: any) {
// Clean up abort handler

View file

@ -1,4 +1,4 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import { spawnSync } from "child_process";
import { existsSync } from "fs";

View file

@ -1,5 +1,5 @@
import { createInterface } from "node:readline";
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import { spawn } from "child_process";
import { readFileSync, type Stats, statSync } from "fs";

View file

@ -1,5 +1,3 @@
import type { AgentTool } from "@mariozechner/pi-ai";
export { type BashToolDetails, bashTool, createBashTool } from "./bash.js";
export { createEditTool, editTool } from "./edit.js";
export { createFindTool, type FindToolDetails, findTool } from "./find.js";
@ -9,6 +7,7 @@ export { createReadTool, type ReadToolDetails, readTool } from "./read.js";
export type { TruncationResult } from "./truncate.js";
export { createWriteTool, writeTool } from "./write.js";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { bashTool, createBashTool } from "./bash.js";
import { createEditTool, editTool } from "./edit.js";
import { createFindTool, findTool } from "./find.js";

View file

@ -1,4 +1,4 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import { existsSync, readdirSync, statSync } from "fs";
import nodePath from "path";

View file

@ -1,4 +1,5 @@
import type { AgentTool, ImageContent, TextContent } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { constants } from "fs";
import { access, readFile } from "fs/promises";

View file

@ -1,4 +1,4 @@
import type { AgentTool } from "@mariozechner/pi-ai";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Type } from "@sinclair/typebox";
import { mkdir, writeFile } from "fs/promises";
import { dirname } from "path";

View file

@ -4,7 +4,6 @@ export {
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type CompactionResult,
type ModelCycleResult,
type PromptOptions,
type SessionStats,
@ -13,17 +12,26 @@ export {
export { type ApiKeyCredential, type AuthCredential, AuthStorage, type OAuthCredential } from "./core/auth-storage.js";
// Compaction
export {
type BranchPreparation,
type BranchSummaryResult,
type CollectEntriesResult,
type CompactionResult,
type CutPointResult,
calculateContextTokens,
collectEntriesForBranchSummary,
compact,
DEFAULT_COMPACTION_SETTINGS,
estimateTokens,
type FileOperations,
findCutPoint,
findTurnStartIndex,
type GenerateBranchSummaryOptions,
generateBranchSummary,
generateSummary,
getLastAssistantUsage,
prepareBranchEntries,
shouldCompact,
} from "./core/compaction.js";
} from "./core/compaction/index.js";
// Custom tools
export type {
AgentToolUpdateCallback,
@ -38,31 +46,7 @@ export type {
ToolUIContext,
} from "./core/custom-tools/index.js";
export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js";
export type {
AgentEndEvent,
AgentStartEvent,
BashToolResultEvent,
CustomToolResultEvent,
EditToolResultEvent,
FindToolResultEvent,
GrepToolResultEvent,
HookAPI,
HookEvent,
HookEventContext,
HookFactory,
HookUIContext,
LsToolResultEvent,
ReadToolResultEvent,
SessionEvent,
SessionEventResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEvent,
ToolResultEventResult,
TurnEndEvent,
TurnStartEvent,
WriteToolResultEvent,
} from "./core/hooks/index.js";
export type * from "./core/hooks/index.js";
// Hook system types and type guards
export {
isBashToolResult,
@ -73,7 +57,7 @@ export {
isReadToolResult,
isWriteToolResult,
} from "./core/hooks/index.js";
export { messageTransformer } from "./core/messages.js";
export { convertToLlm } from "./core/messages.js";
export { ModelRegistry } from "./core/model-registry.js";
// SDK for programmatic usage
export {
@ -107,20 +91,24 @@ export {
readOnlyTools,
} from "./core/sdk.js";
export {
type BranchSummaryEntry,
buildSessionContext,
type CompactionEntry,
createSummaryMessage,
CURRENT_SESSION_VERSION,
type CustomEntry,
type CustomMessageEntry,
type FileEntry,
getLatestCompactionEntry,
type ModelChangeEntry,
migrateSessionEntries,
parseSessionEntries,
type SessionContext as LoadedSession,
type SessionContext,
type SessionEntry,
type SessionEntryBase,
type SessionHeader,
type SessionInfo,
SessionManager,
type SessionMessageEntry,
SUMMARY_PREFIX,
SUMMARY_SUFFIX,
type ThinkingLevelChangeEntry,
} from "./core/session-manager.js";
export {

View file

@ -5,8 +5,7 @@
* createAgentSession() options. The SDK does the heavy lifting.
*/
import type { Attachment } from "@mariozechner/pi-agent-core";
import { supportsXhigh } from "@mariozechner/pi-ai";
import { type ImageContent, supportsXhigh } from "@mariozechner/pi-ai";
import chalk from "chalk";
import { existsSync } from "fs";
import { join } from "path";
@ -34,10 +33,10 @@ import { initTheme, stopThemeWatcher } from "./modes/interactive/theme/theme.js"
import { getChangelogPath, getNewEntries, parseChangelog } from "./utils/changelog.js";
import { ensureTool } from "./utils/tools-manager.js";
async function checkForNewVersion(currentVersion: string): Promise<string | null> {
async function checkForNewVersion(currentVersion: string): Promise<string | undefined> {
try {
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi -coding-agent/latest");
if (!response.ok) return null;
if (!response.ok) return undefined;
const data = (await response.json()) as { version?: string };
const latestVersion = data.version;
@ -46,26 +45,26 @@ async function checkForNewVersion(currentVersion: string): Promise<string | null
return latestVersion;
}
return null;
return undefined;
} catch {
return null;
return undefined;
}
}
async function runInteractiveMode(
session: AgentSession,
version: string,
changelogMarkdown: string | null,
changelogMarkdown: string | undefined,
modelFallbackMessage: string | undefined,
modelsJsonError: string | null,
modelsJsonError: string | undefined,
migratedProviders: string[],
versionCheckPromise: Promise<string | null>,
versionCheckPromise: Promise<string | undefined>,
initialMessages: string[],
customTools: LoadedCustomTool[],
setToolUIContext: (uiContext: HookUIContext, hasUI: boolean) => void,
initialMessage?: string,
initialAttachments?: Attachment[],
fdPath: string | null = null,
initialImages?: ImageContent[],
fdPath: string | undefined = undefined,
): Promise<void> {
const mode = new InteractiveMode(session, version, changelogMarkdown, customTools, setToolUIContext, fdPath);
@ -77,7 +76,7 @@ async function runInteractiveMode(
}
});
mode.renderInitialMessages(session.state);
mode.renderInitialMessages();
if (migratedProviders.length > 0) {
mode.showWarning(`Migrated credentials to auth.json: ${migratedProviders.join(", ")}`);
@ -93,7 +92,7 @@ async function runInteractiveMode(
if (initialMessage) {
try {
await session.prompt(initialMessage, { attachments: initialAttachments });
await session.prompt(initialMessage, { images: initialImages });
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : "Unknown error occurred";
mode.showError(errorMessage);
@ -122,31 +121,31 @@ async function runInteractiveMode(
async function prepareInitialMessage(parsed: Args): Promise<{
initialMessage?: string;
initialAttachments?: Attachment[];
initialImages?: ImageContent[];
}> {
if (parsed.fileArgs.length === 0) {
return {};
}
const { textContent, imageAttachments } = await processFileArguments(parsed.fileArgs);
const { text, images } = await processFileArguments(parsed.fileArgs);
let initialMessage: string;
if (parsed.messages.length > 0) {
initialMessage = textContent + parsed.messages[0];
initialMessage = text + parsed.messages[0];
parsed.messages.shift();
} else {
initialMessage = textContent;
initialMessage = text;
}
return {
initialMessage,
initialAttachments: imageAttachments.length > 0 ? imageAttachments : undefined,
initialImages: images.length > 0 ? images : undefined,
};
}
function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | null {
function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | undefined {
if (parsed.continue || parsed.resume) {
return null;
return undefined;
}
const lastVersion = settingsManager.getLastChangelogVersion();
@ -166,10 +165,10 @@ function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager):
}
}
return null;
return undefined;
}
function createSessionManager(parsed: Args, cwd: string): SessionManager | null {
function createSessionManager(parsed: Args, cwd: string): SessionManager | undefined {
if (parsed.noSession) {
return SessionManager.inMemory();
}
@ -184,8 +183,8 @@ function createSessionManager(parsed: Args, cwd: string): SessionManager | null
if (parsed.sessionDir) {
return SessionManager.create(cwd, parsed.sessionDir);
}
// Default case (new session) returns null, SDK will create one
return null;
// Default case (new session) returns undefined, SDK will create one
return undefined;
}
/** Discover SYSTEM.md file if no CLI system prompt was provided */
@ -208,7 +207,7 @@ function discoverSystemPromptFile(): string | undefined {
function buildSessionOptions(
parsed: Args,
scopedModels: ScopedModel[],
sessionManager: SessionManager | null,
sessionManager: SessionManager | undefined,
modelRegistry: ModelRegistry,
): CreateAgentSessionOptions {
const options: CreateAgentSessionOptions = {};
@ -330,7 +329,7 @@ export async function main(args: string[]) {
}
const cwd = process.cwd();
const { initialMessage, initialAttachments } = await prepareInitialMessage(parsed);
const { initialMessage, initialImages } = await prepareInitialMessage(parsed);
time("prepareInitialMessage");
const isInteractive = !parsed.print && parsed.mode === undefined;
const mode = parsed.mode || "text";
@ -409,7 +408,7 @@ export async function main(args: string[]) {
if (mode === "rpc") {
await runRpcMode(session);
} else if (isInteractive) {
const versionCheckPromise = checkForNewVersion(VERSION).catch(() => null);
const versionCheckPromise = checkForNewVersion(VERSION).catch(() => undefined);
const changelogMarkdown = getChangelogForDisplay(parsed, settingsManager);
if (scopedModels.length > 0) {
@ -438,11 +437,11 @@ export async function main(args: string[]) {
customToolsResult.tools,
customToolsResult.setUIContext,
initialMessage,
initialAttachments,
initialImages,
fdPath,
);
} else {
await runPrintMode(session, mode, parsed.messages, initialMessage, initialAttachments);
await runPrintMode(session, mode, parsed.messages, initialMessage, initialImages);
stopThemeWatcher();
if (process.stdout.writableLength > 0) {
await new Promise<void>((resolve) => process.stdout.once("drain", resolve));

View file

@ -21,7 +21,7 @@ export class BashExecutionComponent extends Container {
private command: string;
private outputLines: string[] = [];
private status: "running" | "complete" | "cancelled" | "error" = "running";
private exitCode: number | null = null;
private exitCode: number | undefined = undefined;
private loader: Loader;
private truncationResult?: TruncationResult;
private fullOutputPath?: string;
@ -90,13 +90,17 @@ export class BashExecutionComponent extends Container {
}
setComplete(
exitCode: number | null,
exitCode: number | undefined,
cancelled: boolean,
truncationResult?: TruncationResult,
fullOutputPath?: string,
): void {
this.exitCode = exitCode;
this.status = cancelled ? "cancelled" : exitCode !== 0 && exitCode !== null ? "error" : "complete";
this.status = cancelled
? "cancelled"
: exitCode !== 0 && exitCode !== undefined && exitCode !== null
? "error"
: "complete";
this.truncationResult = truncationResult;
this.fullOutputPath = fullOutputPath;

View file

@ -0,0 +1,42 @@
import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
import type { BranchSummaryMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a branch summary message with collapsed/expanded state.
* Uses same background color as hook messages for visual consistency.
*/
export class BranchSummaryMessageComponent extends Box {
private expanded = false;
private message: BranchSummaryMessage;
constructor(message: BranchSummaryMessage) {
super(1, 1, (t) => theme.bg("customMessageBg", t));
this.message = message;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
const label = theme.fg("customMessageLabel", `\x1b[1m[branch]\x1b[22m`);
this.addChild(new Text(label, 0, 0));
this.addChild(new Spacer(1));
if (this.expanded) {
const header = "**Branch Summary**\n\n";
this.addChild(
new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
} else {
this.addChild(new Text(theme.fg("customMessageText", "Branch summary (ctrl+o to expand)"), 0, 0));
}
}
}

View file

@ -0,0 +1,45 @@
import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
import type { CompactionSummaryMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a compaction message with collapsed/expanded state.
* Uses same background color as hook messages for visual consistency.
*/
export class CompactionSummaryMessageComponent extends Box {
private expanded = false;
private message: CompactionSummaryMessage;
constructor(message: CompactionSummaryMessage) {
super(1, 1, (t) => theme.bg("customMessageBg", t));
this.message = message;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
const tokenStr = this.message.tokensBefore.toLocaleString();
const label = theme.fg("customMessageLabel", `\x1b[1m[compaction]\x1b[22m`);
this.addChild(new Text(label, 0, 0));
this.addChild(new Spacer(1));
if (this.expanded) {
const header = `**Compacted from ${tokenStr} tokens**\n\n`;
this.addChild(
new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
} else {
this.addChild(
new Text(theme.fg("customMessageText", `Compacted from ${tokenStr} tokens (ctrl+o to expand)`), 0, 0),
);
}
}
}

View file

@ -1,52 +0,0 @@
import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a compaction indicator with collapsed/expanded state.
* Collapsed: shows "Context compacted from X tokens"
* Expanded: shows the full summary rendered as markdown (like a user message)
*/
export class CompactionComponent extends Container {
private expanded = false;
private tokensBefore: number;
private summary: string;
constructor(tokensBefore: number, summary: string) {
super();
this.tokensBefore = tokensBefore;
this.summary = summary;
this.updateDisplay();
}
setExpanded(expanded: boolean): void {
this.expanded = expanded;
this.updateDisplay();
}
private updateDisplay(): void {
this.clear();
if (this.expanded) {
// Show header + summary as markdown (like user message)
this.addChild(new Spacer(1));
const header = `**Context compacted from ${this.tokensBefore.toLocaleString()} tokens**\n\n`;
this.addChild(
new Markdown(header + this.summary, 1, 1, getMarkdownTheme(), {
bgColor: (text: string) => theme.bg("userMessageBg", text),
color: (text: string) => theme.fg("userMessageText", text),
}),
);
this.addChild(new Spacer(1));
} else {
// Collapsed: simple text in warning color with token count
const tokenStr = this.tokensBefore.toLocaleString();
this.addChild(
new Text(
theme.fg("warning", `Earlier messages compacted from ${tokenStr} tokens (ctrl+o to expand)`),
1,
1,
),
);
}
}
}

View file

@ -0,0 +1,96 @@
import type { TextContent } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui";
import { Box, Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
import type { HookMessageRenderer } from "../../../core/hooks/types.js";
import type { HookMessage } from "../../../core/messages.js";
import { getMarkdownTheme, theme } from "../theme/theme.js";
/**
* Component that renders a custom message entry from hooks.
* Uses distinct styling to differentiate from user messages.
*/
export class HookMessageComponent extends Container {
private message: HookMessage<unknown>;
private customRenderer?: HookMessageRenderer;
private box: Box;
private customComponent?: Component;
private _expanded = false;
constructor(message: HookMessage<unknown>, customRenderer?: HookMessageRenderer) {
super();
this.message = message;
this.customRenderer = customRenderer;
this.addChild(new Spacer(1));
// Create box with purple background (used for default rendering)
this.box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
this.rebuild();
}
setExpanded(expanded: boolean): void {
if (this._expanded !== expanded) {
this._expanded = expanded;
this.rebuild();
}
}
private rebuild(): void {
// Remove previous content component
if (this.customComponent) {
this.removeChild(this.customComponent);
this.customComponent = undefined;
}
this.removeChild(this.box);
// Try custom renderer first - it handles its own styling
if (this.customRenderer) {
try {
const component = this.customRenderer(this.message, { expanded: this._expanded }, theme);
if (component) {
// Custom renderer provides its own styled component
this.customComponent = component;
this.addChild(component);
return;
}
} catch {
// Fall through to default rendering
}
}
// Default rendering uses our box
this.addChild(this.box);
this.box.clear();
// Default rendering: label + content
const label = theme.fg("customMessageLabel", `\x1b[1m[${this.message.customType}]\x1b[22m`);
this.box.addChild(new Text(label, 0, 0));
this.box.addChild(new Spacer(1));
// Extract text content
let text: string;
if (typeof this.message.content === "string") {
text = this.message.content;
} else {
text = this.message.content
.filter((c): c is TextContent => c.type === "text")
.map((c) => c.text)
.join("\n");
}
// Limit lines when collapsed
if (!this._expanded) {
const lines = text.split("\n");
if (lines.length > 5) {
text = `${lines.slice(0, 5).join("\n")}\n...`;
}
}
this.box.addChild(
new Markdown(text, 0, 0, getMarkdownTheme(), {
color: (text: string) => theme.fg("customMessageText", text),
}),
);
}
}

View file

@ -36,18 +36,18 @@ export class ModelSelectorComponent extends Container {
private allModels: ModelItem[] = [];
private filteredModels: ModelItem[] = [];
private selectedIndex: number = 0;
private currentModel: Model<any> | null;
private currentModel?: Model<any>;
private settingsManager: SettingsManager;
private modelRegistry: ModelRegistry;
private onSelectCallback: (model: Model<any>) => void;
private onCancelCallback: () => void;
private errorMessage: string | null = null;
private errorMessage?: string;
private tui: TUI;
private scopedModels: ReadonlyArray<ScopedModelItem>;
constructor(
tui: TUI,
currentModel: Model<any> | null,
currentModel: Model<any> | undefined,
settingsManager: SettingsManager,
modelRegistry: ModelRegistry,
scopedModels: ReadonlyArray<ScopedModelItem>,

View file

@ -415,10 +415,14 @@ export class ToolExecutionComponent extends Container {
} else if (this.toolName === "edit") {
const rawPath = this.args?.file_path || this.args?.path || "";
const path = shortenPath(rawPath);
text =
theme.fg("toolTitle", theme.bold("edit")) +
" " +
(path ? theme.fg("accent", path) : theme.fg("toolOutput", "..."));
// Build path display, appending :line if we have a successful result with line info
let pathDisplay = path ? theme.fg("accent", path) : theme.fg("toolOutput", "...");
if (this.result && !this.result.isError && this.result.details?.firstChangedLine) {
pathDisplay += theme.fg("warning", `:${this.result.details.firstChangedLine}`);
}
text = `${theme.fg("toolTitle", theme.bold("edit"))} ${pathDisplay}`;
if (this.result) {
if (this.result.isError) {

View file

@ -0,0 +1,866 @@
import {
type Component,
Container,
Input,
isArrowDown,
isArrowLeft,
isArrowRight,
isArrowUp,
isBackspace,
isCtrlC,
isCtrlO,
isEnter,
isEscape,
isShiftCtrlO,
Spacer,
Text,
TruncatedText,
truncateToWidth,
} from "@mariozechner/pi-tui";
import type { SessionTreeNode } from "../../../core/session-manager.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
/** Gutter info: position (displayIndent where connector was) and whether to show │ */
interface GutterInfo {
position: number; // displayIndent level where the connector was shown
show: boolean; // true = show │, false = show spaces
}
/** Flattened tree node for navigation */
interface FlatNode {
node: SessionTreeNode;
/** Indentation level (each level = 3 chars) */
indent: number;
/** Whether to show connector (├─ or └─) - true if parent has multiple children */
showConnector: boolean;
/** If showConnector, true = last sibling (└─), false = not last (├─) */
isLast: boolean;
/** Gutter info for each ancestor branch point */
gutters: GutterInfo[];
/** True if this node is a root under a virtual branching root (multiple roots) */
isVirtualRootChild: boolean;
}
/** Filter mode for tree display */
type FilterMode = "default" | "no-tools" | "user-only" | "labeled-only" | "all";
/**
* Tree list component with selection and ASCII art visualization
*/
/** Tool call info for lookup */
interface ToolCallInfo {
name: string;
arguments: Record<string, unknown>;
}
class TreeList implements Component {
private flatNodes: FlatNode[] = [];
private filteredNodes: FlatNode[] = [];
private selectedIndex = 0;
private currentLeafId: string | null;
private maxVisibleLines: number;
private filterMode: FilterMode = "default";
private searchQuery = "";
private toolCallMap: Map<string, ToolCallInfo> = new Map();
private multipleRoots = false;
private activePathIds: Set<string> = new Set();
public onSelect?: (entryId: string) => void;
public onCancel?: () => void;
public onLabelEdit?: (entryId: string, currentLabel: string | undefined) => void;
constructor(tree: SessionTreeNode[], currentLeafId: string | null, maxVisibleLines: number) {
this.currentLeafId = currentLeafId;
this.maxVisibleLines = maxVisibleLines;
this.multipleRoots = tree.length > 1;
this.flatNodes = this.flattenTree(tree);
this.buildActivePath();
this.applyFilter();
// Start with current leaf selected
const leafIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === currentLeafId);
if (leafIndex !== -1) {
this.selectedIndex = leafIndex;
} else {
this.selectedIndex = Math.max(0, this.filteredNodes.length - 1);
}
}
/** Build the set of entry IDs on the path from root to current leaf */
private buildActivePath(): void {
this.activePathIds.clear();
if (!this.currentLeafId) return;
// Build a map of id -> entry for parent lookup
const entryMap = new Map<string, FlatNode>();
for (const flatNode of this.flatNodes) {
entryMap.set(flatNode.node.entry.id, flatNode);
}
// Walk from leaf to root
let currentId: string | null = this.currentLeafId;
while (currentId) {
this.activePathIds.add(currentId);
const node = entryMap.get(currentId);
if (!node) break;
currentId = node.node.entry.parentId ?? null;
}
}
private flattenTree(roots: SessionTreeNode[]): FlatNode[] {
const result: FlatNode[] = [];
this.toolCallMap.clear();
// Indentation rules:
// - At indent 0: stay at 0 unless parent has >1 children (then +1)
// - At indent 1: children always go to indent 2 (visual grouping of subtree)
// - At indent 2+: stay flat for single-child chains, +1 only if parent branches
// Stack items: [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild]
type StackItem = [SessionTreeNode, number, boolean, boolean, boolean, GutterInfo[], boolean];
const stack: StackItem[] = [];
// Determine which subtrees contain the active leaf (to sort current branch first)
// Use iterative post-order traversal to avoid stack overflow
const containsActive = new Map<SessionTreeNode, boolean>();
const leafId = this.currentLeafId;
{
// Build list in pre-order, then process in reverse for post-order effect
const allNodes: SessionTreeNode[] = [];
const preOrderStack: SessionTreeNode[] = [...roots];
while (preOrderStack.length > 0) {
const node = preOrderStack.pop()!;
allNodes.push(node);
// Push children in reverse so they're processed left-to-right
for (let i = node.children.length - 1; i >= 0; i--) {
preOrderStack.push(node.children[i]);
}
}
// Process in reverse (post-order): children before parents
for (let i = allNodes.length - 1; i >= 0; i--) {
const node = allNodes[i];
let has = leafId !== null && node.entry.id === leafId;
for (const child of node.children) {
if (containsActive.get(child)) {
has = true;
}
}
containsActive.set(node, has);
}
}
// Add roots in reverse order, prioritizing the one containing the active leaf
// If multiple roots, treat them as children of a virtual root that branches
const multipleRoots = roots.length > 1;
const orderedRoots = [...roots].sort((a, b) => Number(containsActive.get(b)) - Number(containsActive.get(a)));
for (let i = orderedRoots.length - 1; i >= 0; i--) {
const isLast = i === orderedRoots.length - 1;
stack.push([orderedRoots[i], multipleRoots ? 1 : 0, multipleRoots, multipleRoots, isLast, [], multipleRoots]);
}
while (stack.length > 0) {
const [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild] = stack.pop()!;
// Extract tool calls from assistant messages for later lookup
const entry = node.entry;
if (entry.type === "message" && entry.message.role === "assistant") {
const content = (entry.message as { content?: unknown }).content;
if (Array.isArray(content)) {
for (const block of content) {
if (typeof block === "object" && block !== null && "type" in block && block.type === "toolCall") {
const tc = block as { id: string; name: string; arguments: Record<string, unknown> };
this.toolCallMap.set(tc.id, { name: tc.name, arguments: tc.arguments });
}
}
}
}
result.push({ node, indent, showConnector, isLast, gutters, isVirtualRootChild });
const children = node.children;
const multipleChildren = children.length > 1;
// Order children so the branch containing the active leaf comes first
const orderedChildren = (() => {
const prioritized: SessionTreeNode[] = [];
const rest: SessionTreeNode[] = [];
for (const child of children) {
if (containsActive.get(child)) {
prioritized.push(child);
} else {
rest.push(child);
}
}
return [...prioritized, ...rest];
})();
// Calculate child indent
let childIndent: number;
if (multipleChildren) {
// Parent branches: children get +1
childIndent = indent + 1;
} else if (justBranched && indent > 0) {
// First generation after a branch: +1 for visual grouping
childIndent = indent + 1;
} else {
// Single-child chain: stay flat
childIndent = indent;
}
// Build gutters for children
// If this node showed a connector, add a gutter entry for descendants
// Only add gutter if connector is actually displayed (not suppressed for virtual root children)
const connectorDisplayed = showConnector && !isVirtualRootChild;
// When connector is displayed, add a gutter entry at the connector's position
// Connector is at position (displayIndent - 1), so gutter should be there too
const currentDisplayIndent = this.multipleRoots ? Math.max(0, indent - 1) : indent;
const connectorPosition = Math.max(0, currentDisplayIndent - 1);
const childGutters: GutterInfo[] = connectorDisplayed
? [...gutters, { position: connectorPosition, show: !isLast }]
: gutters;
// Add children in reverse order
for (let i = orderedChildren.length - 1; i >= 0; i--) {
const childIsLast = i === orderedChildren.length - 1;
stack.push([
orderedChildren[i],
childIndent,
multipleChildren,
multipleChildren,
childIsLast,
childGutters,
false,
]);
}
}
return result;
}
private applyFilter(): void {
// Remember currently selected node to preserve cursor position
const previouslySelectedId = this.filteredNodes[this.selectedIndex]?.node.entry.id;
const searchTokens = this.searchQuery.toLowerCase().split(/\s+/).filter(Boolean);
this.filteredNodes = this.flatNodes.filter((flatNode) => {
const entry = flatNode.node.entry;
const isCurrentLeaf = entry.id === this.currentLeafId;
// Skip assistant messages with only tool calls (no text) unless error/aborted
// Always show current leaf so active position is visible
if (entry.type === "message" && entry.message.role === "assistant" && !isCurrentLeaf) {
const msg = entry.message as { stopReason?: string; content?: unknown };
const hasText = this.hasTextContent(msg.content);
const isErrorOrAborted = msg.stopReason && msg.stopReason !== "stop" && msg.stopReason !== "toolUse";
// Only hide if no text AND not an error/aborted message
if (!hasText && !isErrorOrAborted) {
return false;
}
}
// Apply filter mode
let passesFilter = true;
// Entry types hidden in default view (settings/bookkeeping)
const isSettingsEntry =
entry.type === "label" ||
entry.type === "custom" ||
entry.type === "model_change" ||
entry.type === "thinking_level_change";
switch (this.filterMode) {
case "user-only":
// Just user messages
passesFilter = entry.type === "message" && entry.message.role === "user";
break;
case "no-tools":
// Default minus tool results
passesFilter = !isSettingsEntry && !(entry.type === "message" && entry.message.role === "toolResult");
break;
case "labeled-only":
// Just labeled entries
passesFilter = flatNode.node.label !== undefined;
break;
case "all":
// Show everything
passesFilter = true;
break;
default:
// Default mode: hide settings/bookkeeping entries
passesFilter = !isSettingsEntry;
break;
}
if (!passesFilter) return false;
// Apply search filter
if (searchTokens.length > 0) {
const nodeText = this.getSearchableText(flatNode.node).toLowerCase();
return searchTokens.every((token) => nodeText.includes(token));
}
return true;
});
// Try to preserve cursor on the same node after filtering
if (previouslySelectedId) {
const newIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === previouslySelectedId);
if (newIndex !== -1) {
this.selectedIndex = newIndex;
return;
}
}
// Fall back: clamp index if out of bounds
if (this.selectedIndex >= this.filteredNodes.length) {
this.selectedIndex = Math.max(0, this.filteredNodes.length - 1);
}
}
/** Get searchable text content from a node */
private getSearchableText(node: SessionTreeNode): string {
const entry = node.entry;
const parts: string[] = [];
if (node.label) {
parts.push(node.label);
}
switch (entry.type) {
case "message": {
const msg = entry.message;
parts.push(msg.role);
if ("content" in msg && msg.content) {
parts.push(this.extractContent(msg.content));
}
if (msg.role === "bashExecution") {
const bashMsg = msg as { command?: string };
if (bashMsg.command) parts.push(bashMsg.command);
}
break;
}
case "custom_message": {
parts.push(entry.customType);
if (typeof entry.content === "string") {
parts.push(entry.content);
} else {
parts.push(this.extractContent(entry.content));
}
break;
}
case "compaction":
parts.push("compaction");
break;
case "branch_summary":
parts.push("branch summary", entry.summary);
break;
case "model_change":
parts.push("model", entry.modelId);
break;
case "thinking_level_change":
parts.push("thinking", entry.thinkingLevel);
break;
case "custom":
parts.push("custom", entry.customType);
break;
case "label":
parts.push("label", entry.label ?? "");
break;
}
return parts.join(" ");
}
invalidate(): void {}
getSearchQuery(): string {
return this.searchQuery;
}
getSelectedNode(): SessionTreeNode | undefined {
return this.filteredNodes[this.selectedIndex]?.node;
}
updateNodeLabel(entryId: string, label: string | undefined): void {
for (const flatNode of this.flatNodes) {
if (flatNode.node.entry.id === entryId) {
flatNode.node.label = label;
break;
}
}
}
private getFilterLabel(): string {
switch (this.filterMode) {
case "no-tools":
return " [no-tools]";
case "user-only":
return " [user]";
case "labeled-only":
return " [labeled]";
case "all":
return " [all]";
default:
return "";
}
}
render(width: number): string[] {
const lines: string[] = [];
if (this.filteredNodes.length === 0) {
lines.push(truncateToWidth(theme.fg("muted", " No entries found"), width));
lines.push(truncateToWidth(theme.fg("muted", ` (0/0)${this.getFilterLabel()}`), width));
return lines;
}
const startIndex = Math.max(
0,
Math.min(
this.selectedIndex - Math.floor(this.maxVisibleLines / 2),
this.filteredNodes.length - this.maxVisibleLines,
),
);
const endIndex = Math.min(startIndex + this.maxVisibleLines, this.filteredNodes.length);
for (let i = startIndex; i < endIndex; i++) {
const flatNode = this.filteredNodes[i];
const entry = flatNode.node.entry;
const isSelected = i === this.selectedIndex;
// Build line: cursor + prefix + path marker + label + content
const cursor = isSelected ? theme.fg("accent", " ") : " ";
// If multiple roots, shift display (roots at 0, not 1)
const displayIndent = this.multipleRoots ? Math.max(0, flatNode.indent - 1) : flatNode.indent;
// Build prefix with gutters at their correct positions
// Each gutter has a position (displayIndent where its connector was shown)
const connector =
flatNode.showConnector && !flatNode.isVirtualRootChild ? (flatNode.isLast ? "└─ " : "├─ ") : "";
const connectorPosition = connector ? displayIndent - 1 : -1;
// Build prefix char by char, placing gutters and connector at their positions
const totalChars = displayIndent * 3;
const prefixChars: string[] = [];
for (let i = 0; i < totalChars; i++) {
const level = Math.floor(i / 3);
const posInLevel = i % 3;
// Check if there's a gutter at this level
const gutter = flatNode.gutters.find((g) => g.position === level);
if (gutter) {
if (posInLevel === 0) {
prefixChars.push(gutter.show ? "│" : " ");
} else {
prefixChars.push(" ");
}
} else if (connector && level === connectorPosition) {
// Connector at this level
if (posInLevel === 0) {
prefixChars.push(flatNode.isLast ? "└" : "├");
} else if (posInLevel === 1) {
prefixChars.push("─");
} else {
prefixChars.push(" ");
}
} else {
prefixChars.push(" ");
}
}
const prefix = prefixChars.join("");
// Active path marker - shown right before the entry text
const isOnActivePath = this.activePathIds.has(entry.id);
const pathMarker = isOnActivePath ? theme.fg("accent", "• ") : "";
const label = flatNode.node.label ? theme.fg("warning", `[${flatNode.node.label}] `) : "";
const content = this.getEntryDisplayText(flatNode.node, isSelected);
let line = cursor + theme.fg("dim", prefix) + pathMarker + label + content;
if (isSelected) {
line = theme.bg("selectedBg", line);
}
lines.push(truncateToWidth(line, width));
}
lines.push(
truncateToWidth(
theme.fg("muted", ` (${this.selectedIndex + 1}/${this.filteredNodes.length})${this.getFilterLabel()}`),
width,
),
);
return lines;
}
private getEntryDisplayText(node: SessionTreeNode, isSelected: boolean): string {
const entry = node.entry;
let result: string;
const normalize = (s: string) => s.replace(/[\n\t]/g, " ").trim();
switch (entry.type) {
case "message": {
const msg = entry.message;
const role = msg.role;
if (role === "user") {
const msgWithContent = msg as { content?: unknown };
const content = normalize(this.extractContent(msgWithContent.content));
result = theme.fg("accent", "user: ") + content;
} else if (role === "assistant") {
const msgWithContent = msg as { content?: unknown; stopReason?: string; errorMessage?: string };
const textContent = normalize(this.extractContent(msgWithContent.content));
if (textContent) {
result = theme.fg("success", "assistant: ") + textContent;
} else if (msgWithContent.stopReason === "aborted") {
result = theme.fg("success", "assistant: ") + theme.fg("muted", "(aborted)");
} else if (msgWithContent.errorMessage) {
const errMsg = normalize(msgWithContent.errorMessage).slice(0, 80);
result = theme.fg("success", "assistant: ") + theme.fg("error", errMsg);
} else {
result = theme.fg("success", "assistant: ") + theme.fg("muted", "(no content)");
}
} else if (role === "toolResult") {
const toolMsg = msg as { toolCallId?: string; toolName?: string };
const toolCall = toolMsg.toolCallId ? this.toolCallMap.get(toolMsg.toolCallId) : undefined;
if (toolCall) {
result = theme.fg("muted", this.formatToolCall(toolCall.name, toolCall.arguments));
} else {
result = theme.fg("muted", `[${toolMsg.toolName ?? "tool"}]`);
}
} else if (role === "bashExecution") {
const bashMsg = msg as { command?: string };
result = theme.fg("dim", `[bash]: ${normalize(bashMsg.command ?? "")}`);
} else {
result = theme.fg("dim", `[${role}]`);
}
break;
}
case "custom_message": {
const content =
typeof entry.content === "string"
? entry.content
: entry.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
result = theme.fg("customMessageLabel", `[${entry.customType}]: `) + normalize(content);
break;
}
case "compaction": {
const tokens = Math.round(entry.tokensBefore / 1000);
result = theme.fg("borderAccent", `[compaction: ${tokens}k tokens]`);
break;
}
case "branch_summary":
result = theme.fg("warning", `[branch summary]: `) + normalize(entry.summary);
break;
case "model_change":
result = theme.fg("dim", `[model: ${entry.modelId}]`);
break;
case "thinking_level_change":
result = theme.fg("dim", `[thinking: ${entry.thinkingLevel}]`);
break;
case "custom":
result = theme.fg("dim", `[custom: ${entry.customType}]`);
break;
case "label":
result = theme.fg("dim", `[label: ${entry.label ?? "(cleared)"}]`);
break;
default:
result = "";
}
return isSelected ? theme.bold(result) : result;
}
private extractContent(content: unknown): string {
const maxLen = 200;
if (typeof content === "string") return content.slice(0, maxLen);
if (Array.isArray(content)) {
let result = "";
for (const c of content) {
if (typeof c === "object" && c !== null && "type" in c && c.type === "text") {
result += (c as { text: string }).text;
if (result.length >= maxLen) return result.slice(0, maxLen);
}
}
return result;
}
return "";
}
private hasTextContent(content: unknown): boolean {
if (typeof content === "string") return content.trim().length > 0;
if (Array.isArray(content)) {
for (const c of content) {
if (typeof c === "object" && c !== null && "type" in c && c.type === "text") {
const text = (c as { text?: string }).text;
if (text && text.trim().length > 0) return true;
}
}
}
return false;
}
private formatToolCall(name: string, args: Record<string, unknown>): string {
const shortenPath = (p: string): string => {
const home = process.env.HOME || process.env.USERPROFILE || "";
if (home && p.startsWith(home)) return `~${p.slice(home.length)}`;
return p;
};
switch (name) {
case "read": {
const path = shortenPath(String(args.path || args.file_path || ""));
const offset = args.offset as number | undefined;
const limit = args.limit as number | undefined;
let display = path;
if (offset !== undefined || limit !== undefined) {
const start = offset ?? 1;
const end = limit !== undefined ? start + limit - 1 : "";
display += `:${start}${end ? `-${end}` : ""}`;
}
return `[read: ${display}]`;
}
case "write": {
const path = shortenPath(String(args.path || args.file_path || ""));
return `[write: ${path}]`;
}
case "edit": {
const path = shortenPath(String(args.path || args.file_path || ""));
return `[edit: ${path}]`;
}
case "bash": {
const rawCmd = String(args.command || "");
const cmd = rawCmd
.replace(/[\n\t]/g, " ")
.trim()
.slice(0, 50);
return `[bash: ${cmd}${rawCmd.length > 50 ? "..." : ""}]`;
}
case "grep": {
const pattern = String(args.pattern || "");
const path = shortenPath(String(args.path || "."));
return `[grep: /${pattern}/ in ${path}]`;
}
case "find": {
const pattern = String(args.pattern || "");
const path = shortenPath(String(args.path || "."));
return `[find: ${pattern} in ${path}]`;
}
case "ls": {
const path = shortenPath(String(args.path || "."));
return `[ls: ${path}]`;
}
default: {
// Custom tool - show name and truncated JSON args
const argsStr = JSON.stringify(args).slice(0, 40);
return `[${name}: ${argsStr}${JSON.stringify(args).length > 40 ? "..." : ""}]`;
}
}
}
handleInput(keyData: string): void {
if (isArrowUp(keyData)) {
this.selectedIndex = this.selectedIndex === 0 ? this.filteredNodes.length - 1 : this.selectedIndex - 1;
} else if (isArrowDown(keyData)) {
this.selectedIndex = this.selectedIndex === this.filteredNodes.length - 1 ? 0 : this.selectedIndex + 1;
} else if (isArrowLeft(keyData)) {
// Page up
this.selectedIndex = Math.max(0, this.selectedIndex - this.maxVisibleLines);
} else if (isArrowRight(keyData)) {
// Page down
this.selectedIndex = Math.min(this.filteredNodes.length - 1, this.selectedIndex + this.maxVisibleLines);
} else if (isEnter(keyData)) {
const selected = this.filteredNodes[this.selectedIndex];
if (selected && this.onSelect) {
this.onSelect(selected.node.entry.id);
}
} else if (isEscape(keyData)) {
if (this.searchQuery) {
this.searchQuery = "";
this.applyFilter();
} else {
this.onCancel?.();
}
} else if (isCtrlC(keyData)) {
this.onCancel?.();
} else if (isShiftCtrlO(keyData)) {
// Cycle filter backwards
const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"];
const currentIndex = modes.indexOf(this.filterMode);
this.filterMode = modes[(currentIndex - 1 + modes.length) % modes.length];
this.applyFilter();
} else if (isCtrlO(keyData)) {
// Cycle filter forwards: default → no-tools → user-only → labeled-only → all → default
const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"];
const currentIndex = modes.indexOf(this.filterMode);
this.filterMode = modes[(currentIndex + 1) % modes.length];
this.applyFilter();
} else if (isBackspace(keyData)) {
if (this.searchQuery.length > 0) {
this.searchQuery = this.searchQuery.slice(0, -1);
this.applyFilter();
}
} else if (keyData === "l" && !this.searchQuery) {
const selected = this.filteredNodes[this.selectedIndex];
if (selected && this.onLabelEdit) {
this.onLabelEdit(selected.node.entry.id, selected.node.label);
}
} else {
const hasControlChars = [...keyData].some((ch) => {
const code = ch.charCodeAt(0);
return code < 32 || code === 0x7f || (code >= 0x80 && code <= 0x9f);
});
if (!hasControlChars && keyData.length > 0) {
this.searchQuery += keyData;
this.applyFilter();
}
}
}
}
/** Component that displays the current search query */
class SearchLine implements Component {
constructor(private treeList: TreeList) {}
invalidate(): void {}
render(width: number): string[] {
const query = this.treeList.getSearchQuery();
if (query) {
return [truncateToWidth(` ${theme.fg("muted", "Search:")} ${theme.fg("accent", query)}`, width)];
}
return [truncateToWidth(` ${theme.fg("muted", "Search:")}`, width)];
}
handleInput(_keyData: string): void {}
}
/** Label input component shown when editing a label */
class LabelInput implements Component {
private input: Input;
private entryId: string;
public onSubmit?: (entryId: string, label: string | undefined) => void;
public onCancel?: () => void;
constructor(entryId: string, currentLabel: string | undefined) {
this.entryId = entryId;
this.input = new Input();
if (currentLabel) {
this.input.setValue(currentLabel);
}
}
invalidate(): void {}
render(width: number): string[] {
const lines: string[] = [];
const indent = " ";
const availableWidth = width - indent.length;
lines.push(truncateToWidth(`${indent}${theme.fg("muted", "Label (empty to remove):")}`, width));
lines.push(...this.input.render(availableWidth).map((line) => truncateToWidth(`${indent}${line}`, width)));
lines.push(truncateToWidth(`${indent}${theme.fg("dim", "enter: save esc: cancel")}`, width));
return lines;
}
handleInput(keyData: string): void {
if (isEnter(keyData)) {
const value = this.input.getValue().trim();
this.onSubmit?.(this.entryId, value || undefined);
} else if (isEscape(keyData)) {
this.onCancel?.();
} else {
this.input.handleInput(keyData);
}
}
}
/**
* Component that renders a session tree selector for navigation
*/
export class TreeSelectorComponent extends Container {
private treeList: TreeList;
private labelInput: LabelInput | null = null;
private labelInputContainer: Container;
private treeContainer: Container;
private onLabelChangeCallback?: (entryId: string, label: string | undefined) => void;
constructor(
tree: SessionTreeNode[],
currentLeafId: string | null,
terminalHeight: number,
onSelect: (entryId: string) => void,
onCancel: () => void,
onLabelChange?: (entryId: string, label: string | undefined) => void,
) {
super();
this.onLabelChangeCallback = onLabelChange;
const maxVisibleLines = Math.max(5, Math.floor(terminalHeight / 2));
this.treeList = new TreeList(tree, currentLeafId, maxVisibleLines);
this.treeList.onSelect = onSelect;
this.treeList.onCancel = onCancel;
this.treeList.onLabelEdit = (entryId, currentLabel) => this.showLabelInput(entryId, currentLabel);
this.treeContainer = new Container();
this.treeContainer.addChild(this.treeList);
this.labelInputContainer = new Container();
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
this.addChild(new Text(theme.bold(" Session Tree"), 1, 0));
this.addChild(
new TruncatedText(theme.fg("muted", " ↑/↓: move. ←/→: page. l: label. ^O/⇧^O: filter. Type to search"), 0, 0),
);
this.addChild(new SearchLine(this.treeList));
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
this.addChild(this.treeContainer);
this.addChild(this.labelInputContainer);
this.addChild(new Spacer(1));
this.addChild(new DynamicBorder());
if (tree.length === 0) {
setTimeout(() => onCancel(), 100);
}
}
private showLabelInput(entryId: string, currentLabel: string | undefined): void {
this.labelInput = new LabelInput(entryId, currentLabel);
this.labelInput.onSubmit = (id, label) => {
this.treeList.updateNodeLabel(id, label);
this.onLabelChangeCallback?.(id, label);
this.hideLabelInput();
};
this.labelInput.onCancel = () => this.hideLabelInput();
this.treeContainer.clear();
this.labelInputContainer.clear();
this.labelInputContainer.addChild(this.labelInput);
}
private hideLabelInput(): void {
this.labelInput = null;
this.labelInputContainer.clear();
this.treeContainer.clear();
this.treeContainer.addChild(this.treeList);
}
handleInput(keyData: string): void {
if (this.labelInput) {
this.labelInput.handleInput(keyData);
} else {
this.treeList.handleInput(keyData);
}
}
getTreeList(): TreeList {
return this.treeList;
}
}

View file

@ -5,13 +5,9 @@ import { getMarkdownTheme, theme } from "../theme/theme.js";
* Component that renders a user message
*/
export class UserMessageComponent extends Container {
constructor(text: string, isFirst: boolean) {
constructor(text: string) {
super();
// Add spacer before user message (except first one)
if (!isFirst) {
this.addChild(new Spacer(1));
}
this.addChild(new Spacer(1));
this.addChild(
new Markdown(text, 1, 1, getMarkdownTheme(), {
bgColor: (text: string) => theme.bg("userMessageBg", text),

Some files were not shown because too many files have changed in this diff Show more