mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 05:03:26 +00:00
Merge session-tree: tree structure with branching, compaction, and hook API improvements
This commit is contained in:
commit
1f3f851185
174 changed files with 10978 additions and 6295 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -29,3 +29,4 @@ compaction-results/
|
|||
.opencode/
|
||||
syntax.jsonl
|
||||
out.jsonl
|
||||
pi-*.html
|
||||
|
|
|
|||
12
.pi/commands/review.md
Normal file
12
.pi/commands/review.md
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
---
|
||||
description: Review a file for issues
|
||||
---
|
||||
Please review the following file for potential issues, bugs, or improvements:
|
||||
|
||||
$1
|
||||
|
||||
Focus on:
|
||||
- Logic errors
|
||||
- Edge cases
|
||||
- Code style
|
||||
- Performance concerns
|
||||
86
.pi/hooks/test-command.ts
Normal file
86
.pi/hooks/test-command.ts
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Test hook demonstrating custom commands, message rendering, and before_agent_start.
|
||||
*/
|
||||
import type { BeforeAgentStartEvent, HookAPI } from "@mariozechner/pi-coding-agent";
|
||||
import { Box, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
// Track whether injection is enabled
|
||||
let injectEnabled = false;
|
||||
|
||||
// Register a custom message renderer for our "test-info" type
|
||||
pi.registerMessageRenderer("test-info", (message, options, theme) => {
|
||||
const box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
|
||||
const label = theme.fg("success", "[TEST INFO]");
|
||||
box.addChild(new Text(label, 0, 0));
|
||||
box.addChild(new Spacer(1));
|
||||
|
||||
const content =
|
||||
typeof message.content === "string"
|
||||
? message.content
|
||||
: message.content.map((c) => (c.type === "text" ? c.text : "[image]")).join("");
|
||||
|
||||
box.addChild(new Text(theme.fg("text", content), 0, 0));
|
||||
|
||||
if (options.expanded && message.details) {
|
||||
box.addChild(new Spacer(1));
|
||||
box.addChild(new Text(theme.fg("dim", `Details: ${JSON.stringify(message.details)}`), 0, 0));
|
||||
}
|
||||
|
||||
return box;
|
||||
});
|
||||
|
||||
// Register /test-msg command
|
||||
pi.registerCommand("test-msg", {
|
||||
description: "Send a test custom message",
|
||||
handler: async () => {
|
||||
pi.sendMessage(
|
||||
{
|
||||
customType: "test-info",
|
||||
content: "This is a test message with custom rendering!",
|
||||
display: true,
|
||||
details: { timestamp: Date.now(), source: "test-command hook" },
|
||||
},
|
||||
true, // triggerTurn: start agent run
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// Register /test-hidden command
|
||||
pi.registerCommand("test-hidden", {
|
||||
description: "Send a hidden message (display: false)",
|
||||
handler: async (ctx) => {
|
||||
pi.sendMessage({
|
||||
customType: "test-info",
|
||||
content: "This message is in context but not displayed",
|
||||
display: false,
|
||||
});
|
||||
ctx.ui.notify("Sent hidden message (check session file)");
|
||||
},
|
||||
});
|
||||
|
||||
// Register /test-inject command to toggle before_agent_start injection
|
||||
pi.registerCommand("test-inject", {
|
||||
description: "Toggle context injection before agent starts",
|
||||
handler: async (ctx) => {
|
||||
injectEnabled = !injectEnabled;
|
||||
ctx.ui.notify(`Context injection ${injectEnabled ? "enabled" : "disabled"}`);
|
||||
},
|
||||
});
|
||||
|
||||
// Demonstrate before_agent_start: inject context when enabled
|
||||
pi.on("before_agent_start", async (event: BeforeAgentStartEvent) => {
|
||||
if (!injectEnabled) return;
|
||||
|
||||
// Return a message to inject before the user's prompt
|
||||
return {
|
||||
message: {
|
||||
customType: "test-info",
|
||||
content: `[Injected context for prompt: "${event.prompt.slice(0, 50)}..."]`,
|
||||
display: true,
|
||||
details: { injectedAt: Date.now() },
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
56
AGENTS.md
56
AGENTS.md
|
|
@ -30,7 +30,7 @@ When reading issues:
|
|||
|
||||
When creating issues:
|
||||
- Add `pkg:*` labels to indicate which package(s) the issue affects
|
||||
- Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:proxy`, `pkg:tui`, `pkg:web-ui`
|
||||
- Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:tui`, `pkg:web-ui`
|
||||
- If an issue spans multiple packages, add all relevant labels
|
||||
|
||||
When closing issues via commit:
|
||||
|
|
@ -39,7 +39,7 @@ When closing issues via commit:
|
|||
|
||||
## Tools
|
||||
- GitHub CLI for issues/PRs
|
||||
- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:proxy, pkg:tui, pkg:web-ui
|
||||
- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:tui, pkg:web-ui
|
||||
- TUI interaction: use tmux
|
||||
|
||||
## Style
|
||||
|
|
@ -49,43 +49,37 @@ When closing issues via commit:
|
|||
- Technical prose only, be kind but direct (e.g., "Thanks @user" not "Thanks so much @user!")
|
||||
|
||||
## Changelog
|
||||
Location: `packages/coding-agent/CHANGELOG.md`, `packages/ai/CHANGELOG.md`, `packages/tui/CHANGELOG.md`, pick the one relevant to the changes or ask user.
|
||||
Location: `packages/*/CHANGELOG.md` (each package has its own)
|
||||
|
||||
### Format
|
||||
Use these sections under `## [Unreleased]`:
|
||||
- `### Breaking Changes` - API changes requiring migration
|
||||
- `### Added` - New features
|
||||
- `### Changed` - Changes to existing functionality
|
||||
- `### Fixed` - Bug fixes
|
||||
- `### Removed` - Removed features
|
||||
|
||||
### Rules
|
||||
- New entries ALWAYS go under `## [Unreleased]` section
|
||||
- NEVER modify already-released version sections (e.g., `## [0.12.2]`)
|
||||
- Each version section is immutable once released
|
||||
- When releasing: rename `[Unreleased]` to the new version, then add a fresh empty `[Unreleased]` section
|
||||
|
||||
### Attribution format
|
||||
- **Internal changes (from issues)**: Reference issue only
|
||||
- Example: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))`
|
||||
- **External contributions (PRs from others)**: Reference PR and credit the contributor
|
||||
- Example: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))`
|
||||
- If a PR addresses an issue, reference both: `([#123](...issues/123), [#456](...pull/456) by [@user](...))` or just the PR if the issue context is clear from the description
|
||||
### Attribution
|
||||
- **Internal changes (from issues)**: `Fixed foo bar ([#123](https://github.com/badlogic/pi-mono/issues/123))`
|
||||
- **External contributions**: `Added feature X ([#456](https://github.com/badlogic/pi-mono/pull/456) by [@username](https://github.com/username))`
|
||||
|
||||
## Releasing
|
||||
|
||||
1. **Bump version** (all packages use lockstep versioning):
|
||||
1. **Update CHANGELOGs**: Ensure all changes since last release are documented in the `[Unreleased]` section of each affected package's CHANGELOG.md
|
||||
|
||||
2. **Run release script**:
|
||||
```bash
|
||||
npm run version:patch # For bug fixes
|
||||
npm run version:minor # For new features
|
||||
npm run version:major # For breaking changes
|
||||
npm run release:patch # Bug fixes
|
||||
npm run release:minor # New features
|
||||
npm run release:major # Breaking changes
|
||||
```
|
||||
|
||||
2. **Finalize CHANGELOG.md**: Change `[Unreleased]` to the new version with today's date (e.g., `## [0.12.12] - 2025-12-05`)
|
||||
|
||||
3. **Commit and tag**:
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "Release v0.12.12"
|
||||
git tag v0.12.12
|
||||
git push origin main
|
||||
git push origin v0.12.12
|
||||
```
|
||||
|
||||
4. **Publish to npm**:
|
||||
```bash
|
||||
npm run publish
|
||||
```
|
||||
|
||||
5. **Add new [Unreleased] section** at top of CHANGELOG.md for next cycle, commit it
|
||||
The script handles: version bump, CHANGELOG finalization, commit, tag, publish, and adding new `[Unreleased]` sections.
|
||||
|
||||
### Tool Usage
|
||||
**CTRICIAL**: NEVER use sed/cat to read a file or a range of a file. Always use the read tool (use offset + limit for ranged reads).
|
||||
60
README.md
60
README.md
|
|
@ -7,12 +7,11 @@ Tools for building AI agents and managing LLM deployments.
|
|||
| Package | Description |
|
||||
|---------|-------------|
|
||||
| **[@mariozechner/pi-ai](packages/ai)** | Unified multi-provider LLM API (OpenAI, Anthropic, Google, etc.) |
|
||||
| **[@mariozechner/pi-agent](packages/agent)** | Agent runtime with tool calling and state management |
|
||||
| **[@mariozechner/pi-agent-core](packages/agent)** | Agent runtime with tool calling and state management |
|
||||
| **[@mariozechner/pi-coding-agent](packages/coding-agent)** | Interactive coding agent CLI |
|
||||
| **[@mariozechner/pi-mom](packages/mom)** | Slack bot that delegates messages to the pi coding agent |
|
||||
| **[@mariozechner/pi-tui](packages/tui)** | Terminal UI library with differential rendering |
|
||||
| **[@mariozechner/pi-web-ui](packages/web-ui)** | Web components for AI chat interfaces |
|
||||
| **[@mariozechner/pi-proxy](packages/proxy)** | CORS proxy for browser-based LLM API calls |
|
||||
| **[@mariozechner/pi-pods](packages/pods)** | CLI for managing vLLM deployments on GPU pods |
|
||||
|
||||
## Development
|
||||
|
|
@ -71,55 +70,18 @@ These commands:
|
|||
|
||||
### Publishing
|
||||
|
||||
Complete release process:
|
||||
```bash
|
||||
npm run release:patch # Bug fixes
|
||||
npm run release:minor # New features
|
||||
npm run release:major # Breaking changes
|
||||
```
|
||||
|
||||
1. **Add changes to CHANGELOG.md** (if changes affect coding-agent):
|
||||
```bash
|
||||
# Add your changes to the [Unreleased] section in packages/coding-agent/CHANGELOG.md
|
||||
# Always add new entries under [Unreleased], never under already-released versions
|
||||
```
|
||||
This handles version bump, CHANGELOG updates, commit, tag, publish, and push.
|
||||
|
||||
2. **Bump version** (all packages):
|
||||
```bash
|
||||
npm run version:patch # For bug fixes
|
||||
npm run version:minor # For new features
|
||||
npm run version:major # For breaking changes
|
||||
```
|
||||
|
||||
3. **Finalize CHANGELOG.md for release** (if changes affect coding-agent):
|
||||
```bash
|
||||
# Change [Unreleased] to the new version number with today's date
|
||||
# e.g., ## [0.7.16] - 2025-11-17
|
||||
# NEVER add entries to already-released version sections
|
||||
# Each version section is immutable once released
|
||||
```
|
||||
|
||||
4. **Commit and tag**:
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "Release v0.7.16"
|
||||
git tag v0.7.16
|
||||
git push origin main
|
||||
git push origin v0.7.16
|
||||
```
|
||||
|
||||
5. **Publish to npm**:
|
||||
```bash
|
||||
npm run publish # Publish all packages to npm
|
||||
```
|
||||
|
||||
**NPM Token Setup**: Publishing requires a granular access token with "Bypass 2FA on publish" enabled.
|
||||
- Go to https://www.npmjs.com/settings/badlogic/tokens/
|
||||
- Create a new "Granular Access Token"
|
||||
- Select "Bypass 2FA on publish"
|
||||
- Tokens expire after 90 days, so regenerate when needed
|
||||
- Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN`
|
||||
|
||||
6. **Add new [Unreleased] section** (for next development cycle):
|
||||
```bash
|
||||
# Add a new [Unreleased] section at the top of CHANGELOG.md
|
||||
# Commit: git commit -am "Add [Unreleased] section"
|
||||
```
|
||||
**NPM Token Setup**: Requires a granular access token with "Bypass 2FA on publish" enabled.
|
||||
- Go to https://www.npmjs.com/settings/badlogic/tokens/
|
||||
- Create a new "Granular Access Token" with "Bypass 2FA on publish"
|
||||
- Set the token: `npm config set //registry.npmjs.org/:_authToken=YOUR_TOKEN`
|
||||
|
||||
## License
|
||||
|
||||
|
|
|
|||
149
package-lock.json
generated
149
package-lock.json
generated
|
|
@ -12,6 +12,7 @@
|
|||
"packages/web-ui/example"
|
||||
],
|
||||
"dependencies": {
|
||||
"@mariozechner/pi-coding-agent": "^0.30.2",
|
||||
"get-east-asian-width": "^1.4.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
@ -695,18 +696,6 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.7",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.7.tgz",
|
||||
"integrity": "sha512-vUcD0uauS7EU2caukW8z5lJKtoGMokxNbJtBiwHgpqxEXokaHCBkQUmCHhjFB1VUTWdqj25QoMkMKzgjq+uhrw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"hono": "^4"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/balanced-match": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz",
|
||||
|
|
@ -964,10 +953,6 @@
|
|||
"resolved": "packages/mom",
|
||||
"link": true
|
||||
},
|
||||
"node_modules/@mariozechner/pi-proxy": {
|
||||
"resolved": "packages/proxy",
|
||||
"link": true
|
||||
},
|
||||
"node_modules/@mariozechner/pi-tui": {
|
||||
"resolved": "packages/tui",
|
||||
"link": true
|
||||
|
|
@ -995,9 +980,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.86.tgz",
|
||||
"integrity": "sha512-hOkywnrkdFdVpsuaNsZWfEY7kc96eROV2DuMTTvGF15AZfwobzdG2w0eDlU5UBx3Lg/XlWUnqVT5zLUWyo5h6A==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.88.tgz",
|
||||
"integrity": "sha512-/p08f93LEbsL5mDZFQ3DBxcPv/I4QG9EDYRRq1WNlCOXVfAHBTHMSVMwxlqG/AtnSfUr9+vgfN7MKiyDo0+Weg==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"workspaces": [
|
||||
|
|
@ -1011,23 +996,23 @@
|
|||
"url": "https://github.com/sponsors/Brooooooklyn"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@napi-rs/canvas-android-arm64": "0.1.86",
|
||||
"@napi-rs/canvas-darwin-arm64": "0.1.86",
|
||||
"@napi-rs/canvas-darwin-x64": "0.1.86",
|
||||
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.86",
|
||||
"@napi-rs/canvas-linux-arm64-gnu": "0.1.86",
|
||||
"@napi-rs/canvas-linux-arm64-musl": "0.1.86",
|
||||
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.86",
|
||||
"@napi-rs/canvas-linux-x64-gnu": "0.1.86",
|
||||
"@napi-rs/canvas-linux-x64-musl": "0.1.86",
|
||||
"@napi-rs/canvas-win32-arm64-msvc": "0.1.86",
|
||||
"@napi-rs/canvas-win32-x64-msvc": "0.1.86"
|
||||
"@napi-rs/canvas-android-arm64": "0.1.88",
|
||||
"@napi-rs/canvas-darwin-arm64": "0.1.88",
|
||||
"@napi-rs/canvas-darwin-x64": "0.1.88",
|
||||
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.88",
|
||||
"@napi-rs/canvas-linux-arm64-gnu": "0.1.88",
|
||||
"@napi-rs/canvas-linux-arm64-musl": "0.1.88",
|
||||
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.88",
|
||||
"@napi-rs/canvas-linux-x64-gnu": "0.1.88",
|
||||
"@napi-rs/canvas-linux-x64-musl": "0.1.88",
|
||||
"@napi-rs/canvas-win32-arm64-msvc": "0.1.88",
|
||||
"@napi-rs/canvas-win32-x64-msvc": "0.1.88"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-android-arm64": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.86.tgz",
|
||||
"integrity": "sha512-IjkZFKUr6GzMzzrawJaN3v+yY3Fvpa71e0DcbePfxWelFKnESIir+XUcdAbim29JOd0JE0/hQJdfUCb5t/Fjrw==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.88.tgz",
|
||||
"integrity": "sha512-KEaClPnZuVxJ8smUWjV1wWFkByBO/D+vy4lN+Dm5DFH514oqwukxKGeck9xcKJhaWJGjfruGmYGiwRe//+/zQQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -1045,9 +1030,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-darwin-arm64": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.86.tgz",
|
||||
"integrity": "sha512-PUCxDq0wSSJbtaOqoKj3+t5tyDbtxWumziOTykdn3T839hu6koMaBFpGk9lXpsGaPNgyFpPqjxhtsPljBGnDHg==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.88.tgz",
|
||||
"integrity": "sha512-Xgywz0dDxOKSgx3eZnK85WgGMmGrQEW7ZLA/E7raZdlEE+xXCozobgqz2ZvYigpB6DJFYkqnwHjqCOTSDGlFdg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -1065,9 +1050,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-darwin-x64": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.86.tgz",
|
||||
"integrity": "sha512-rlCFLv4Rrg45qFZq7mysrKnsUbMhwdNg3YPuVfo9u4RkOqm7ooAJvdyDFxiqfSsJJTqupYqa9VQCUt8WKxKhNQ==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.88.tgz",
|
||||
"integrity": "sha512-Yz4wSCIQOUgNucgk+8NFtQxQxZV5NO8VKRl9ePKE6XoNyNVC8JDqtvhh3b3TPqKK8W5p2EQpAr1rjjm0mfBxdg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -1085,9 +1070,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.86.tgz",
|
||||
"integrity": "sha512-6xWwyMc9BlDBt+9XHN/GzUo3MozHta/2fxQHMb80x0K2zpZuAdDKUYHmYzx9dFWDY3SbPYnx6iRlQl6wxnwS1w==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.88.tgz",
|
||||
"integrity": "sha512-9gQM2SlTo76hYhxHi2XxWTAqpTOb+JtxMPEIr+H5nAhHhyEtNmTSDRtz93SP7mGd2G3Ojf2oF5tP9OdgtgXyKg==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
|
|
@ -1105,9 +1090,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.86.tgz",
|
||||
"integrity": "sha512-r2OX3w50xHxrToTovOSQWwkVfSq752CUzH9dzlVXyr8UDKFV8dMjfa9hePXvAJhN3NBp4TkHcGx15QCdaCIwnA==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.88.tgz",
|
||||
"integrity": "sha512-7qgaOBMXuVRk9Fzztzr3BchQKXDxGbY+nwsovD3I/Sx81e+sX0ReEDYHTItNb0Je4NHbAl7D0MKyd4SvUc04sg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -1125,9 +1110,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.86.tgz",
|
||||
"integrity": "sha512-jbXuh8zVFUPw6a9SGpgc6EC+fRbGGyP1NFfeQiVqGLs6bN93ROtPLPL6MH9Bp6yt0CXUFallk2vgKdWDbmW+bw==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.88.tgz",
|
||||
"integrity": "sha512-kYyNrUsHLkoGHBc77u4Unh067GrfiCUMbGHC2+OTxbeWfZkPt2o32UOQkhnSswKd9Fko/wSqqGkY956bIUzruA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -1145,9 +1130,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.86.tgz",
|
||||
"integrity": "sha512-9IwHR2qbq2HceM9fgwyL7x37Jy3ptt1uxvikQEuWR0FisIx9QEdt7F3huljCky76aoouF2vSd0R2fHo3ESRoPw==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.88.tgz",
|
||||
"integrity": "sha512-HVuH7QgzB0yavYdNZDRyAsn/ejoXB0hn8twwFnOqUbCCdkV+REna7RXjSR7+PdfW0qMQ2YYWsLvVBT5iL/mGpw==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
|
|
@ -1165,9 +1150,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.86.tgz",
|
||||
"integrity": "sha512-Jor+rhRN6ubix+D2QkNn9XlPPVAYl+2qFrkZ4oZN9UgtqIUZ+n+HljxhlkkDFRaX1mlxXOXPQjxaZg17zDSFcQ==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.88.tgz",
|
||||
"integrity": "sha512-hvcvKIcPEQrvvJtJnwD35B3qk6umFJ8dFIr8bSymfrSMem0EQsfn1ztys8ETIFndTwdNWJKWluvxztA41ivsEw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -1185,9 +1170,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-x64-musl": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.86.tgz",
|
||||
"integrity": "sha512-A28VTy91DbclopSGZ2tIon3p8hcVI1JhnNpDpJ5N9rYlUnVz1WQo4waEMh+FICTZF07O3coxBNZc4Vu4doFw7A==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.88.tgz",
|
||||
"integrity": "sha512-eSMpGYY2xnZSQ6UxYJ6plDboxq4KeJ4zT5HaVkUnbObNN6DlbJe0Mclh3wifAmquXfrlgTZt6zhHsUgz++AK6g==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -1205,9 +1190,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-win32-arm64-msvc": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.86.tgz",
|
||||
"integrity": "sha512-q6G1YXUt3gBCAS2bcDMCaBL4y20di8eVVBi1XhjUqZSVyZZxxwIuRQHy31NlPJUCMiyNiMuc6zeI0uqgkWwAmA==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-0.1.88.tgz",
|
||||
"integrity": "sha512-qcIFfEgHrchyYqRrxsCeTQgpJZ/GqHiqPcU/Fvw/ARVlQeDX1VyFH+X+0gCR2tca6UJrq96vnW+5o7buCq+erA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -1225,9 +1210,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
|
||||
"version": "0.1.86",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.86.tgz",
|
||||
"integrity": "sha512-X0g46uRVgnvCM1cOjRXAOSFSG63ktUFIf/TIfbKCUc7QpmYUcHmSP9iR6DGOYfk+SggLsXoJCIhPTotYeZEAmg==",
|
||||
"version": "0.1.88",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.88.tgz",
|
||||
"integrity": "sha512-ROVqbfS4QyZxYkqmaIBBpbz/BQvAR+05FXM5PAtTYVc0uyY8Y4BHJSMdGAaMf6TdIVRsQsiq+FG/dH9XhvWCFQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -4008,16 +3993,6 @@
|
|||
"node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.11.2",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.2.tgz",
|
||||
"integrity": "sha512-o+avdUAD1v94oHkjGBhiMhBV4WBHxhbu0+CUVH78hhphKy/OKQLxtKjkmmNcrMlbYAhAbsM/9F+l3KnYxyD3Lg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/html-parse-string": {
|
||||
"version": "0.0.9",
|
||||
"resolved": "https://registry.npmjs.org/html-parse-string/-/html-parse-string-0.0.9.tgz",
|
||||
|
|
@ -4583,7 +4558,6 @@
|
|||
"resolved": "https://registry.npmjs.org/lit/-/lit-3.3.2.tgz",
|
||||
"integrity": "sha512-NF9zbsP79l4ao2SNrH3NkfmFgN/hBYSQo90saIVI1o5GpjAdCPVstVzO1MrLOakHoEhYkrtRjPK6Ob521aoYWQ==",
|
||||
"license": "BSD-3-Clause",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@lit/reactive-element": "^2.1.0",
|
||||
"lit-element": "^4.2.0",
|
||||
|
|
@ -5704,7 +5678,6 @@
|
|||
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.4.0.tgz",
|
||||
"integrity": "sha512-uSaO4gnW+b3Y2aWoWfFpX62vn2sR3skfhbjsEnaBI81WD1wBLlHZe5sWf0AqjksNdYTbGBEd0UasQMT3SNV15g==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/dcastil"
|
||||
|
|
@ -5733,8 +5706,7 @@
|
|||
"version": "4.1.18",
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.18.tgz",
|
||||
"integrity": "sha512-4+Z+0yiYyEtUVCScyfHCxOYP06L5Ne+JiHhY2IjR2KWMIWhJOYZKLSGZaP5HkZ8+bY0cxfzwDE5uOmzFXyIwxw==",
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/tapable": {
|
||||
"version": "2.3.0",
|
||||
|
|
@ -5999,7 +5971,6 @@
|
|||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.0.tgz",
|
||||
"integrity": "sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.27.0",
|
||||
"fdir": "^6.5.0",
|
||||
|
|
@ -6442,9 +6413,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/zod-to-json-schema": {
|
||||
"version": "3.25.0",
|
||||
"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.0.tgz",
|
||||
"integrity": "sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==",
|
||||
"version": "3.25.1",
|
||||
"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.1.tgz",
|
||||
"integrity": "sha512-pM/SU9d3YAggzi6MtR4h7ruuQlqKtad8e9S0fmxcMi+ueAK5Korys/aWcV9LIIHTVbj01NdzxcnXSN+O74ZIVA==",
|
||||
"license": "ISC",
|
||||
"peerDependencies": {
|
||||
"zod": "^3.25 || ^4"
|
||||
|
|
@ -6636,22 +6607,6 @@
|
|||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"packages/proxy": {
|
||||
"name": "@mariozechner/pi-proxy",
|
||||
"version": "0.30.2",
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.14.0",
|
||||
"hono": "^4.6.16"
|
||||
},
|
||||
"bin": {
|
||||
"pi-proxy": "dist/cli.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.10.5",
|
||||
"tsx": "^4.19.2",
|
||||
"typescript": "^5.7.3"
|
||||
}
|
||||
},
|
||||
"packages/tui": {
|
||||
"name": "@mariozechner/pi-tui",
|
||||
"version": "0.30.2",
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@
|
|||
],
|
||||
"scripts": {
|
||||
"clean": "npm run clean --workspaces",
|
||||
"build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi-proxy && npm run build -w @mariozechner/pi",
|
||||
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui,proxy\" --prefix-colors \"cyan,yellow,red,white,green,magenta,blue\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\" \"npm run dev -w @mariozechner/pi-proxy\"",
|
||||
"build": "npm run build -w @mariozechner/pi-tui && npm run build -w @mariozechner/pi-ai && npm run build -w @mariozechner/pi-agent-core && npm run build -w @mariozechner/pi-coding-agent && npm run build -w @mariozechner/pi-mom && npm run build -w @mariozechner/pi-web-ui && npm run build -w @mariozechner/pi",
|
||||
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui\" --prefix-colors \"cyan,yellow,red,white,green,magenta\" \"npm run dev -w @mariozechner/pi-ai\" \"npm run dev -w @mariozechner/pi-agent-core\" \"npm run dev -w @mariozechner/pi-coding-agent\" \"npm run dev -w @mariozechner/pi-mom\" \"npm run dev -w @mariozechner/pi-web-ui\" \"npm run dev -w @mariozechner/pi-tui\"",
|
||||
"dev:tsc": "concurrently --names \"ai,web-ui\" --prefix-colors \"cyan,green\" \"npm run dev:tsc -w @mariozechner/pi-ai\" \"npm run dev:tsc -w @mariozechner/pi-web-ui\"",
|
||||
"check": "biome check --write . && tsgo --noEmit && npm run check -w @mariozechner/pi-web-ui",
|
||||
"test": "npm run test --workspaces --if-present",
|
||||
|
|
@ -20,6 +20,9 @@
|
|||
"prepublishOnly": "npm run clean && npm run build && npm run check",
|
||||
"publish": "npm run prepublishOnly && npm publish -ws --access public",
|
||||
"publish:dry": "npm run prepublishOnly && npm publish -ws --access public --dry-run",
|
||||
"release:patch": "node scripts/release.mjs patch",
|
||||
"release:minor": "node scripts/release.mjs minor",
|
||||
"release:major": "node scripts/release.mjs major",
|
||||
"prepare": "husky"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
@ -36,6 +39,7 @@
|
|||
},
|
||||
"version": "0.0.3",
|
||||
"dependencies": {
|
||||
"@mariozechner/pi-coding-agent": "^0.30.2",
|
||||
"get-east-asian-width": "^1.4.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
69
packages/agent/CHANGELOG.md
Normal file
69
packages/agent/CHANGELOG.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Changelog
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. The `Agent` class now takes a `streamFn` option directly for custom streaming implementations.
|
||||
|
||||
- **Agent options renamed**:
|
||||
- `transport` → removed (use `streamFn` instead)
|
||||
- `messageTransformer` → `convertToLlm` (converts `AgentMessage[]` to LLM-compatible `Message[]`)
|
||||
- `preprocessor` → `transformContext` (transforms `AgentMessage[]` before `convertToLlm`)
|
||||
|
||||
- **AppMessage renamed to AgentMessage**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency.
|
||||
|
||||
- **Agent loop moved from pi-ai**: The `agentLoop`, `agentLoopContinue`, and related types (`AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `AgentLoopConfig`) have moved from `@mariozechner/pi-ai` to this package.
|
||||
|
||||
### Added
|
||||
|
||||
- **`streamFn` option**: Pass a custom stream function to the Agent for proxy backends or custom implementations. Default uses `streamSimple` from pi-ai.
|
||||
|
||||
- **`streamProxy` utility**: New helper function for browser apps that need to proxy through a backend server. Replaces `AppTransport`.
|
||||
|
||||
- **`getApiKey` option**: Dynamic API key resolution for expiring OAuth tokens (e.g., GitHub Copilot).
|
||||
|
||||
- **`AgentLoopContext` and `AgentLoopConfig`**: Exported types for the low-level agent loop API.
|
||||
|
||||
- **`agentLoop` and `agentLoopContinue`**: Low-level functions for running the agent loop directly without the `Agent` class wrapper.
|
||||
|
||||
### Migration Guide
|
||||
|
||||
**Before (0.30.x):**
|
||||
```typescript
|
||||
import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core';
|
||||
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport({ apiKey: '...' }),
|
||||
messageTransformer: (messages) => messages.filter(...),
|
||||
preprocessor: async (messages) => compactMessages(messages)
|
||||
});
|
||||
```
|
||||
|
||||
**After:**
|
||||
```typescript
|
||||
import { Agent } from '@mariozechner/pi-agent-core';
|
||||
import { streamSimple } from '@mariozechner/pi-ai';
|
||||
|
||||
const agent = new Agent({
|
||||
streamFn: streamSimple, // or omit for default
|
||||
convertToLlm: (messages) => messages.filter(...),
|
||||
transformContext: async (messages) => compactMessages(messages),
|
||||
getApiKey: async (provider) => resolveApiKey(provider)
|
||||
});
|
||||
```
|
||||
|
||||
**For proxy usage (replaces AppTransport):**
|
||||
```typescript
|
||||
import { Agent, streamProxy } from '@mariozechner/pi-agent-core';
|
||||
|
||||
const agent = new Agent({
|
||||
streamFn: (model, context, options) => streamProxy(
|
||||
'/api/agent',
|
||||
model,
|
||||
context,
|
||||
options,
|
||||
{ 'Authorization': 'Bearer ...' }
|
||||
)
|
||||
});
|
||||
```
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# @mariozechner/pi-agent-core
|
||||
|
||||
Stateful agent abstraction with transport layer for LLM interactions. Provides a reactive `Agent` class that manages conversation state, emits granular events, and supports pluggable transports for different deployment scenarios.
|
||||
Stateful agent with tool execution, event streaming, and extensible message types. Built on `@mariozechner/pi-ai`.
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
@ -11,12 +11,10 @@ npm install @mariozechner/pi-agent-core
|
|||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { Agent, ProviderTransport } from '@mariozechner/pi-agent-core';
|
||||
import { Agent } from '@mariozechner/pi-agent-core';
|
||||
import { getModel } from '@mariozechner/pi-ai';
|
||||
|
||||
// Create agent with direct provider transport
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
initialState: {
|
||||
systemPrompt: 'You are a helpful assistant.',
|
||||
model: getModel('anthropic', 'claude-sonnet-4-20250514'),
|
||||
|
|
@ -28,38 +26,105 @@ const agent = new Agent({
|
|||
// Subscribe to events for reactive UI updates
|
||||
agent.subscribe((event) => {
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
console.log(`${event.message.role} message started`);
|
||||
break;
|
||||
case 'message_update':
|
||||
// Stream text to UI
|
||||
const content = event.message.content;
|
||||
for (const block of content) {
|
||||
if (block.type === 'text') console.log(block.text);
|
||||
// Only emitted for assistant messages during streaming
|
||||
// event.message is partial - may have incomplete content
|
||||
for (const block of event.message.content) {
|
||||
if (block.type === 'text') process.stdout.write(block.text);
|
||||
}
|
||||
break;
|
||||
case 'message_end':
|
||||
console.log(`${event.message.role} message complete`);
|
||||
break;
|
||||
case 'tool_execution_start':
|
||||
console.log(`Calling ${event.toolName}...`);
|
||||
break;
|
||||
case 'tool_execution_update':
|
||||
// Stream tool output (e.g., bash stdout)
|
||||
console.log('Progress:', event.partialResult.content);
|
||||
break;
|
||||
case 'tool_execution_end':
|
||||
console.log(`Result:`, event.result.content);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Send a prompt
|
||||
await agent.prompt('Hello, world!');
|
||||
|
||||
// Access conversation state
|
||||
console.log(agent.state.messages);
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
## AgentMessage vs LLM Message
|
||||
|
||||
### Agent State
|
||||
The agent internally works with `AgentMessage`, a flexible type that can include:
|
||||
- Standard LLM messages (`user`, `assistant`, `toolResult`)
|
||||
- Custom app-specific message types (via declaration merging)
|
||||
|
||||
The `Agent` maintains reactive state:
|
||||
LLMs only understand a subset: `user`, `assistant`, and `toolResult` messages with specific content formats. The `convertToLlm` function bridges this gap.
|
||||
|
||||
### Why This Separation?
|
||||
|
||||
1. **Rich UI state**: Store UI-specific data (attachments metadata, custom message types) alongside the conversation
|
||||
2. **Session persistence**: Save the full conversation state including app-specific messages
|
||||
3. **Context manipulation**: Transform messages before sending to LLM (compaction, injection, filtering)
|
||||
|
||||
### The Conversion Flow
|
||||
|
||||
```
|
||||
AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM
|
||||
↑ (optional) (required)
|
||||
|
|
||||
App state with custom types,
|
||||
attachments, UI metadata
|
||||
```
|
||||
|
||||
### Constraints
|
||||
|
||||
**Messages passed to `prompt()` or queued via `queueMessage()` must convert to LLM messages with `role: "user"` or `role: "toolResult"`.**
|
||||
|
||||
When calling `continue()`, the last message in the context must also convert to `user` or `toolResult`. The LLM expects to respond to a user or tool result, not to its own assistant message.
|
||||
|
||||
```typescript
|
||||
// OK: Standard user message
|
||||
await agent.prompt('Hello');
|
||||
|
||||
// OK: Custom type that converts to user message
|
||||
await agent.prompt({ role: 'hookMessage', content: 'System notification', timestamp: Date.now() });
|
||||
// But convertToLlm must handle this:
|
||||
convertToLlm: (messages) => messages.map(m => {
|
||||
if (m.role === 'hookMessage') {
|
||||
return { role: 'user', content: m.content, timestamp: m.timestamp };
|
||||
}
|
||||
return m;
|
||||
})
|
||||
|
||||
// ERROR: Cannot prompt with assistant message
|
||||
await agent.prompt({ role: 'assistant', content: [...], ... }); // Will fail at LLM
|
||||
```
|
||||
|
||||
## Agent Options
|
||||
|
||||
```typescript
|
||||
interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
|
||||
// Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
// Default: filters to user/assistant/toolResult and converts image attachments.
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
// Transform context before convertToLlm (for pruning, compaction, injecting context)
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
// Queue mode: 'all' sends all queued messages, 'one-at-a-time' sends one per turn
|
||||
queueMode?: 'all' | 'one-at-a-time';
|
||||
|
||||
// Custom stream function (for proxy backends). Default: streamSimple from pi-ai
|
||||
streamFn?: StreamFn;
|
||||
|
||||
// Dynamic API key resolution (useful for expiring OAuth tokens)
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
}
|
||||
```
|
||||
|
||||
## Agent State
|
||||
|
||||
```typescript
|
||||
interface AgentState {
|
||||
|
|
@ -67,17 +132,19 @@ interface AgentState {
|
|||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel; // 'off' | 'minimal' | 'low' | 'medium' | 'high' | 'xhigh'
|
||||
tools: AgentTool<any>[];
|
||||
messages: AppMessage[];
|
||||
messages: AgentMessage[]; // Full conversation including custom types
|
||||
isStreaming: boolean;
|
||||
streamMessage: Message | null;
|
||||
streamMessage: AgentMessage | null; // Current partial message during streaming
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
```
|
||||
|
||||
### Events
|
||||
## Events
|
||||
|
||||
Events provide fine-grained lifecycle information:
|
||||
Events provide fine-grained lifecycle information for building reactive UIs.
|
||||
|
||||
### Event Types
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
|
|
@ -86,33 +153,118 @@ Events provide fine-grained lifecycle information:
|
|||
| `turn_start` | New turn begins (one LLM response + tool executions) |
|
||||
| `turn_end` | Turn completes with assistant message and tool results |
|
||||
| `message_start` | Message begins (user, assistant, or toolResult) |
|
||||
| `message_update` | Assistant message streaming update |
|
||||
| `message_update` | **Assistant messages only.** Partial message during streaming |
|
||||
| `message_end` | Message completes |
|
||||
| `tool_execution_start` | Tool begins execution |
|
||||
| `tool_execution_update` | Tool streams progress (e.g., bash output) |
|
||||
| `tool_execution_update` | Tool streams progress |
|
||||
| `tool_execution_end` | Tool completes with result |
|
||||
|
||||
### Transports
|
||||
### Message Events for prompt() and queueMessage()
|
||||
|
||||
Transports abstract LLM communication:
|
||||
When you call `prompt(message)`, the agent emits `message_start` and `message_end` events for that message before the assistant responds:
|
||||
|
||||
- **`ProviderTransport`**: Direct API calls using `@mariozechner/pi-ai`
|
||||
- **`AppTransport`**: Proxy through a backend server (for browser apps)
|
||||
```
|
||||
prompt(userMessage)
|
||||
→ agent_start
|
||||
→ turn_start
|
||||
→ message_start { message: userMessage }
|
||||
→ message_end { message: userMessage }
|
||||
→ message_start { message: assistantMessage } // LLM starts responding
|
||||
→ message_update { message: partialAssistant } // streaming...
|
||||
→ message_end { message: assistantMessage }
|
||||
...
|
||||
```
|
||||
|
||||
Queued messages (via `queueMessage()`) emit the same events when injected:
|
||||
|
||||
```
|
||||
// During tool execution, a message is queued
|
||||
agent.queueMessage(interruptMessage)
|
||||
|
||||
// After tool completes, before next LLM call:
|
||||
→ message_start { message: interruptMessage }
|
||||
→ message_end { message: interruptMessage }
|
||||
→ message_start { message: assistantMessage } // LLM responds to interrupt
|
||||
...
|
||||
```
|
||||
|
||||
### Handling Partial Messages in Reactive UIs
|
||||
|
||||
`message_update` events contain partial assistant messages during streaming. The `event.message` may have:
|
||||
- Incomplete text (truncated mid-word)
|
||||
- Partial tool call arguments
|
||||
- Missing content blocks that haven't started streaming yet
|
||||
|
||||
**Pattern for reactive UIs:**
|
||||
|
||||
```typescript
|
||||
// Direct provider access (Node.js)
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport({
|
||||
apiKey: process.env.ANTHROPIC_API_KEY
|
||||
})
|
||||
});
|
||||
agent.subscribe((event) => {
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
if (event.message.role === 'assistant') {
|
||||
// Create placeholder in UI
|
||||
ui.addMessage({ id: tempId, role: 'assistant', content: [] });
|
||||
}
|
||||
break;
|
||||
|
||||
// Via proxy (browser)
|
||||
case 'message_update':
|
||||
// Replace placeholder content with partial content
|
||||
// This is only emitted for assistant messages
|
||||
ui.updateMessage(tempId, event.message.content);
|
||||
break;
|
||||
|
||||
case 'message_end':
|
||||
if (event.message.role === 'assistant') {
|
||||
// Finalize with complete message
|
||||
ui.finalizeMessage(tempId, event.message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**Accessing the current partial message:**
|
||||
|
||||
During streaming, `agent.state.streamMessage` contains the current partial message. This is useful for rendering outside the event handler:
|
||||
|
||||
```typescript
|
||||
// In a render loop or reactive binding
|
||||
if (agent.state.isStreaming && agent.state.streamMessage) {
|
||||
renderPartialMessage(agent.state.streamMessage);
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Message Types
|
||||
|
||||
Extend `AgentMessage` for app-specific messages via declaration merging:
|
||||
|
||||
```typescript
|
||||
declare module '@mariozechner/pi-agent-core' {
|
||||
interface CustomAgentMessages {
|
||||
artifact: { role: 'artifact'; code: string; language: string; timestamp: number };
|
||||
notification: { role: 'notification'; text: string; timestamp: number };
|
||||
}
|
||||
}
|
||||
|
||||
// AgentMessage now includes your custom types
|
||||
const msg: AgentMessage = { role: 'artifact', code: '...', language: 'typescript', timestamp: Date.now() };
|
||||
```
|
||||
|
||||
Custom messages are stored in state but filtered out by the default `convertToLlm`. Provide your own converter to handle them:
|
||||
|
||||
```typescript
|
||||
const agent = new Agent({
|
||||
transport: new AppTransport({
|
||||
endpoint: '/api/agent',
|
||||
headers: { 'Authorization': 'Bearer ...' }
|
||||
})
|
||||
convertToLlm: (messages) => {
|
||||
return messages
|
||||
.filter(m => m.role !== 'notification') // Filter out UI-only messages
|
||||
.map(m => {
|
||||
if (m.role === 'artifact') {
|
||||
// Convert to user message so LLM sees the artifact
|
||||
return { role: 'user', content: `[Artifact: ${m.language}]\n${m.code}`, timestamp: m.timestamp };
|
||||
}
|
||||
return m;
|
||||
});
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
|
|
@ -121,45 +273,76 @@ const agent = new Agent({
|
|||
Queue messages to inject at the next turn:
|
||||
|
||||
```typescript
|
||||
// Queue mode: 'all' or 'one-at-a-time'
|
||||
agent.setQueueMode('one-at-a-time');
|
||||
|
||||
// Queue a message while agent is streaming
|
||||
await agent.queueMessage({
|
||||
// Queue while agent is streaming
|
||||
agent.queueMessage({
|
||||
role: 'user',
|
||||
content: 'Additional context...',
|
||||
content: 'Stop what you are doing and focus on this instead.',
|
||||
timestamp: Date.now()
|
||||
});
|
||||
```
|
||||
|
||||
## Attachments
|
||||
When queued messages are detected after a tool call, remaining tool calls are skipped with error results ("Skipped due to queued user message"). The queued message is then injected before the next assistant response.
|
||||
|
||||
User messages can include attachments:
|
||||
## Images
|
||||
|
||||
User messages can include images:
|
||||
|
||||
```typescript
|
||||
await agent.prompt('What is in this image?', [{
|
||||
id: 'img1',
|
||||
type: 'image',
|
||||
fileName: 'photo.jpg',
|
||||
mimeType: 'image/jpeg',
|
||||
size: 102400,
|
||||
content: base64ImageData
|
||||
}]);
|
||||
await agent.prompt('What is in this image?', [
|
||||
{ type: 'image', data: base64ImageData, mimeType: 'image/jpeg' }
|
||||
]);
|
||||
```
|
||||
|
||||
## Custom Message Types
|
||||
## Proxy Usage
|
||||
|
||||
Extend `AppMessage` for app-specific messages via declaration merging:
|
||||
For browser apps that need to proxy through a backend, use `streamProxy`:
|
||||
|
||||
```typescript
|
||||
declare module '@mariozechner/pi-agent-core' {
|
||||
interface CustomMessages {
|
||||
artifact: { role: 'artifact'; code: string; language: string };
|
||||
}
|
||||
import { Agent, streamProxy } from '@mariozechner/pi-agent-core';
|
||||
|
||||
const agent = new Agent({
|
||||
streamFn: (model, context, options) => streamProxy(
|
||||
'/api/agent',
|
||||
model,
|
||||
context,
|
||||
options,
|
||||
{ 'Authorization': 'Bearer ...' }
|
||||
)
|
||||
});
|
||||
```
|
||||
|
||||
## Low-Level API
|
||||
|
||||
For more control, use `agentLoop` and `agentLoopContinue` directly:
|
||||
|
||||
```typescript
|
||||
import { agentLoop, agentLoopContinue, AgentLoopContext, AgentLoopConfig } from '@mariozechner/pi-agent-core';
|
||||
import { getModel, streamSimple } from '@mariozechner/pi-ai';
|
||||
|
||||
const context: AgentLoopContext = {
|
||||
systemPrompt: 'You are helpful.',
|
||||
messages: [],
|
||||
tools: [myTool]
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: getModel('openai', 'gpt-4o-mini'),
|
||||
convertToLlm: (msgs) => msgs.filter(m => ['user', 'assistant', 'toolResult'].includes(m.role))
|
||||
};
|
||||
|
||||
const userMessage = { role: 'user', content: 'Hello', timestamp: Date.now() };
|
||||
|
||||
for await (const event of agentLoop(userMessage, context, config, undefined, streamSimple)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
|
||||
// Now AppMessage includes your custom type
|
||||
const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript' };
|
||||
// Continue from existing context (e.g., after overflow recovery)
|
||||
// Last message in context must convert to 'user' or 'toolResult'
|
||||
for await (const event of agentLoopContinue(context, config, undefined, streamSimple)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
|
@ -168,13 +351,14 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript'
|
|||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `prompt(text, attachments?)` | Send a user prompt |
|
||||
| `continue()` | Continue from current context (for retry after overflow) |
|
||||
| `prompt(text, images?)` | Send a user prompt with optional images |
|
||||
| `prompt(message)` | Send an AgentMessage directly (must convert to user/toolResult) |
|
||||
| `continue()` | Continue from current context (last message must convert to user/toolResult) |
|
||||
| `abort()` | Abort current operation |
|
||||
| `waitForIdle()` | Returns promise that resolves when agent is idle |
|
||||
| `waitForIdle()` | Promise that resolves when agent is idle |
|
||||
| `reset()` | Clear all messages and state |
|
||||
| `subscribe(fn)` | Subscribe to events, returns unsubscribe function |
|
||||
| `queueMessage(msg)` | Queue message for next turn |
|
||||
| `queueMessage(msg)` | Queue message for next turn (must convert to user/toolResult) |
|
||||
| `clearMessageQueue()` | Clear queued messages |
|
||||
|
||||
### State Mutators
|
||||
|
|
@ -184,7 +368,7 @@ const msg: AppMessage = { role: 'artifact', code: '...', language: 'typescript'
|
|||
| `setSystemPrompt(v)` | Update system prompt |
|
||||
| `setModel(m)` | Switch model |
|
||||
| `setThinkingLevel(l)` | Set reasoning level |
|
||||
| `setQueueMode(m)` | Set queue mode ('all' or 'one-at-a-time') |
|
||||
| `setQueueMode(m)` | Set queue mode |
|
||||
| `setTools(t)` | Update available tools |
|
||||
| `replaceMessages(ms)` | Replace all messages |
|
||||
| `appendMessage(m)` | Append a message |
|
||||
|
|
|
|||
|
|
@ -1,33 +1,52 @@
|
|||
import { streamSimple } from "../stream.js";
|
||||
import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js";
|
||||
import { EventStream } from "../utils/event-stream.js";
|
||||
import { validateToolArguments } from "../utils/validation.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, QueuedMessage } from "./types.js";
|
||||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new user message.
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompt: UserMessage,
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: typeof streamSimple,
|
||||
): EventStream<AgentEvent, AgentContext["messages"]> {
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentContext["messages"] = [prompt];
|
||||
const newMessages: AgentMessage[] = [...prompts];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, prompt],
|
||||
messages: [...context.messages, ...prompts],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
for (const prompt of prompts) {
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
}
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
|
@ -37,33 +56,34 @@ export function agentLoop(
|
|||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retry after overflow - context already has user message or tool results.
|
||||
* Throws if the last message is not a user message or tool result.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: typeof streamSimple,
|
||||
): EventStream<AgentEvent, AgentContext["messages"]> {
|
||||
// Validate that we can continue from this context
|
||||
const lastMessage = context.messages[context.messages.length - 1];
|
||||
if (!lastMessage) {
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
|
||||
throw new Error(`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`);
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentContext["messages"] = [];
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = { ...context };
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
// No user message events - we're continuing from existing context
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
|
@ -71,28 +91,28 @@ export function agentLoopContinue(
|
|||
return stream;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
|
||||
return new EventStream<AgentEvent, AgentContext["messages"]>(
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared loop logic for both agentLoop and agentLoopContinue.
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentContext["messages"],
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentContext["messages"]>,
|
||||
streamFn?: typeof streamSimple,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let hasMoreToolCalls = true;
|
||||
let firstTurn = true;
|
||||
let queuedMessages: QueuedMessage<any>[] = (await config.getQueuedMessages?.()) || [];
|
||||
let queuedAfterTools: QueuedMessage<any>[] | null = null;
|
||||
let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
|
||||
let queuedAfterTools: AgentMessage[] | null = null;
|
||||
|
||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
|
|
@ -101,15 +121,13 @@ async function runLoop(
|
|||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process queued messages first (inject before next assistant response)
|
||||
// Process queued messages (inject before next assistant response)
|
||||
if (queuedMessages.length > 0) {
|
||||
for (const { original, llm } of queuedMessages) {
|
||||
stream.push({ type: "message_start", message: original });
|
||||
stream.push({ type: "message_end", message: original });
|
||||
if (llm) {
|
||||
currentContext.messages.push(llm);
|
||||
newMessages.push(llm);
|
||||
}
|
||||
for (const message of queuedMessages) {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
queuedMessages = [];
|
||||
}
|
||||
|
|
@ -119,7 +137,6 @@ async function runLoop(
|
|||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
// Stop the loop on error or abort
|
||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
|
|
@ -132,7 +149,6 @@ async function runLoop(
|
|||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
// Execute tool calls
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext.tools,
|
||||
message,
|
||||
|
|
@ -142,10 +158,14 @@ async function runLoop(
|
|||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
queuedAfterTools = toolExecution.queuedMessages ?? null;
|
||||
currentContext.messages.push(...toolResults);
|
||||
newMessages.push(...toolResults);
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
}
|
||||
stream.push({ type: "turn_end", message, toolResults: toolResults });
|
||||
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
|
||||
// Get queued messages after turn completes
|
||||
if (queuedAfterTools && queuedAfterTools.length > 0) {
|
||||
|
|
@ -160,41 +180,44 @@ async function runLoop(
|
|||
stream.end(newMessages);
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentContext["messages"]>,
|
||||
streamFn?: typeof streamSimple,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Convert AgentContext to Context for streamSimple
|
||||
// Use a copy of messages to avoid mutating the original context
|
||||
const processedMessages = config.preprocessor
|
||||
? await config.preprocessor(context.messages, signal)
|
||||
: [...context.messages];
|
||||
const processedContext: Context = {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: [...processedMessages].map((m) => {
|
||||
if (m.role === "toolResult") {
|
||||
// biome-ignore lint/correctness/noUnusedVariables: fine here
|
||||
const { details, ...rest } = m;
|
||||
return rest;
|
||||
} else {
|
||||
return m;
|
||||
}
|
||||
}),
|
||||
tools: context.tools, // AgentTool extends Tool, so this works
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
// Use custom stream function if provided, otherwise use default streamSimple
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key for every assistant response (important for expiring tokens)
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, processedContext, { ...config, apiKey: resolvedApiKey, signal });
|
||||
const response = await streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
|
@ -220,7 +243,11 @@ async function streamAssistantResponse(
|
|||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } });
|
||||
stream.push({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
|
|
@ -244,16 +271,19 @@ async function streamAssistantResponse(
|
|||
return await response.result();
|
||||
}
|
||||
|
||||
async function executeToolCalls<T>(
|
||||
tools: AgentTool<any, T>[] | undefined,
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
tools: AgentTool<any>[] | undefined,
|
||||
assistantMessage: AssistantMessage,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, Message[]>,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
||||
): Promise<{ toolResults: ToolResultMessage<T>[]; queuedMessages?: QueuedMessage<any>[] }> {
|
||||
): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
const results: ToolResultMessage<any>[] = [];
|
||||
let queuedMessages: QueuedMessage<any>[] | undefined;
|
||||
const results: ToolResultMessage[] = [];
|
||||
let queuedMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
|
|
@ -266,16 +296,14 @@ async function executeToolCalls<T>(
|
|||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
let result: AgentToolResult<T>;
|
||||
let result: AgentToolResult<any>;
|
||||
let isError = false;
|
||||
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
// Validate arguments using shared validation function
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
|
||||
// Execute with validated, typed arguments, passing update callback
|
||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
|
|
@ -288,7 +316,7 @@ async function executeToolCalls<T>(
|
|||
} catch (e) {
|
||||
result = {
|
||||
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
|
||||
details: {} as T,
|
||||
details: {},
|
||||
};
|
||||
isError = true;
|
||||
}
|
||||
|
|
@ -301,7 +329,7 @@ async function executeToolCalls<T>(
|
|||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage<T> = {
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
|
|
@ -315,6 +343,7 @@ async function executeToolCalls<T>(
|
|||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
// Check for queued messages - skip remaining tools if user interrupted
|
||||
if (getQueuedMessages) {
|
||||
const queued = await getQueuedMessages();
|
||||
if (queued.length > 0) {
|
||||
|
|
@ -331,13 +360,13 @@ async function executeToolCalls<T>(
|
|||
return { toolResults: results, queuedMessages };
|
||||
}
|
||||
|
||||
function skipToolCall<T>(
|
||||
function skipToolCall(
|
||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
||||
stream: EventStream<AgentEvent, Message[]>,
|
||||
): ToolResultMessage<T> {
|
||||
const result: AgentToolResult<T> = {
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {} as T,
|
||||
details: {},
|
||||
};
|
||||
|
||||
stream.push({
|
||||
|
|
@ -354,12 +383,12 @@ function skipToolCall<T>(
|
|||
isError: true,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage<T> = {
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
details: {},
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
|
@ -1,62 +1,66 @@
|
|||
import type { ImageContent, Message, QueuedMessage, ReasoningEffort, TextContent } from "@mariozechner/pi-ai";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import type { AgentTransport } from "./transports/types.js";
|
||||
import type { AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "./types.js";
|
||||
/**
|
||||
* Agent class that uses the agent-loop directly.
|
||||
* No transport abstraction - calls streamSimple via the loop.
|
||||
*/
|
||||
|
||||
import {
|
||||
getModel,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
type ReasoningEffort,
|
||||
streamSimple,
|
||||
type TextContent,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentState,
|
||||
AgentTool,
|
||||
StreamFn,
|
||||
ThinkingLevel,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Default message transformer: Keep only LLM-compatible messages, strip app-specific fields.
|
||||
* Converts attachments to proper content blocks (images → ImageContent, documents → TextContent).
|
||||
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
||||
*/
|
||||
function defaultMessageTransformer(messages: AppMessage[]): Message[] {
|
||||
return messages
|
||||
.filter((m) => {
|
||||
// Only keep standard LLM message roles
|
||||
return m.role === "user" || m.role === "assistant" || m.role === "toolResult";
|
||||
})
|
||||
.map((m) => {
|
||||
if (m.role === "user") {
|
||||
const { attachments, ...rest } = m as any;
|
||||
|
||||
// If no attachments, return as-is
|
||||
if (!attachments || attachments.length === 0) {
|
||||
return rest as Message;
|
||||
}
|
||||
|
||||
// Convert attachments to content blocks
|
||||
const content = Array.isArray(rest.content) ? [...rest.content] : [{ type: "text", text: rest.content }];
|
||||
|
||||
for (const attachment of attachments as Attachment[]) {
|
||||
// Add image blocks for image attachments
|
||||
if (attachment.type === "image") {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: attachment.content,
|
||||
mimeType: attachment.mimeType,
|
||||
} as ImageContent);
|
||||
}
|
||||
// Add text blocks for documents with extracted text
|
||||
else if (attachment.type === "document" && attachment.extractedText) {
|
||||
content.push({
|
||||
type: "text",
|
||||
text: `\n\n[Document: ${attachment.fileName}]\n${attachment.extractedText}`,
|
||||
isDocument: true,
|
||||
} as TextContent);
|
||||
}
|
||||
}
|
||||
|
||||
return { ...rest, content } as Message;
|
||||
}
|
||||
return m as Message;
|
||||
});
|
||||
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
|
||||
}
|
||||
|
||||
export interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
transport: AgentTransport;
|
||||
// Transform app messages to LLM-compatible messages before sending to transport
|
||||
messageTransformer?: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
|
||||
// Queue mode: "all" = send all queued messages at once, "one-at-a-time" = send one queued message per turn
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
* Default filters to user/assistant/toolResult and converts attachments.
|
||||
*/
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to context before convertToLlm.
|
||||
* Use for context pruning, injecting external context, etc.
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
queueMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||
*/
|
||||
streamFn?: StreamFn;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
}
|
||||
|
||||
export class Agent {
|
||||
|
|
@ -71,20 +75,25 @@ export class Agent {
|
|||
pendingToolCalls: new Set<string>(),
|
||||
error: undefined,
|
||||
};
|
||||
|
||||
private listeners = new Set<(e: AgentEvent) => void>();
|
||||
private abortController?: AbortController;
|
||||
private transport: AgentTransport;
|
||||
private messageTransformer: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
|
||||
private messageQueue: Array<QueuedMessage<AppMessage>> = [];
|
||||
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
private messageQueue: AgentMessage[] = [];
|
||||
private queueMode: "all" | "one-at-a-time";
|
||||
public streamFn: StreamFn;
|
||||
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
private runningPrompt?: Promise<void>;
|
||||
private resolveRunningPrompt?: () => void;
|
||||
|
||||
constructor(opts: AgentOptions) {
|
||||
constructor(opts: AgentOptions = {}) {
|
||||
this._state = { ...this._state, ...opts.initialState };
|
||||
this.transport = opts.transport;
|
||||
this.messageTransformer = opts.messageTransformer || defaultMessageTransformer;
|
||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||
this.transformContext = opts.transformContext;
|
||||
this.queueMode = opts.queueMode || "one-at-a-time";
|
||||
this.streamFn = opts.streamFn || streamSimple;
|
||||
this.getApiKey = opts.getApiKey;
|
||||
}
|
||||
|
||||
get state(): AgentState {
|
||||
|
|
@ -96,12 +105,12 @@ export class Agent {
|
|||
return () => this.listeners.delete(fn);
|
||||
}
|
||||
|
||||
// State mutators - update internal state without emitting events
|
||||
// State mutators
|
||||
setSystemPrompt(v: string) {
|
||||
this._state.systemPrompt = v;
|
||||
}
|
||||
|
||||
setModel(m: typeof this._state.model) {
|
||||
setModel(m: Model<any>) {
|
||||
this._state.model = m;
|
||||
}
|
||||
|
||||
|
|
@ -117,25 +126,20 @@ export class Agent {
|
|||
return this.queueMode;
|
||||
}
|
||||
|
||||
setTools(t: typeof this._state.tools) {
|
||||
setTools(t: AgentTool<any>[]) {
|
||||
this._state.tools = t;
|
||||
}
|
||||
|
||||
replaceMessages(ms: AppMessage[]) {
|
||||
replaceMessages(ms: AgentMessage[]) {
|
||||
this._state.messages = ms.slice();
|
||||
}
|
||||
|
||||
appendMessage(m: AppMessage) {
|
||||
appendMessage(m: AgentMessage) {
|
||||
this._state.messages = [...this._state.messages, m];
|
||||
}
|
||||
|
||||
async queueMessage(m: AppMessage) {
|
||||
// Transform message and queue it for injection at next turn
|
||||
const transformed = await this.messageTransformer([m]);
|
||||
this.messageQueue.push({
|
||||
original: m,
|
||||
llm: transformed[0], // undefined if filtered out
|
||||
});
|
||||
queueMessage(m: AgentMessage) {
|
||||
this.messageQueue.push(m);
|
||||
}
|
||||
|
||||
clearMessageQueue() {
|
||||
|
|
@ -150,17 +154,10 @@ export class Agent {
|
|||
this.abortController?.abort();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a promise that resolves when the current prompt completes.
|
||||
* Returns immediately resolved promise if no prompt is running.
|
||||
*/
|
||||
waitForIdle(): Promise<void> {
|
||||
return this.runningPrompt ?? Promise.resolve();
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all messages and state. Call abort() first if a prompt is in flight.
|
||||
*/
|
||||
reset() {
|
||||
this._state.messages = [];
|
||||
this._state.isStreaming = false;
|
||||
|
|
@ -170,86 +167,57 @@ export class Agent {
|
|||
this.messageQueue = [];
|
||||
}
|
||||
|
||||
async prompt(input: string, attachments?: Attachment[]) {
|
||||
/** Send a prompt with an AgentMessage */
|
||||
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
||||
const model = this._state.model;
|
||||
if (!model) {
|
||||
throw new Error("No model configured");
|
||||
}
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
// Build user message with attachments
|
||||
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
||||
if (attachments?.length) {
|
||||
for (const a of attachments) {
|
||||
if (a.type === "image") {
|
||||
content.push({ type: "image", data: a.content, mimeType: a.mimeType });
|
||||
} else if (a.type === "document" && a.extractedText) {
|
||||
content.push({
|
||||
type: "text",
|
||||
text: `\n\n[Document: ${a.fileName}]\n${a.extractedText}`,
|
||||
isDocument: true,
|
||||
} as TextContent);
|
||||
}
|
||||
let msgs: AgentMessage[];
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
msgs = input;
|
||||
} else if (typeof input === "string") {
|
||||
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
||||
if (images && images.length > 0) {
|
||||
content.push(...images);
|
||||
}
|
||||
msgs = [
|
||||
{
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
} else {
|
||||
msgs = [input];
|
||||
}
|
||||
|
||||
const userMessage: AppMessage = {
|
||||
role: "user",
|
||||
content,
|
||||
attachments: attachments?.length ? attachments : undefined,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
await this._runAgentLoop(userMessage);
|
||||
await this._runLoop(msgs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from the current context without adding a new user message.
|
||||
* Used for retry after overflow recovery when context already has user message or tool results.
|
||||
*/
|
||||
/** Continue from current context (for retry after overflow) */
|
||||
async continue() {
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
|
||||
throw new Error(`Cannot continue from message role: ${lastMessage.role}`);
|
||||
if (messages[messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
await this._runAgentLoopContinue();
|
||||
await this._runLoop(undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal: Run the agent loop with a new user message.
|
||||
* Run the agent loop.
|
||||
* If messages are provided, starts a new conversation turn with those messages.
|
||||
* Otherwise, continues from existing context.
|
||||
*/
|
||||
private async _runAgentLoop(userMessage: AppMessage) {
|
||||
const { llmMessages, cfg } = await this._prepareRun();
|
||||
|
||||
const events = this.transport.run(llmMessages, userMessage as Message, cfg, this.abortController!.signal);
|
||||
|
||||
await this._processEvents(events);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal: Continue the agent loop from current context.
|
||||
*/
|
||||
private async _runAgentLoopContinue() {
|
||||
const { llmMessages, cfg } = await this._prepareRun();
|
||||
|
||||
const events = this.transport.continue(llmMessages, cfg, this.abortController!.signal);
|
||||
|
||||
await this._processEvents(events);
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare for running the agent loop.
|
||||
*/
|
||||
private async _prepareRun() {
|
||||
private async _runLoop(messages?: AgentMessage[]) {
|
||||
const model = this._state.model;
|
||||
if (!model) {
|
||||
throw new Error("No model configured");
|
||||
}
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
|
|
@ -265,87 +233,90 @@ export class Agent {
|
|||
? undefined
|
||||
: this._state.thinkingLevel === "minimal"
|
||||
? "low"
|
||||
: this._state.thinkingLevel;
|
||||
: (this._state.thinkingLevel as ReasoningEffort);
|
||||
|
||||
const cfg = {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: this._state.systemPrompt,
|
||||
messages: this._state.messages.slice(),
|
||||
tools: this._state.tools,
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
reasoning,
|
||||
getQueuedMessages: async <T>() => {
|
||||
convertToLlm: this.convertToLlm,
|
||||
transformContext: this.transformContext,
|
||||
getApiKey: this.getApiKey,
|
||||
getQueuedMessages: async () => {
|
||||
if (this.queueMode === "one-at-a-time") {
|
||||
if (this.messageQueue.length > 0) {
|
||||
const first = this.messageQueue[0];
|
||||
this.messageQueue = this.messageQueue.slice(1);
|
||||
return [first] as QueuedMessage<T>[];
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
} else {
|
||||
const queued = this.messageQueue.slice();
|
||||
this.messageQueue = [];
|
||||
return queued as QueuedMessage<T>[];
|
||||
return queued;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const llmMessages = await this.messageTransformer(this._state.messages);
|
||||
|
||||
return { llmMessages, cfg, model };
|
||||
}
|
||||
|
||||
/**
|
||||
* Process events from the transport.
|
||||
*/
|
||||
private async _processEvents(events: AsyncIterable<AgentEvent>) {
|
||||
const model = this._state.model!;
|
||||
const generatedMessages: AppMessage[] = [];
|
||||
let partial: AppMessage | null = null;
|
||||
let partial: AgentMessage | null = null;
|
||||
|
||||
try {
|
||||
for await (const ev of events) {
|
||||
switch (ev.type) {
|
||||
case "message_start": {
|
||||
partial = ev.message as AppMessage;
|
||||
this._state.streamMessage = ev.message as Message;
|
||||
const stream = messages
|
||||
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
|
||||
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
// Update internal state based on events
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
}
|
||||
case "message_update": {
|
||||
partial = ev.message as AppMessage;
|
||||
this._state.streamMessage = ev.message as Message;
|
||||
|
||||
case "message_update":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
}
|
||||
case "message_end": {
|
||||
|
||||
case "message_end":
|
||||
partial = null;
|
||||
this._state.streamMessage = null;
|
||||
this.appendMessage(ev.message as AppMessage);
|
||||
generatedMessages.push(ev.message as AppMessage);
|
||||
this.appendMessage(event.message);
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_start": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.add(ev.toolCallId);
|
||||
s.add(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_end": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.delete(ev.toolCallId);
|
||||
s.delete(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
case "turn_end": {
|
||||
if (ev.message.role === "assistant" && ev.message.errorMessage) {
|
||||
this._state.error = ev.message.errorMessage;
|
||||
|
||||
case "turn_end":
|
||||
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
||||
this._state.error = (event.message as any).errorMessage;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "agent_end": {
|
||||
|
||||
case "agent_end":
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
this.emit(ev as AgentEvent);
|
||||
// Emit to listeners
|
||||
this.emit(event);
|
||||
}
|
||||
|
||||
// Handle any remaining partial message
|
||||
|
|
@ -357,8 +328,7 @@ export class Agent {
|
|||
(c.type === "toolCall" && c.name.trim().length > 0),
|
||||
);
|
||||
if (!onlyEmpty) {
|
||||
this.appendMessage(partial as AppMessage);
|
||||
generatedMessages.push(partial as AppMessage);
|
||||
this.appendMessage(partial);
|
||||
} else {
|
||||
if (this.abortController?.signal.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
|
|
@ -366,7 +336,7 @@ export class Agent {
|
|||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
const msg: Message = {
|
||||
const errorMsg: AgentMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
api: model.api,
|
||||
|
|
@ -383,10 +353,11 @@ export class Agent {
|
|||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
this.appendMessage(msg as AppMessage);
|
||||
generatedMessages.push(msg as AppMessage);
|
||||
} as AgentMessage;
|
||||
|
||||
this.appendMessage(errorMsg);
|
||||
this._state.error = err?.message || String(err);
|
||||
this.emit({ type: "agent_end", messages: [errorMsg] });
|
||||
} finally {
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
|
|
|
|||
|
|
@ -1,22 +1,6 @@
|
|||
// Core Agent
|
||||
export { Agent, type AgentOptions } from "./agent.js";
|
||||
// Transports
|
||||
export {
|
||||
type AgentRunConfig,
|
||||
type AgentTransport,
|
||||
AppTransport,
|
||||
type AppTransportOptions,
|
||||
ProviderTransport,
|
||||
type ProviderTransportOptions,
|
||||
type ProxyAssistantMessageEvent,
|
||||
} from "./transports/index.js";
|
||||
export * from "./agent.js";
|
||||
// Loop functions
|
||||
export * from "./agent-loop.js";
|
||||
// Types
|
||||
export type {
|
||||
AgentEvent,
|
||||
AgentState,
|
||||
AppMessage,
|
||||
Attachment,
|
||||
CustomMessages,
|
||||
ThinkingLevel,
|
||||
UserMessageWithAttachments,
|
||||
} from "./types.js";
|
||||
export * from "./types.js";
|
||||
|
|
|
|||
340
packages/agent/src/proxy.ts
Normal file
340
packages/agent/src/proxy.ts
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
/**
|
||||
* Proxy stream function for apps that route LLM calls through a server.
|
||||
* The server manages auth and proxies requests to LLM providers.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type Context,
|
||||
EventStream,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type ToolCall,
|
||||
} from "@mariozechner/pi-ai";
|
||||
// Internal import for JSON parsing utility
|
||||
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
|
||||
|
||||
// Create stream class matching ProxyMessageEventStream
|
||||
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
|
||||
usage: AssistantMessage["usage"];
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
errorMessage?: string;
|
||||
usage: AssistantMessage["usage"];
|
||||
};
|
||||
|
||||
export interface ProxyStreamOptions extends SimpleStreamOptions {
|
||||
/** Auth token for the proxy server */
|
||||
authToken: string;
|
||||
/** Proxy server URL (e.g., "https://genai.example.com") */
|
||||
proxyUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling LLM providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*
|
||||
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const agent = new Agent({
|
||||
* streamFn: (model, context, options) =>
|
||||
* streamProxy(model, context, {
|
||||
* ...options,
|
||||
* authToken: await getAuthToken(),
|
||||
* proxyUrl: "https://genai.example.com",
|
||||
* }),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
|
||||
const stream = new ProxyMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${options.proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
const event = processProxyEvent(proxyEvent, partial);
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
const reason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.stopReason = reason;
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason,
|
||||
error: partial,
|
||||
});
|
||||
stream.end();
|
||||
} finally {
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a proxy event and update the partial message.
|
||||
*/
|
||||
function processProxyEvent(
|
||||
proxyEvent: ProxyAssistantMessageEvent,
|
||||
partial: AssistantMessage,
|
||||
): AssistantMessageEvent | undefined {
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
return { type: "start", partial };
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
|
||||
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
|
||||
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments = parseStreamingJson((content as any).partialJson) || {};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
|
||||
default: {
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,397 +0,0 @@
|
|||
import type {
|
||||
AgentContext,
|
||||
AgentLoopConfig,
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
ToolCall,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai";
|
||||
import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js";
|
||||
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
|
||||
import type { ProxyAssistantMessageEvent } from "./proxy-types.js";
|
||||
import type { AgentRunConfig, AgentTransport } from "./types.js";
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*/
|
||||
function streamSimpleProxy(
|
||||
model: Model<any>,
|
||||
context: Context,
|
||||
options: SimpleStreamOptions & { authToken: string },
|
||||
proxyUrl: string,
|
||||
): AssistantMessageEventStream {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
// Set up abort handler to cancel the reader
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
// Don't send apiKey or signal - those are added server-side
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response, use default message
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
// Check if aborted after reading
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
let event: AssistantMessageEvent | undefined;
|
||||
|
||||
// Handle different event types
|
||||
// Server sends events with partial for non-delta events,
|
||||
// and without partial for delta events
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
event = { type: "start", partial };
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
event = { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
break;
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
event = {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
} else {
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
event = {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
} else {
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
};
|
||||
event = { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
break;
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
event = {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
} else {
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
event = {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
} else {
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
event = { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
break;
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments = parseStreamingJson((content as any).partialJson) || {};
|
||||
event = {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
} else {
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
event = {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
event = { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
break;
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
event = { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
break;
|
||||
|
||||
default: {
|
||||
// Exhaustive check
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled event type: ${(proxyEvent as any).type}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Push the event to stream
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
} else {
|
||||
throw new Error("Failed to create event from proxy event");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if aborted after reading
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
partial.stopReason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: partial.stopReason,
|
||||
error: partial,
|
||||
} satisfies AssistantMessageEvent);
|
||||
stream.end();
|
||||
} finally {
|
||||
// Clean up abort handler
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export interface AppTransportOptions {
|
||||
/**
|
||||
* Proxy server URL. The server manages user accounts and proxies requests to LLM providers.
|
||||
* Example: "https://genai.mariozechner.at"
|
||||
*/
|
||||
proxyUrl: string;
|
||||
|
||||
/**
|
||||
* Function to retrieve auth token for the proxy server.
|
||||
* The token is used for user authentication and authorization.
|
||||
*/
|
||||
getAuthToken: () => Promise<string> | string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transport that uses an app server with user authentication tokens.
|
||||
* The server manages user accounts and proxies requests to LLM providers.
|
||||
*/
|
||||
export class AppTransport implements AgentTransport {
|
||||
private options: AppTransportOptions;
|
||||
|
||||
constructor(options: AppTransportOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
private async getStreamFn(authToken: string) {
|
||||
return <TApi extends Api>(model: Model<TApi>, context: Context, options?: SimpleStreamOptions) => {
|
||||
return streamSimpleProxy(
|
||||
model,
|
||||
context,
|
||||
{
|
||||
...options,
|
||||
authToken,
|
||||
},
|
||||
this.options.proxyUrl,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
|
||||
return {
|
||||
systemPrompt: cfg.systemPrompt,
|
||||
messages,
|
||||
tools: cfg.tools,
|
||||
};
|
||||
}
|
||||
|
||||
private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig {
|
||||
return {
|
||||
model: cfg.model,
|
||||
reasoning: cfg.reasoning,
|
||||
getQueuedMessages: cfg.getQueuedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const authToken = await this.options.getAuthToken();
|
||||
if (!authToken) {
|
||||
throw new Error("Auth token is required for AppTransport");
|
||||
}
|
||||
|
||||
const streamFn = await this.getStreamFn(authToken);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(cfg);
|
||||
|
||||
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
|
||||
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const authToken = await this.options.getAuthToken();
|
||||
if (!authToken) {
|
||||
throw new Error("Auth token is required for AppTransport");
|
||||
}
|
||||
|
||||
const streamFn = await this.getStreamFn(authToken);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(cfg);
|
||||
|
||||
for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
import {
|
||||
type AgentContext,
|
||||
type AgentLoopConfig,
|
||||
agentLoop,
|
||||
agentLoopContinue,
|
||||
type Message,
|
||||
type UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type { AgentRunConfig, AgentTransport } from "./types.js";
|
||||
|
||||
export interface ProviderTransportOptions {
|
||||
/**
|
||||
* Function to retrieve API key for a given provider.
|
||||
* If not provided, transport will try to use environment variables.
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Optional CORS proxy URL for browser environments.
|
||||
* If provided, all requests will be routed through this proxy.
|
||||
* Format: "https://proxy.example.com"
|
||||
*/
|
||||
corsProxyUrl?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transport that calls LLM providers directly.
|
||||
* Optionally routes calls through a CORS proxy if configured.
|
||||
*/
|
||||
export class ProviderTransport implements AgentTransport {
|
||||
private options: ProviderTransportOptions;
|
||||
|
||||
constructor(options: ProviderTransportOptions = {}) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
private getModel(cfg: AgentRunConfig) {
|
||||
let model = cfg.model;
|
||||
if (this.options.corsProxyUrl && cfg.model.baseUrl) {
|
||||
model = {
|
||||
...cfg.model,
|
||||
baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`,
|
||||
};
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
|
||||
return {
|
||||
systemPrompt: cfg.systemPrompt,
|
||||
messages,
|
||||
tools: cfg.tools,
|
||||
};
|
||||
}
|
||||
|
||||
private buildLoopConfig(model: AgentRunConfig["model"], cfg: AgentRunConfig): AgentLoopConfig {
|
||||
return {
|
||||
model,
|
||||
reasoning: cfg.reasoning,
|
||||
// Resolve API key per assistant response (important for expiring OAuth tokens)
|
||||
getApiKey: this.options.getApiKey,
|
||||
getQueuedMessages: cfg.getQueuedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const model = this.getModel(cfg);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(model, cfg);
|
||||
|
||||
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
|
||||
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const model = this.getModel(cfg);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(model, cfg);
|
||||
|
||||
for await (const ev of agentLoopContinue(context, pc, signal)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
export { AppTransport, type AppTransportOptions } from "./AppTransport.js";
|
||||
export { ProviderTransport, type ProviderTransportOptions } from "./ProviderTransport.js";
|
||||
export type { ProxyAssistantMessageEvent } from "./proxy-types.js";
|
||||
export type { AgentRunConfig, AgentTransport } from "./types.js";
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import type { StopReason, Usage } from "@mariozechner/pi-ai";
|
||||
|
||||
/**
|
||||
* Event types emitted by the proxy server.
|
||||
* The server strips the `partial` field from delta events to reduce bandwidth.
|
||||
* Clients reconstruct the partial message from these events.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; usage: Usage }
|
||||
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; errorMessage: string; usage: Usage };
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import type { AgentEvent, AgentTool, Message, Model, QueuedMessage, ReasoningEffort } from "@mariozechner/pi-ai";
|
||||
|
||||
/**
|
||||
* The minimal configuration needed to run an agent turn.
|
||||
*/
|
||||
export interface AgentRunConfig {
|
||||
systemPrompt: string;
|
||||
tools: AgentTool<any>[];
|
||||
model: Model<any>;
|
||||
reasoning?: ReasoningEffort;
|
||||
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transport interface for executing agent turns.
|
||||
* Transports handle the communication with LLM providers,
|
||||
* abstracting away the details of API calls, proxies, etc.
|
||||
*
|
||||
* Events yielded must match the @mariozechner/pi-ai AgentEvent types.
|
||||
*/
|
||||
export interface AgentTransport {
|
||||
/** Run with a new user message */
|
||||
run(
|
||||
messages: Message[],
|
||||
userMessage: Message,
|
||||
config: AgentRunConfig,
|
||||
signal?: AbortSignal,
|
||||
): AsyncIterable<AgentEvent>;
|
||||
|
||||
/** Continue from current context (no new user message) */
|
||||
continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable<AgentEvent>;
|
||||
}
|
||||
|
|
@ -1,26 +1,86 @@
|
|||
import type {
|
||||
AgentTool,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
streamSimple,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
|
||||
/** Stream function - can return sync or Promise for async config lookup */
|
||||
export type StreamFn = (
|
||||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Attachment type definition.
|
||||
* Processing is done by consumers (e.g., document extraction in web-ui).
|
||||
* Configuration for the agent loop.
|
||||
*/
|
||||
export interface Attachment {
|
||||
id: string;
|
||||
type: "image" | "document";
|
||||
fileName: string;
|
||||
mimeType: string;
|
||||
size: number;
|
||||
content: string; // base64 encoded (without data URL prefix)
|
||||
extractedText?: string; // For documents
|
||||
preview?: string; // base64 image preview
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
*
|
||||
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
|
||||
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
|
||||
* status messages) should be filtered out.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* convertToLlm: (messages) => messages.flatMap(m => {
|
||||
* if (m.role === "hookMessage") {
|
||||
* // Convert custom message to user message
|
||||
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
|
||||
* }
|
||||
* if (m.role === "notification") {
|
||||
* // Filter out UI-only messages
|
||||
* return [];
|
||||
* }
|
||||
* // Pass through standard LLM messages
|
||||
* return [m];
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to the context before `convertToLlm`.
|
||||
*
|
||||
* Use this for operations that work at the AgentMessage level:
|
||||
* - Context window management (pruning old messages)
|
||||
* - Injecting context from external sources
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* transformContext: async (messages) => {
|
||||
* if (estimateTokens(messages) > MAX_TOKENS) {
|
||||
* return pruneOldMessages(messages);
|
||||
* }
|
||||
* return messages;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
*
|
||||
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
|
||||
* during long-running tool execution phases.
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Returns queued messages to inject into the conversation.
|
||||
*
|
||||
* Called after each turn to check for user interruptions or injected messages.
|
||||
* If messages are returned, they're added to the context before the next LLM call.
|
||||
*/
|
||||
getQueuedMessages?: () => Promise<AgentMessage[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -29,11 +89,6 @@ export interface Attachment {
|
|||
*/
|
||||
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/**
|
||||
* User message with optional attachments.
|
||||
*/
|
||||
export type UserMessageWithAttachments = UserMessage & { attachments?: Attachment[] };
|
||||
|
||||
/**
|
||||
* Extensible interface for custom app messages.
|
||||
* Apps can extend via declaration merging:
|
||||
|
|
@ -41,27 +96,23 @@ export type UserMessageWithAttachments = UserMessage & { attachments?: Attachmen
|
|||
* @example
|
||||
* ```typescript
|
||||
* declare module "@mariozechner/agent" {
|
||||
* interface CustomMessages {
|
||||
* interface CustomAgentMessages {
|
||||
* artifact: ArtifactMessage;
|
||||
* notification: NotificationMessage;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export interface CustomMessages {
|
||||
export interface CustomAgentMessages {
|
||||
// Empty by default - apps extend via declaration merging
|
||||
}
|
||||
|
||||
/**
|
||||
* AppMessage: Union of LLM messages + attachments + custom messages.
|
||||
* AgentMessage: Union of LLM messages + custom messages.
|
||||
* This abstraction allows apps to add custom message types while maintaining
|
||||
* type safety and compatibility with the base LLM messages.
|
||||
*/
|
||||
export type AppMessage =
|
||||
| AssistantMessage
|
||||
| UserMessageWithAttachments
|
||||
| Message // Includes ToolResultMessage
|
||||
| CustomMessages[keyof CustomMessages];
|
||||
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
|
||||
|
||||
/**
|
||||
* Agent state containing all configuration and conversation data.
|
||||
|
|
@ -71,13 +122,42 @@ export interface AgentState {
|
|||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AppMessage[]; // Can include attachments + custom message types
|
||||
messages: AgentMessage[]; // Can include attachments + custom message types
|
||||
isStreaming: boolean;
|
||||
streamMessage: Message | null;
|
||||
streamMessage: AgentMessage | null;
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: AgentMessage[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Events emitted by the Agent for UI updates.
|
||||
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
|
||||
|
|
@ -85,15 +165,15 @@ export interface AgentState {
|
|||
export type AgentEvent =
|
||||
// Agent lifecycle
|
||||
| { type: "agent_start" }
|
||||
| { type: "agent_end"; messages: AppMessage[] }
|
||||
| { type: "agent_end"; messages: AgentMessage[] }
|
||||
// Turn lifecycle - a turn is one assistant response + any tool calls/results
|
||||
| { type: "turn_start" }
|
||||
| { type: "turn_end"; message: AppMessage; toolResults: ToolResultMessage[] }
|
||||
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
|
||||
// Message lifecycle - emitted for user, assistant, and toolResult messages
|
||||
| { type: "message_start"; message: AppMessage }
|
||||
| { type: "message_start"; message: AgentMessage }
|
||||
// Only emitted for assistant messages during streaming
|
||||
| { type: "message_update"; message: AppMessage; assistantMessageEvent: AssistantMessageEvent }
|
||||
| { type: "message_end"; message: AppMessage }
|
||||
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
|
||||
| { type: "message_end"; message: AgentMessage }
|
||||
// Tool execution lifecycle
|
||||
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
|
||||
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }
|
||||
|
|
|
|||
535
packages/agent/test/agent-loop.test.ts
Normal file
535
packages/agent/test/agent-loop.test.ts
Normal file
|
|
@ -0,0 +1,535 @@
|
|||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
EventStream,
|
||||
type Message,
|
||||
type Model,
|
||||
type UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AgentTool } from "../src/types.js";
|
||||
|
||||
// Mock stream for testing - mimics MockAssistantStream
|
||||
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function createUsage() {
|
||||
return {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
function createModel(): Model<"openai-responses"> {
|
||||
return {
|
||||
id: "mock",
|
||||
name: "mock",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://example.invalid",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 2048,
|
||||
};
|
||||
}
|
||||
|
||||
function createAssistantMessage(
|
||||
content: AssistantMessage["content"],
|
||||
stopReason: AssistantMessage["stopReason"] = "stop",
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content,
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: createUsage(),
|
||||
stopReason,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createUserMessage(text: string): UserMessage {
|
||||
return {
|
||||
role: "user",
|
||||
content: text,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
// Simple identity converter for tests - just passes through standard messages
|
||||
function identityConverter(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
}
|
||||
|
||||
describe("agentLoop with AgentMessage", () => {
|
||||
it("should emit events with AgentMessage types", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Hi there!" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should have user message and assistant message
|
||||
expect(messages.length).toBe(2);
|
||||
expect(messages[0].role).toBe("user");
|
||||
expect(messages[1].role).toBe("assistant");
|
||||
|
||||
// Verify event sequence
|
||||
const eventTypes = events.map((e) => e.type);
|
||||
expect(eventTypes).toContain("agent_start");
|
||||
expect(eventTypes).toContain("turn_start");
|
||||
expect(eventTypes).toContain("message_start");
|
||||
expect(eventTypes).toContain("message_end");
|
||||
expect(eventTypes).toContain("turn_end");
|
||||
expect(eventTypes).toContain("agent_end");
|
||||
});
|
||||
|
||||
it("should handle custom message types via convertToLlm", async () => {
|
||||
// Create a custom message type
|
||||
interface CustomNotification {
|
||||
role: "notification";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const notification: CustomNotification = {
|
||||
role: "notification",
|
||||
text: "This is a notification",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [notification as unknown as AgentMessage], // Custom message in context
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
let convertedMessages: Message[] = [];
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Filter out notifications, convert rest
|
||||
convertedMessages = messages
|
||||
.filter((m) => (m as { role: string }).role !== "notification")
|
||||
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// The notification should have been filtered out in convertToLlm
|
||||
expect(convertedMessages.length).toBe(1); // Only user message
|
||||
expect(convertedMessages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("should apply transformContext before convertToLlm", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [
|
||||
createUserMessage("old message 1"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 1" }]),
|
||||
createUserMessage("old message 2"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 2" }]),
|
||||
],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("new message");
|
||||
|
||||
let transformedMessages: AgentMessage[] = [];
|
||||
let convertedMessages: Message[] = [];
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
transformContext: async (messages) => {
|
||||
// Keep only last 2 messages (prune old ones)
|
||||
transformedMessages = messages.slice(-2);
|
||||
return transformedMessages;
|
||||
},
|
||||
convertToLlm: (messages) => {
|
||||
convertedMessages = messages.filter(
|
||||
(m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult",
|
||||
) as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
// transformContext should have been called first, keeping only last 2
|
||||
expect(transformedMessages.length).toBe(2);
|
||||
// Then convertToLlm receives the pruned messages
|
||||
expect(convertedMessages.length).toBe(2);
|
||||
});
|
||||
|
||||
it("should handle tool calls and results", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `echoed: ${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("echo something");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
let callIndex = 0;
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return tool call
|
||||
const message = createAssistantMessage(
|
||||
[{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "hello" } }],
|
||||
"toolUse",
|
||||
);
|
||||
stream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Tool should have been executed
|
||||
expect(executed).toEqual(["hello"]);
|
||||
|
||||
// Should have tool execution events
|
||||
const toolStart = events.find((e) => e.type === "tool_execution_start");
|
||||
const toolEnd = events.find((e) => e.type === "tool_execution_end");
|
||||
expect(toolStart).toBeDefined();
|
||||
expect(toolEnd).toBeDefined();
|
||||
if (toolEnd?.type === "tool_execution_end") {
|
||||
expect(toolEnd.isError).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("should inject queued messages and skip remaining tool calls", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `ok:${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("start");
|
||||
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
|
||||
|
||||
let queuedDelivered = false;
|
||||
let callIndex = 0;
|
||||
let sawInterruptInContext = false;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
getQueuedMessages: async () => {
|
||||
// Return queued message after first tool executes
|
||||
if (executed.length === 1 && !queuedDelivered) {
|
||||
queuedDelivered = true;
|
||||
return [queuedUserMessage];
|
||||
}
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, (_model, ctx, _options) => {
|
||||
// Check if interrupt message is in context on second call
|
||||
if (callIndex === 1) {
|
||||
sawInterruptInContext = ctx.messages.some(
|
||||
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
|
||||
);
|
||||
}
|
||||
|
||||
const mockStream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return two tool calls
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
|
||||
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
mockStream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
mockStream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return mockStream;
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Only first tool should have executed
|
||||
expect(executed).toEqual(["first"]);
|
||||
|
||||
// Second tool should be skipped
|
||||
const toolEnds = events.filter(
|
||||
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> => e.type === "tool_execution_end",
|
||||
);
|
||||
expect(toolEnds.length).toBe(2);
|
||||
expect(toolEnds[0].isError).toBe(false);
|
||||
expect(toolEnds[1].isError).toBe(true);
|
||||
if (toolEnds[1].result.content[0]?.type === "text") {
|
||||
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
|
||||
}
|
||||
|
||||
// Queued message should appear in events
|
||||
const queuedMessageEvent = events.find(
|
||||
(e) =>
|
||||
e.type === "message_start" &&
|
||||
e.message.role === "user" &&
|
||||
typeof e.message.content === "string" &&
|
||||
e.message.content === "interrupt",
|
||||
);
|
||||
expect(queuedMessageEvent).toBeDefined();
|
||||
|
||||
// Interrupt message should be in context when second LLM call is made
|
||||
expect(sawInterruptInContext).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("agentLoopContinue with AgentMessage", () => {
|
||||
it("should throw when context has no messages", () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
expect(() => agentLoopContinue(context, config)).toThrow("Cannot continue: no messages in context");
|
||||
});
|
||||
|
||||
it("should continue from existing context without emitting user message events", async () => {
|
||||
const userMessage: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [userMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should only return the new assistant message (not the existing user message)
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
|
||||
// Should NOT have user message events (that's the key difference from agentLoop)
|
||||
const messageEndEvents = events.filter((e) => e.type === "message_end");
|
||||
expect(messageEndEvents.length).toBe(1);
|
||||
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
|
||||
});
|
||||
|
||||
it("should allow custom message types as last message (caller responsibility)", async () => {
|
||||
// Custom message that will be converted to user message by convertToLlm
|
||||
interface HookMessage {
|
||||
role: "hookMessage";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const hookMessage: HookMessage = {
|
||||
role: "hookMessage",
|
||||
text: "Hook content",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [hookMessage as unknown as AgentMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Convert hookMessage to user message
|
||||
return messages
|
||||
.map((m) => {
|
||||
if ((m as any).role === "hookMessage") {
|
||||
return {
|
||||
role: "user" as const,
|
||||
content: (m as any).text,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
return m;
|
||||
})
|
||||
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response to hook" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Should not throw - the hookMessage will be converted to user message
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
});
|
||||
});
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent, ProviderTransport } from "../src/index.js";
|
||||
import { Agent } from "../src/index.js";
|
||||
|
||||
describe("Agent", () => {
|
||||
it("should create an agent instance with default state", () => {
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
});
|
||||
const agent = new Agent();
|
||||
|
||||
expect(agent.state).toBeDefined();
|
||||
expect(agent.state.systemPrompt).toBe("");
|
||||
|
|
@ -23,7 +21,6 @@ describe("Agent", () => {
|
|||
it("should create an agent instance with custom initial state", () => {
|
||||
const customModel = getModel("openai", "gpt-4o-mini");
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model: customModel,
|
||||
|
|
@ -37,9 +34,7 @@ describe("Agent", () => {
|
|||
});
|
||||
|
||||
it("should subscribe to events", () => {
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
});
|
||||
const agent = new Agent();
|
||||
|
||||
let eventCount = 0;
|
||||
const unsubscribe = agent.subscribe((_event) => {
|
||||
|
|
@ -61,9 +56,7 @@ describe("Agent", () => {
|
|||
});
|
||||
|
||||
it("should update state with mutators", () => {
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
});
|
||||
const agent = new Agent();
|
||||
|
||||
// Test setSystemPrompt
|
||||
agent.setSystemPrompt("Custom prompt");
|
||||
|
|
@ -101,38 +94,19 @@ describe("Agent", () => {
|
|||
});
|
||||
|
||||
it("should support message queueing", async () => {
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
});
|
||||
const agent = new Agent();
|
||||
|
||||
const message = { role: "user" as const, content: "Queued message", timestamp: Date.now() };
|
||||
await agent.queueMessage(message);
|
||||
agent.queueMessage(message);
|
||||
|
||||
// The message is queued but not yet in state.messages
|
||||
expect(agent.state.messages).not.toContainEqual(message);
|
||||
});
|
||||
|
||||
it("should handle abort controller", () => {
|
||||
const agent = new Agent({
|
||||
transport: new ProviderTransport(),
|
||||
});
|
||||
const agent = new Agent();
|
||||
|
||||
// Should not throw even if nothing is running
|
||||
expect(() => agent.abort()).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("ProviderTransport", () => {
|
||||
it("should create a provider transport instance", () => {
|
||||
const transport = new ProviderTransport();
|
||||
expect(transport).toBeDefined();
|
||||
});
|
||||
|
||||
it("should create a provider transport with options", () => {
|
||||
const transport = new ProviderTransport({
|
||||
getApiKey: async (provider) => `test-key-${provider}`,
|
||||
corsProxyUrl: "https://proxy.example.com",
|
||||
});
|
||||
expect(transport).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,25 +1,8 @@
|
|||
import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
|
||||
import { calculateTool, getModel } from "@mariozechner/pi-ai";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent, ProviderTransport } from "../src/index.js";
|
||||
|
||||
function createTransport() {
|
||||
return new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
});
|
||||
}
|
||||
import { Agent } from "../src/index.js";
|
||||
import { calculateTool } from "./utils/calculate.js";
|
||||
|
||||
async function basicPrompt(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
|
|
@ -29,7 +12,6 @@ async function basicPrompt(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("What is 2+2? Answer with just the number.");
|
||||
|
|
@ -57,7 +39,6 @@ async function toolExecution(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
|
||||
|
|
@ -99,7 +80,6 @@ async function abortExecution(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
|
||||
|
|
@ -129,7 +109,6 @@ async function stateUpdates(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const events: Array<string> = [];
|
||||
|
|
@ -162,7 +141,6 @@ async function multiTurnConversation(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("My name is Alice.");
|
||||
|
|
@ -356,7 +334,6 @@ describe("Agent.continue()", () => {
|
|||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
|
||||
|
|
@ -368,7 +345,6 @@ describe("Agent.continue()", () => {
|
|||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
|
|
@ -405,7 +381,6 @@ describe("Agent.continue()", () => {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
// Manually add a user message without calling prompt()
|
||||
|
|
@ -445,7 +420,6 @@ describe("Agent.continue()", () => {
|
|||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
// Set up a conversation state as if tool was just executed
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool, AgentToolResult } from "../../agent/types.js";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface CalculateResult extends AgentToolResult<undefined> {
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool } from "../../agent/index.js";
|
||||
import type { AgentToolResult } from "../types.js";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {}
|
||||
|
||||
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Agent API moved**: All agent functionality (`agentLoop`, `agentLoopContinue`, `AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, etc.) has moved to `@mariozechner/pi-agent-core`. See the [agent-core README](../agent/README.md) for documentation.
|
||||
|
||||
## [0.28.0] - 2025-12-25
|
||||
|
||||
### Breaking Changes
|
||||
|
|
|
|||
|
|
@ -782,276 +782,6 @@ const continuation = await complete(newModel, restored);
|
|||
|
||||
> **Note**: If the context contains images (encoded as base64 as shown in the Image Input section), those will also be serialized.
|
||||
|
||||
## Agent API
|
||||
|
||||
The Agent API provides a higher-level interface for building agents with tools. It handles tool execution, validation, and provides detailed event streaming for interactive applications.
|
||||
|
||||
### Event System
|
||||
|
||||
The Agent API streams events during execution, allowing you to build reactive UIs and track agent progress. The agent processes prompts in **turns**, where each turn consists of:
|
||||
1. An assistant message (the LLM's response)
|
||||
2. Optional tool executions if the assistant calls tools
|
||||
3. Tool result messages that are fed back to the LLM
|
||||
|
||||
This continues until the assistant produces a response without tool calls.
|
||||
|
||||
**Queued messages**: If you provide `getQueuedMessages` in the loop config, the agent checks for queued user messages after each tool call. When queued messages are found, any remaining tool calls from the current assistant message are skipped and returned as error tool results (`isError: true`) with the message "Skipped due to queued user message." The queued user messages are injected before the next assistant response.
|
||||
|
||||
### Event Flow Example
|
||||
|
||||
Given a prompt asking to calculate two expressions and sum them:
|
||||
|
||||
```typescript
|
||||
import { agentLoop, AgentContext, calculateTool } from '@mariozechner/pi-ai';
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: 'You are a helpful math assistant.',
|
||||
messages: [],
|
||||
tools: [calculateTool]
|
||||
};
|
||||
|
||||
const stream = agentLoop(
|
||||
{ role: 'user', content: 'Calculate 15 * 20 and 30 * 40, then sum the results', timestamp: Date.now() },
|
||||
context,
|
||||
{ model: getModel('openai', 'gpt-4o-mini') }
|
||||
);
|
||||
|
||||
// Expected event sequence:
|
||||
// 1. agent_start - Agent begins processing
|
||||
// 2. turn_start - First turn begins
|
||||
// 3. message_start - User message starts
|
||||
// 4. message_end - User message ends
|
||||
// 5. message_start - Assistant message starts
|
||||
// 6. message_update - Assistant streams response with tool calls
|
||||
// 7. message_end - Assistant message ends
|
||||
// 8. tool_execution_start - First calculation (15 * 20)
|
||||
// 9. tool_execution_update - Streaming progress (for long-running tools)
|
||||
// 10. tool_execution_end - Result: 300
|
||||
// 11. tool_execution_start - Second calculation (30 * 40)
|
||||
// 12. tool_execution_update - Streaming progress
|
||||
// 13. tool_execution_end - Result: 1200
|
||||
// 12. message_start - Tool result message for first calculation
|
||||
// 13. message_end - Tool result message ends
|
||||
// 14. message_start - Tool result message for second calculation
|
||||
// 15. message_end - Tool result message ends
|
||||
// 16. turn_end - First turn ends with 2 tool results
|
||||
// 17. turn_start - Second turn begins
|
||||
// 18. message_start - Assistant message starts
|
||||
// 19. message_update - Assistant streams response with sum calculation
|
||||
// 20. message_end - Assistant message ends
|
||||
// 21. tool_execution_start - Sum calculation (300 + 1200)
|
||||
// 22. tool_execution_end - Result: 1500
|
||||
// 23. message_start - Tool result message for sum
|
||||
// 24. message_end - Tool result message ends
|
||||
// 25. turn_end - Second turn ends with 1 tool result
|
||||
// 26. turn_start - Third turn begins
|
||||
// 27. message_start - Final assistant message starts
|
||||
// 28. message_update - Assistant streams final answer
|
||||
// 29. message_end - Final assistant message ends
|
||||
// 30. turn_end - Third turn ends with 0 tool results
|
||||
// 31. agent_end - Agent completes with all messages
|
||||
```
|
||||
|
||||
### Handling Events
|
||||
|
||||
```typescript
|
||||
for await (const event of stream) {
|
||||
switch (event.type) {
|
||||
case 'agent_start':
|
||||
console.log('Agent started');
|
||||
break;
|
||||
|
||||
case 'turn_start':
|
||||
console.log('New turn started');
|
||||
break;
|
||||
|
||||
case 'message_start':
|
||||
console.log(`${event.message.role} message started`);
|
||||
break;
|
||||
|
||||
case 'message_update':
|
||||
// Only for assistant messages during streaming
|
||||
if (event.message.content.some(c => c.type === 'text')) {
|
||||
console.log('Assistant:', event.message.content);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool_execution_start':
|
||||
console.log(`Calling ${event.toolName} with:`, event.args);
|
||||
break;
|
||||
|
||||
case 'tool_execution_update':
|
||||
// Streaming progress for long-running tools (e.g., bash output)
|
||||
console.log(`Progress:`, event.partialResult.content);
|
||||
break;
|
||||
|
||||
case 'tool_execution_end':
|
||||
if (event.isError) {
|
||||
console.error(`Tool failed:`, event.result);
|
||||
} else {
|
||||
console.log(`Tool result:`, event.result.content);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'turn_end':
|
||||
console.log(`Turn ended with ${event.toolResults.length} tool calls`);
|
||||
break;
|
||||
|
||||
case 'agent_end':
|
||||
console.log(`Agent completed with ${event.messages.length} new messages`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get all messages generated during this agent execution
|
||||
// These include the user message and can be directly appended to context.messages
|
||||
const messages = await stream.result();
|
||||
context.messages.push(...messages);
|
||||
```
|
||||
|
||||
### Continuing from Existing Context
|
||||
|
||||
Use `agentLoopContinue` to resume an agent loop without adding a new user message. This is useful for:
|
||||
- Retrying after context overflow (after compaction reduces context size)
|
||||
- Resuming from tool results that were added manually to the context
|
||||
|
||||
```typescript
|
||||
import { agentLoopContinue, AgentContext } from '@mariozechner/pi-ai';
|
||||
|
||||
// Context already has messages - last must be 'user' or 'toolResult'
|
||||
const context: AgentContext = {
|
||||
systemPrompt: 'You are helpful.',
|
||||
messages: [userMessage, assistantMessage, toolResult],
|
||||
tools: [myTool]
|
||||
};
|
||||
|
||||
// Continue processing from the tool result
|
||||
const stream = agentLoopContinue(context, { model });
|
||||
|
||||
for await (const event of stream) {
|
||||
// Same events as agentLoop, but no user message events emitted
|
||||
}
|
||||
|
||||
const newMessages = await stream.result();
|
||||
```
|
||||
|
||||
**Validation**: Throws if context has no messages or if the last message is an assistant message.
|
||||
|
||||
### Defining Tools with TypeBox
|
||||
|
||||
Tools use TypeBox schemas for runtime validation and type inference:
|
||||
|
||||
```typescript
|
||||
import { Type, Static, AgentTool, AgentToolResult, StringEnum } from '@mariozechner/pi-ai';
|
||||
|
||||
const weatherSchema = Type.Object({
|
||||
city: Type.String({ minLength: 1 }),
|
||||
units: StringEnum(['celsius', 'fahrenheit'], { default: 'celsius' })
|
||||
});
|
||||
|
||||
type WeatherParams = Static<typeof weatherSchema>;
|
||||
|
||||
const weatherTool: AgentTool<typeof weatherSchema, { temp: number }> = {
|
||||
label: 'Get Weather',
|
||||
name: 'get_weather',
|
||||
description: 'Get current weather for a city',
|
||||
parameters: weatherSchema,
|
||||
execute: async (toolCallId, args, signal, onUpdate) => {
|
||||
// args is fully typed: { city: string, units: 'celsius' | 'fahrenheit' }
|
||||
// signal: AbortSignal for cancellation
|
||||
// onUpdate: Optional callback for streaming progress (emits tool_execution_update events)
|
||||
const temp = Math.round(Math.random() * 30);
|
||||
return {
|
||||
content: [{ type: 'text', text: `Temperature in ${args.city}: ${temp}°${args.units[0].toUpperCase()}` }],
|
||||
details: { temp }
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Tools can also return images alongside text
|
||||
const chartTool: AgentTool<typeof Type.Object({ data: Type.Array(Type.Number()) })> = {
|
||||
label: 'Generate Chart',
|
||||
name: 'generate_chart',
|
||||
description: 'Generate a chart from data',
|
||||
parameters: Type.Object({ data: Type.Array(Type.Number()) }),
|
||||
execute: async (toolCallId, args) => {
|
||||
const chartImage = await generateChartImage(args.data);
|
||||
return {
|
||||
content: [
|
||||
{ type: 'text', text: `Generated chart with ${args.data.length} data points` },
|
||||
{ type: 'image', data: chartImage.toString('base64'), mimeType: 'image/png' }
|
||||
]
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Tools can stream progress via the onUpdate callback (emits tool_execution_update events)
|
||||
const bashTool: AgentTool<typeof Type.Object({ command: Type.String() }), { exitCode: number }> = {
|
||||
label: 'Run Bash',
|
||||
name: 'bash',
|
||||
description: 'Execute a bash command',
|
||||
parameters: Type.Object({ command: Type.String() }),
|
||||
execute: async (toolCallId, args, signal, onUpdate) => {
|
||||
let output = '';
|
||||
const child = spawn('bash', ['-c', args.command]);
|
||||
|
||||
child.stdout.on('data', (data) => {
|
||||
output += data.toString();
|
||||
// Stream partial output to UI via tool_execution_update events
|
||||
onUpdate?.({
|
||||
content: [{ type: 'text', text: output }],
|
||||
details: { exitCode: -1 } // Not finished yet
|
||||
});
|
||||
});
|
||||
|
||||
const exitCode = await new Promise<number>((resolve) => {
|
||||
child.on('close', resolve);
|
||||
});
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: output }],
|
||||
details: { exitCode }
|
||||
};
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Validation and Error Handling
|
||||
|
||||
Tool arguments are automatically validated using AJV with the TypeBox schema. Invalid arguments result in detailed error messages:
|
||||
|
||||
```typescript
|
||||
// If the LLM calls with invalid arguments:
|
||||
// get_weather({ city: '', units: 'kelvin' })
|
||||
|
||||
// The tool execution will fail with:
|
||||
/*
|
||||
Validation failed for tool "get_weather":
|
||||
- city: must NOT have fewer than 1 characters
|
||||
- units: must be equal to one of the allowed values
|
||||
|
||||
Received arguments:
|
||||
{
|
||||
"city": "",
|
||||
"units": "kelvin"
|
||||
}
|
||||
*/
|
||||
```
|
||||
|
||||
### Built-in Example Tools
|
||||
|
||||
The library includes example tools for common operations:
|
||||
|
||||
```typescript
|
||||
import { calculateTool, getCurrentTimeTool } from '@mariozechner/pi-ai';
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: 'You are a helpful assistant.',
|
||||
messages: [],
|
||||
tools: [calculateTool, getCurrentTimeTool]
|
||||
};
|
||||
```
|
||||
|
||||
## Browser Usage
|
||||
|
||||
The library supports browser environments. You must pass the API key explicitly since environment variables are not available in browsers:
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
export { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
||||
export * from "./tools/index.js";
|
||||
export type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
AgentToolUpdateCallback,
|
||||
QueuedMessage,
|
||||
} from "./types.js";
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
export { calculate, calculateTool } from "./calculate.js";
|
||||
export { getCurrentTime, getCurrentTimeTool } from "./get-current-time.js";
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: Message[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
// Event types
|
||||
export type AgentEvent =
|
||||
// Emitted when the agent starts. An agent can emit multiple turns
|
||||
| { type: "agent_start" }
|
||||
// Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages
|
||||
| { type: "turn_start" }
|
||||
// Emitted when a user, assistant or tool result message starts
|
||||
| { type: "message_start"; message: Message }
|
||||
// Emitted when an asssitant messages is updated due to streaming
|
||||
| { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage }
|
||||
// Emitted when a user, assistant or tool result message is complete
|
||||
| { type: "message_end"; message: Message }
|
||||
// Emitted when a tool execution starts
|
||||
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
|
||||
// Emitted when a tool execution produces output (streaming)
|
||||
| {
|
||||
type: "tool_execution_update";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: any;
|
||||
partialResult: AgentToolResult<any>;
|
||||
}
|
||||
// Emitted when a tool execution completes
|
||||
| {
|
||||
type: "tool_execution_end";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
}
|
||||
// Emitted when a full turn completes
|
||||
| { type: "turn_end"; message: AssistantMessage; toolResults: ToolResultMessage[] }
|
||||
// Emitted when the agent has completed all its turns. All messages from every turn are
|
||||
// contained in messages, which can be appended to the context
|
||||
| { type: "agent_end"; messages: AgentContext["messages"] };
|
||||
|
||||
// Queued message with optional LLM representation
|
||||
export interface QueuedMessage<TApp = Message> {
|
||||
original: TApp; // Original message for UI events
|
||||
llm?: Message; // Optional transformed message for loop context (undefined if filtered)
|
||||
}
|
||||
|
||||
// Configuration for agent loop execution
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Optional hook to resolve an API key dynamically for each LLM call.
|
||||
*
|
||||
* This is useful for short-lived OAuth tokens (e.g. GitHub Copilot) that may
|
||||
* expire during long-running tool execution phases.
|
||||
*
|
||||
* The agent loop will call this before each assistant response and pass the
|
||||
* returned value as `apiKey` to `streamSimple()` (or a custom `streamFn`).
|
||||
*
|
||||
* If it returns `undefined`, the loop falls back to `config.apiKey`, and then
|
||||
* to `streamSimple()`'s own provider key lookup (setApiKey/env vars).
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
|
||||
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
|
||||
}
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
export * from "./agent/index.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export * from "./providers/google.js";
|
||||
|
|
@ -7,6 +6,7 @@ export * from "./providers/openai-completions.js";
|
|||
export * from "./providers/openai-responses.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
export * from "./utils/oauth/index.js";
|
||||
export * from "./utils/overflow.js";
|
||||
export * from "./utils/typebox-helpers.js";
|
||||
|
|
|
|||
|
|
@ -3325,13 +3325,13 @@ export const MODELS = {
|
|||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.224,
|
||||
output: 0.32,
|
||||
input: 0.25,
|
||||
output: 0.38,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 163840,
|
||||
maxTokens: 4096,
|
||||
maxTokens: 65536,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"deepseek/deepseek-v3.2-exp": {
|
||||
id: "deepseek/deepseek-v3.2-exp",
|
||||
|
|
@ -3892,7 +3892,7 @@ export const MODELS = {
|
|||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 196608,
|
||||
maxTokens: 131072,
|
||||
maxTokens: 65536,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"minimax/minimax-m2.1": {
|
||||
id: "minimax/minimax-m2.1",
|
||||
|
|
@ -5371,7 +5371,7 @@ export const MODELS = {
|
|||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 131072,
|
||||
maxTokens: 128000,
|
||||
maxTokens: 131072,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"openai/gpt-oss-safeguard-20b": {
|
||||
id: "openai/gpt-oss-safeguard-20b",
|
||||
|
|
@ -6249,8 +6249,8 @@ export const MODELS = {
|
|||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.3,
|
||||
output: 1.2,
|
||||
input: 0.25,
|
||||
output: 0.85,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
|
|
@ -6266,8 +6266,8 @@ export const MODELS = {
|
|||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.3,
|
||||
output: 1.2,
|
||||
input: 0.25,
|
||||
output: 0.85,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
|
|
@ -6538,13 +6538,13 @@ export const MODELS = {
|
|||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.39,
|
||||
output: 1.9,
|
||||
input: 0.35,
|
||||
output: 1.5,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 204800,
|
||||
maxTokens: 204800,
|
||||
contextWindow: 202752,
|
||||
maxTokens: 65536,
|
||||
} satisfies Model<"openai-completions">,
|
||||
"z-ai/glm-4.6:exacto": {
|
||||
id: "z-ai/glm-4.6:exacto",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
|
||||
*/
|
||||
|
||||
import type { Content, ThinkingConfig, ThinkingLevel } from "@google/genai";
|
||||
import type { Content, ThinkingConfig } from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
|
|
@ -21,6 +21,12 @@ import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
|||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { convertMessages, convertTools, mapStopReasonString, mapToolChoice } from "./google-shared.js";
|
||||
|
||||
/**
|
||||
* Thinking level for Gemini 3 models.
|
||||
* Mirrors Google's ThinkingLevel enum values.
|
||||
*/
|
||||
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
|
||||
|
||||
export interface GoogleGeminiCliOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
/**
|
||||
|
|
@ -35,7 +41,7 @@ export interface GoogleGeminiCliOptions extends StreamOptions {
|
|||
/** Thinking budget in tokens. Use for Gemini 2.x models. */
|
||||
budgetTokens?: number;
|
||||
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
|
||||
level?: ThinkingLevel;
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
projectId?: string;
|
||||
}
|
||||
|
|
@ -436,7 +442,8 @@ function buildRequest(
|
|||
};
|
||||
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
|
||||
if (options.thinking.level !== undefined) {
|
||||
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level;
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import {
|
|||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
type ThinkingLevel,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import { getEnvApiKey } from "../stream.js";
|
||||
|
|
@ -20,6 +19,7 @@ import type {
|
|||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import { convertMessages, convertTools, mapStopReason, mapToolChoice } from "./google-shared.js";
|
||||
|
||||
export interface GoogleOptions extends StreamOptions {
|
||||
|
|
@ -27,7 +27,7 @@ export interface GoogleOptions extends StreamOptions {
|
|||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: ThinkingLevel;
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -299,7 +299,8 @@ function buildParams(
|
|||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
thinkingConfig.thinkingLevel = options.thinking.level;
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import { ThinkingLevel } from "@google/genai";
|
||||
import { supportsXhigh } from "./models.js";
|
||||
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
|
||||
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
|
||||
import { type GoogleGeminiCliOptions, streamGoogleGeminiCli } from "./providers/google-gemini-cli.js";
|
||||
import {
|
||||
type GoogleGeminiCliOptions,
|
||||
type GoogleThinkingLevel,
|
||||
streamGoogleGeminiCli,
|
||||
} from "./providers/google-gemini-cli.js";
|
||||
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
|
||||
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
|
||||
import type {
|
||||
|
|
@ -30,9 +33,13 @@ export function getEnvApiKey(provider: any): string | undefined {
|
|||
return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
|
||||
}
|
||||
|
||||
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
const envMap: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
|
|
@ -252,53 +259,56 @@ function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
|||
return model.id.includes("3-flash");
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(effort: ClampedReasoningEffort, model: Model<"google-generative-ai">): ThinkingLevel {
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedReasoningEffort,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
// Gemini 3 Pro only supports LOW/HIGH (for now)
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return ThinkingLevel.LOW;
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return ThinkingLevel.HIGH;
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
// Gemini 3 Flash supports all four levels
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return ThinkingLevel.MINIMAL;
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return ThinkingLevel.LOW;
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return ThinkingLevel.MEDIUM;
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return ThinkingLevel.HIGH;
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): ThinkingLevel {
|
||||
function getGeminiCliThinkingLevel(effort: ClampedReasoningEffort, modelId: string): GoogleThinkingLevel {
|
||||
if (modelId.includes("3-pro")) {
|
||||
// Gemini 3 Pro only supports LOW/HIGH (for now)
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return ThinkingLevel.LOW;
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return ThinkingLevel.HIGH;
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
// Gemini 3 Flash supports all four levels
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return ThinkingLevel.MINIMAL;
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return ThinkingLevel.LOW;
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return ThinkingLevel.MEDIUM;
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return ThinkingLevel.HIGH;
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,25 +2,16 @@
|
|||
* Anthropic OAuth flow (Claude Pro/Max)
|
||||
*/
|
||||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
|
||||
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
|
||||
const SCOPES = "org:create_api_key user:profile user:inference";
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge
|
||||
*/
|
||||
function generatePKCE(): { verifier: string; challenge: string } {
|
||||
const verifier = randomBytes(32).toString("base64url");
|
||||
const challenge = createHash("sha256").update(verifier).digest("base64url");
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Anthropic OAuth (device code flow)
|
||||
*
|
||||
|
|
@ -31,7 +22,7 @@ export async function loginAnthropic(
|
|||
onAuthUrl: (url: string) => void,
|
||||
onPromptCode: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import { getModels } from "../../models.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
||||
const COPILOT_HEADERS = {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
/**
|
||||
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { createServer, type Server } from "http";
|
||||
import type { Server } from "http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
|
||||
);
|
||||
|
|
@ -30,19 +33,15 @@ const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
|||
// Fallback project ID when discovery fails
|
||||
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge
|
||||
*/
|
||||
function generatePKCE(): { verifier: string; challenge: string } {
|
||||
const verifier = randomBytes(32).toString("base64url");
|
||||
const challenge = createHash("sha256").update(verifier).digest("base64url");
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> {
|
||||
async function startCallbackServer(): Promise<{
|
||||
server: Server;
|
||||
getCode: () => Promise<{ code: string; state: string }>;
|
||||
}> {
|
||||
const { createServer } = await import("http");
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let codeResolve: (value: { code: string; state: string }) => void;
|
||||
let codeReject: (error: Error) => void;
|
||||
|
|
@ -232,7 +231,7 @@ export async function loginAntigravity(
|
|||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
/**
|
||||
* Gemini CLI OAuth flow (Google Cloud Code Assist)
|
||||
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import { createHash, randomBytes } from "crypto";
|
||||
import { createServer, type Server } from "http";
|
||||
import type { Server } from "http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials } from "./types.js";
|
||||
|
||||
const decode = (s: string) => Buffer.from(s, "base64").toString();
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
|
||||
);
|
||||
|
|
@ -22,19 +25,15 @@ const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
|||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge
|
||||
*/
|
||||
function generatePKCE(): { verifier: string; challenge: string } {
|
||||
const verifier = randomBytes(32).toString("base64url");
|
||||
const challenge = createHash("sha256").update(verifier).digest("base64url");
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
function startCallbackServer(): Promise<{ server: Server; getCode: () => Promise<{ code: string; state: string }> }> {
|
||||
async function startCallbackServer(): Promise<{
|
||||
server: Server;
|
||||
getCode: () => Promise<{ code: string; state: string }>;
|
||||
}> {
|
||||
const { createServer } = await import("http");
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let codeResolve: (value: { code: string; state: string }) => void;
|
||||
let codeReject: (error: Error) => void;
|
||||
|
|
@ -263,7 +262,7 @@ export async function loginGeminiCli(
|
|||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
|
|
|
|||
34
packages/ai/src/utils/oauth/pkce.ts
Normal file
34
packages/ai/src/utils/oauth/pkce.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* PKCE utilities using Web Crypto API.
|
||||
* Works in both Node.js 20+ and browsers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Encode bytes as base64url string.
|
||||
*/
|
||||
function base64urlEncode(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
for (const byte of bytes) {
|
||||
binary += String.fromCharCode(byte);
|
||||
}
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge.
|
||||
* Uses Web Crypto API for cross-platform compatibility.
|
||||
*/
|
||||
export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
|
||||
// Generate random verifier
|
||||
const verifierBytes = new Uint8Array(32);
|
||||
crypto.getRandomValues(verifierBytes);
|
||||
const verifier = base64urlEncode(verifierBytes);
|
||||
|
||||
// Compute SHA-256 challenge
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { agentLoop } from "../src/agent/agent-loop.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js";
|
||||
import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js";
|
||||
import { AssistantMessageEventStream } from "../src/utils/event-stream.js";
|
||||
|
||||
function createUsage() {
|
||||
return {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
function createModel(): Model<"openai-responses"> {
|
||||
return {
|
||||
id: "mock",
|
||||
name: "mock",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://example.invalid",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 2048,
|
||||
};
|
||||
}
|
||||
|
||||
describe("agentLoop queued message interrupt", () => {
|
||||
it("injects queued messages after a tool call and skips remaining tool calls", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `ok:${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: UserMessage = {
|
||||
role: "user",
|
||||
content: "start",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const queuedUserMessage: Message = {
|
||||
role: "user",
|
||||
content: "interrupt",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
const queuedMessages: QueuedMessage<Message>[] = [{ original: queuedUserMessage, llm: queuedUserMessage }];
|
||||
|
||||
let queuedDelivered = false;
|
||||
let sawInterruptInContext = false;
|
||||
let callIndex = 0;
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
const message: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
|
||||
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
|
||||
],
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: createUsage(),
|
||||
stopReason: "toolUse",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
stream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
const message: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "done" }],
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: createUsage(),
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex += 1;
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async <T>() => {
|
||||
if (executed.length === 1 && !queuedDelivered) {
|
||||
queuedDelivered = true;
|
||||
return queuedMessages as QueuedMessage<T>[];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
getQueuedMessages,
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
|
||||
if (callIndex === 1) {
|
||||
sawInterruptInContext = ctx.messages.some(
|
||||
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
|
||||
);
|
||||
}
|
||||
return streamFn();
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(executed).toEqual(["first"]);
|
||||
const toolEnds = events.filter(
|
||||
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
|
||||
);
|
||||
expect(toolEnds.length).toBe(2);
|
||||
expect(toolEnds[1].isError).toBe(true);
|
||||
expect(toolEnds[1].result.content[0]?.type).toBe("text");
|
||||
if (toolEnds[1].result.content[0]?.type === "text") {
|
||||
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
|
||||
}
|
||||
|
||||
const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end");
|
||||
const queuedMessageIndex = events.findIndex(
|
||||
(event) =>
|
||||
event.type === "message_start" &&
|
||||
event.message.role === "user" &&
|
||||
typeof event.message.content === "string" &&
|
||||
event.message.content === "interrupt",
|
||||
);
|
||||
const nextAssistantIndex = events.findIndex(
|
||||
(event, index) =>
|
||||
index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant",
|
||||
);
|
||||
|
||||
expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex);
|
||||
expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex);
|
||||
expect(sawInterruptInContext).toBe(true);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,701 +0,0 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js";
|
||||
import { calculateTool } from "../src/agent/tools/calculate.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Message,
|
||||
Model,
|
||||
OptionsForApi,
|
||||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "../src/types.js";
|
||||
import { resolveApiKey } from "./oauth.js";
|
||||
|
||||
// Resolve OAuth tokens at module level (async, runs before tests)
|
||||
const oauthTokens = await Promise.all([
|
||||
resolveApiKey("anthropic"),
|
||||
resolveApiKey("github-copilot"),
|
||||
resolveApiKey("google-gemini-cli"),
|
||||
resolveApiKey("google-antigravity"),
|
||||
]);
|
||||
const [anthropicOAuthToken, githubCopilotToken, geminiCliToken, antigravityToken] = oauthTokens;
|
||||
|
||||
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Create the agent context with the calculator tool
|
||||
const context: AgentContext = {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
|
||||
messages: [],
|
||||
tools: [calculateTool],
|
||||
};
|
||||
|
||||
// Create the prompt config
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
...options,
|
||||
};
|
||||
|
||||
// Create the user prompt asking for multiple calculations
|
||||
const userPrompt: UserMessage = {
|
||||
role: "user",
|
||||
content: `Use the calculator tool to complete the following mulit-step task.
|
||||
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
|
||||
2. Calculate the sum of the two results using the calculator tool
|
||||
3. Output ONLY the final sum as a single integer number, nothing else.`,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
// Calculate expected results (using integers)
|
||||
const expectedFirst = 3485 * 4234; // = 14755490
|
||||
const expectedSecond = 88823 * 3482; // = 309281786
|
||||
const expectedSum = expectedFirst + expectedSecond; // = 324037276
|
||||
|
||||
// Track events for verification
|
||||
const events: AgentEvent[] = [];
|
||||
let turns = 0;
|
||||
let toolCallCount = 0;
|
||||
const toolResults: number[] = [];
|
||||
let finalAnswer: number | undefined;
|
||||
|
||||
// Execute the prompt
|
||||
const stream = agentLoop(userPrompt, context, config);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
|
||||
switch (event.type) {
|
||||
case "turn_start":
|
||||
turns++;
|
||||
console.log(`\n=== Turn ${turns} started ===`);
|
||||
break;
|
||||
|
||||
case "turn_end":
|
||||
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
|
||||
console.log(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_end":
|
||||
if (!event.isError && typeof event.result === "object" && event.result.content) {
|
||||
const textOutput = event.result.content
|
||||
.filter((c: any) => c.type === "text")
|
||||
.map((c: any) => c.text)
|
||||
.join("\n");
|
||||
toolCallCount++;
|
||||
// Extract number from output like "expression = result"
|
||||
const match = textOutput.match(/=\s*([\d.]+)/);
|
||||
if (match) {
|
||||
const value = parseFloat(match[1]);
|
||||
toolResults.push(value);
|
||||
console.log(`Tool ${toolCallCount}: ${textOutput}`);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
// Just track the message end event, don't extract answer here
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the final messages
|
||||
const finalMessages = await stream.result();
|
||||
|
||||
// Verify the results
|
||||
expect(finalMessages).toBeDefined();
|
||||
expect(finalMessages.length).toBeGreaterThan(0);
|
||||
|
||||
const finalMessage = finalMessages[finalMessages.length - 1];
|
||||
expect(finalMessage).toBeDefined();
|
||||
expect(finalMessage.role).toBe("assistant");
|
||||
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
|
||||
|
||||
// Extract the final answer from the last assistant message
|
||||
const content = finalMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c.type === "text" ? c.text : ""))
|
||||
.join(" ");
|
||||
|
||||
// Look for integers in the response that might be the final answer
|
||||
const numbers = content.match(/\b\d+\b/g);
|
||||
if (numbers) {
|
||||
// Check if any of the numbers matches our expected sum
|
||||
for (const num of numbers) {
|
||||
const value = parseInt(num, 10);
|
||||
if (Math.abs(value - expectedSum) < 10) {
|
||||
finalAnswer = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If no exact match, take the last large number as likely the answer
|
||||
if (finalAnswer === undefined) {
|
||||
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
|
||||
if (largeNumbers.length > 0) {
|
||||
finalAnswer = largeNumbers[largeNumbers.length - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
|
||||
// (or possibly 2 if the model calculates the sum itself without a tool)
|
||||
expect(toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
|
||||
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
|
||||
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
|
||||
expect(turns).toBeGreaterThanOrEqual(3);
|
||||
expect(turns).toBeLessThanOrEqual(4);
|
||||
|
||||
// Verify the individual calculations are in the results
|
||||
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
|
||||
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
|
||||
expect(hasFirstCalc).toBe(true);
|
||||
expect(hasSecondCalc).toBe(true);
|
||||
|
||||
// Verify the final sum
|
||||
if (finalAnswer !== undefined) {
|
||||
expect(finalAnswer).toBe(expectedSum);
|
||||
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
|
||||
} else {
|
||||
// If we couldn't extract the final answer from text, check if it's in the tool results
|
||||
const hasSum = toolResults.some((r) => r === expectedSum);
|
||||
expect(hasSum).toBe(true);
|
||||
}
|
||||
|
||||
// Log summary
|
||||
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
|
||||
if (turns === 3) {
|
||||
console.log("Model used parallel tool calls for initial calculations");
|
||||
} else {
|
||||
console.log("Model used sequential tool calls");
|
||||
}
|
||||
|
||||
return {
|
||||
turns,
|
||||
toolCallCount,
|
||||
toolResults,
|
||||
finalAnswer,
|
||||
events,
|
||||
};
|
||||
}
|
||||
|
||||
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Create the agent context with the calculator tool
|
||||
const context: AgentContext = {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
|
||||
messages: [],
|
||||
tools: [calculateTool],
|
||||
};
|
||||
|
||||
// Create the prompt config
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
...options,
|
||||
};
|
||||
|
||||
// Create a prompt that will require multiple calculations
|
||||
const userPrompt: UserMessage = {
|
||||
role: "user",
|
||||
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
// Create abort controller
|
||||
const abortController = new AbortController();
|
||||
|
||||
// Track events for verification
|
||||
const events: AgentEvent[] = [];
|
||||
let toolCallCount = 0;
|
||||
const errorReceived = false;
|
||||
let finalMessages: Message[] | undefined;
|
||||
|
||||
// Execute the prompt
|
||||
const stream = agentLoop(userPrompt, context, config, abortController.signal);
|
||||
|
||||
// Abort after first tool execution
|
||||
(async () => {
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
|
||||
if (event.type === "tool_execution_end" && !event.isError) {
|
||||
toolCallCount++;
|
||||
// Abort after first successful tool execution
|
||||
if (toolCallCount === 1) {
|
||||
console.log("Aborting after first tool execution");
|
||||
abortController.abort();
|
||||
}
|
||||
}
|
||||
|
||||
if (event.type === "agent_end") {
|
||||
finalMessages = event.messages;
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
finalMessages = await stream.result();
|
||||
|
||||
// Verify abort behavior
|
||||
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
|
||||
const assistantMessage = finalMessages[finalMessages.length - 1];
|
||||
if (!assistantMessage) throw new Error("No final message received");
|
||||
expect(assistantMessage).toBeDefined();
|
||||
expect(assistantMessage.role).toBe("assistant");
|
||||
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
|
||||
|
||||
// Should have executed 1 tool call before abort
|
||||
expect(toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
expect(assistantMessage.stopReason).toBe("aborted");
|
||||
|
||||
return {
|
||||
toolCallCount,
|
||||
events,
|
||||
errorReceived,
|
||||
finalMessages,
|
||||
};
|
||||
}
|
||||
|
||||
describe("Agent Calculator Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
|
||||
const model = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
|
||||
const model = getModel("zai", "glm-4.5-air");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider Agent", () => {
|
||||
const model = getModel("mistral", "devstral-medium-latest");
|
||||
|
||||
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
|
||||
const result = await calculateTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model);
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
// =========================================================================
|
||||
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
|
||||
// =========================================================================
|
||||
|
||||
describe("Anthropic OAuth Provider Agent", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it.skipIf(!anthropicOAuthToken)(
|
||||
"should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const result = await calculateTest(model, { apiKey: anthropicOAuthToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!anthropicOAuthToken)("should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const result = await abortTest(model, { apiKey: anthropicOAuthToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("GitHub Copilot Provider Agent", () => {
|
||||
it.skipIf(!githubCopilotToken)(
|
||||
"gpt-4o - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("github-copilot", "gpt-4o");
|
||||
const result = await calculateTest(model, { apiKey: githubCopilotToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!githubCopilotToken)("gpt-4o - should handle abort during tool execution", { retry: 3 }, async () => {
|
||||
const model = getModel("github-copilot", "gpt-4o");
|
||||
const result = await abortTest(model, { apiKey: githubCopilotToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
it.skipIf(!githubCopilotToken)(
|
||||
"claude-sonnet-4 - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("github-copilot", "claude-sonnet-4");
|
||||
const result = await calculateTest(model, { apiKey: githubCopilotToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!githubCopilotToken)(
|
||||
"claude-sonnet-4 - should handle abort during tool execution",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("github-copilot", "claude-sonnet-4");
|
||||
const result = await abortTest(model, { apiKey: githubCopilotToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("Google Gemini CLI Provider Agent", () => {
|
||||
it.skipIf(!geminiCliToken)(
|
||||
"gemini-2.5-flash - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
|
||||
const result = await calculateTest(model, { apiKey: geminiCliToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!geminiCliToken)(
|
||||
"gemini-2.5-flash - should handle abort during tool execution",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
|
||||
const result = await abortTest(model, { apiKey: geminiCliToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("Google Antigravity Provider Agent", () => {
|
||||
it.skipIf(!antigravityToken)(
|
||||
"gemini-3-flash - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "gemini-3-flash");
|
||||
const result = await calculateTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!antigravityToken)(
|
||||
"gemini-3-flash - should handle abort during tool execution",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "gemini-3-flash");
|
||||
const result = await abortTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!antigravityToken)(
|
||||
"claude-sonnet-4-5 - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "claude-sonnet-4-5");
|
||||
const result = await calculateTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!antigravityToken)(
|
||||
"claude-sonnet-4-5 - should handle abort during tool execution",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "claude-sonnet-4-5");
|
||||
const result = await abortTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!antigravityToken)(
|
||||
"gpt-oss-120b-medium - should calculate multiple expressions and sum the results",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
|
||||
const result = await calculateTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
|
||||
it.skipIf(!antigravityToken)(
|
||||
"gpt-oss-120b-medium - should handle abort during tool execution",
|
||||
{ retry: 3 },
|
||||
async () => {
|
||||
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
|
||||
const result = await abortTest(model, { apiKey: antigravityToken });
|
||||
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("agentLoopContinue", () => {
|
||||
describe("validation", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
const baseContext: AgentContext = {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
const config: AgentLoopConfig = { model };
|
||||
|
||||
it("should throw when context has no messages", () => {
|
||||
expect(() => agentLoopContinue(baseContext, config)).toThrow("Cannot continue: no messages in context");
|
||||
});
|
||||
|
||||
it("should throw when last message is an assistant message", () => {
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
const context: AgentContext = {
|
||||
...baseContext,
|
||||
messages: [assistantMessage],
|
||||
};
|
||||
expect(() => agentLoopContinue(context, config)).toThrow(
|
||||
"Cannot continue from message role: assistant. Expected 'user' or 'toolResult'.",
|
||||
);
|
||||
});
|
||||
|
||||
// Note: "should not throw" tests for valid inputs are covered by the E2E tests below
|
||||
// which actually consume the stream and verify the output
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue and get assistant response when last message is user", { retry: 3 }, async () => {
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
|
||||
messages: [userMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = { model };
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoopContinue(context, config);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should have gotten an assistant response
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
|
||||
// Verify event sequence - no user message events since we're continuing
|
||||
const eventTypes = events.map((e) => e.type);
|
||||
expect(eventTypes).toContain("agent_start");
|
||||
expect(eventTypes).toContain("turn_start");
|
||||
expect(eventTypes).toContain("message_start");
|
||||
expect(eventTypes).toContain("message_end");
|
||||
expect(eventTypes).toContain("turn_end");
|
||||
expect(eventTypes).toContain("agent_end");
|
||||
|
||||
// Should NOT have user message events (that's the difference from agentLoop)
|
||||
const messageEndEvents = events.filter((e) => e.type === "message_end");
|
||||
expect(messageEndEvents.length).toBe(1); // Only assistant message
|
||||
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue processing after tool results", { retry: 3 }, async () => {
|
||||
// Simulate a conversation where:
|
||||
// 1. User asked to calculate something
|
||||
// 2. Assistant made a tool call
|
||||
// 3. Tool result is ready
|
||||
// 4. We continue from here
|
||||
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "What is 5 + 3? Use the calculator." }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me calculate that for you." },
|
||||
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
|
||||
],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const toolResult: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: "calc-1",
|
||||
toolName: "calculate",
|
||||
content: [{ type: "text", text: "5 + 3 = 8" }],
|
||||
isError: false,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are a helpful assistant. After getting a calculation result, state the answer clearly.",
|
||||
messages: [userMessage, assistantMessage, toolResult],
|
||||
tools: [calculateTool],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = { model };
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoopContinue(context, config);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should have gotten an assistant response
|
||||
expect(messages.length).toBeGreaterThanOrEqual(1);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(lastMessage.role).toBe("assistant");
|
||||
|
||||
// The assistant should mention the result (8)
|
||||
if (lastMessage.role === "assistant") {
|
||||
const textContent = lastMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as any).text)
|
||||
.join(" ");
|
||||
expect(textContent).toMatch(/8/);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -752,13 +752,11 @@ describe("Generate E2E Tests", () => {
|
|||
const llm = getModel("google-gemini-cli", "gemini-3-flash-preview");
|
||||
|
||||
it.skipIf(!geminiCliToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => {
|
||||
const { ThinkingLevel } = await import("@google/genai");
|
||||
await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.LOW } });
|
||||
await handleThinking(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "LOW" } });
|
||||
});
|
||||
|
||||
it.skipIf(!geminiCliToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => {
|
||||
const { ThinkingLevel } = await import("@google/genai");
|
||||
await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } });
|
||||
await multiTurn(llm, { apiKey: geminiCliToken, thinking: { enabled: true, level: "MEDIUM" } });
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -778,17 +776,15 @@ describe("Generate E2E Tests", () => {
|
|||
});
|
||||
|
||||
it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel", { retry: 3 }, async () => {
|
||||
const { ThinkingLevel } = await import("@google/genai");
|
||||
// gemini-3-flash supports all four levels: MINIMAL, LOW, MEDIUM, HIGH
|
||||
await handleThinking(llm, {
|
||||
apiKey: antigravityToken,
|
||||
thinking: { enabled: true, level: ThinkingLevel.LOW },
|
||||
thinking: { enabled: true, level: "LOW" },
|
||||
});
|
||||
});
|
||||
|
||||
it.skipIf(!antigravityToken)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => {
|
||||
const { ThinkingLevel } = await import("@google/genai");
|
||||
await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: ThinkingLevel.MEDIUM } });
|
||||
await multiTurn(llm, { apiKey: antigravityToken, thinking: { enabled: true, level: "MEDIUM" } });
|
||||
});
|
||||
|
||||
it.skipIf(!antigravityToken)("should handle image input", { retry: 3 }, async () => {
|
||||
|
|
@ -800,11 +796,10 @@ describe("Generate E2E Tests", () => {
|
|||
const llm = getModel("google-antigravity", "gemini-3-pro-high");
|
||||
|
||||
it.skipIf(!antigravityToken)("should handle thinking with thinkingLevel HIGH", { retry: 3 }, async () => {
|
||||
const { ThinkingLevel } = await import("@google/genai");
|
||||
// gemini-3-pro only supports LOW/HIGH
|
||||
await handleThinking(llm, {
|
||||
apiKey: antigravityToken,
|
||||
thinking: { enabled: true, level: ThinkingLevel.HIGH },
|
||||
thinking: { enabled: true, level: "HIGH" },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,140 +0,0 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import addFormatsModule from "ajv-formats";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
||||
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { AgentTool } from "../src/agent/types.js";
|
||||
|
||||
describe("Tool Validation with TypeBox and AJV", () => {
|
||||
// Define a test tool with TypeBox schema
|
||||
const testSchema = Type.Object({
|
||||
name: Type.String({ minLength: 1 }),
|
||||
age: Type.Integer({ minimum: 0, maximum: 150 }),
|
||||
email: Type.String({ format: "email" }),
|
||||
tags: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
type TestParams = Static<typeof testSchema>;
|
||||
|
||||
const testTool: AgentTool<typeof testSchema, void> = {
|
||||
label: "Test Tool",
|
||||
name: "test_tool",
|
||||
description: "A test tool for validation",
|
||||
parameters: testSchema,
|
||||
execute: async (_toolCallId, args) => {
|
||||
return {
|
||||
content: [{ type: "text", text: `Processed: ${args.name}, ${args.age}, ${args.email}` }],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
// Create AJV instance for validation
|
||||
const ajv = new Ajv({ allErrors: true });
|
||||
addFormats(ajv);
|
||||
|
||||
it("should validate correct input", () => {
|
||||
const validInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
tags: ["developer", "typescript"],
|
||||
};
|
||||
|
||||
// Validate with AJV
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(validInput);
|
||||
expect(isValid).toBe(true);
|
||||
});
|
||||
|
||||
it("should reject invalid email", () => {
|
||||
const invalidInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "not-an-email",
|
||||
};
|
||||
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(invalidInput);
|
||||
expect(isValid).toBe(false);
|
||||
expect(validate.errors).toBeDefined();
|
||||
});
|
||||
|
||||
it("should reject missing required fields", () => {
|
||||
const invalidInput = {
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(invalidInput);
|
||||
expect(isValid).toBe(false);
|
||||
expect(validate.errors).toBeDefined();
|
||||
});
|
||||
|
||||
it("should reject invalid age", () => {
|
||||
const invalidInput = {
|
||||
name: "John Doe",
|
||||
age: -5,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(invalidInput);
|
||||
expect(isValid).toBe(false);
|
||||
expect(validate.errors).toBeDefined();
|
||||
});
|
||||
|
||||
it("should format validation errors nicely", () => {
|
||||
const invalidInput = {
|
||||
name: "",
|
||||
age: 200,
|
||||
email: "invalid",
|
||||
};
|
||||
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(invalidInput);
|
||||
expect(isValid).toBe(false);
|
||||
expect(validate.errors).toBeDefined();
|
||||
|
||||
if (validate.errors) {
|
||||
const errors = validate.errors
|
||||
.map((err: any) => {
|
||||
const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
// AJV error messages are different from Zod
|
||||
expect(errors).toContain("name: must NOT have fewer than 1 characters");
|
||||
expect(errors).toContain("age: must be <= 150");
|
||||
expect(errors).toContain('email: must match format "email"');
|
||||
}
|
||||
});
|
||||
|
||||
it("should have type-safe execute function", async () => {
|
||||
const validInput = {
|
||||
name: "John Doe",
|
||||
age: 30,
|
||||
email: "john@example.com",
|
||||
};
|
||||
|
||||
// Validate and execute
|
||||
const validate = ajv.compile(testTool.parameters);
|
||||
const isValid = validate(validInput);
|
||||
expect(isValid).toBe(true);
|
||||
|
||||
const result = await testTool.execute("test-id", validInput as TestParams);
|
||||
|
||||
const textOutput = result.content
|
||||
.filter((c: any) => c.type === "text")
|
||||
.map((c: any) => c.text)
|
||||
.join("\n");
|
||||
expect(textOutput).toBe("Processed: John Doe, 30, john@example.com");
|
||||
expect(result.details).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -2,13 +2,68 @@
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Session tree structure (v2)**: Sessions now store entries as a tree with `id`/`parentId` fields, enabling in-place branching without creating new files. Existing v1 sessions are auto-migrated on load.
|
||||
- **SessionManager API**:
|
||||
- `saveXXX()` renamed to `appendXXX()` (e.g., `appendMessage`, `appendCompaction`)
|
||||
- `branchInPlace()` renamed to `branch()`
|
||||
- `reset()` renamed to `newSession()`
|
||||
- `createBranchedSessionFromEntries(entries, index)` replaced with `createBranchedSession(leafId)`
|
||||
- `saveCompaction(entry)` replaced with `appendCompaction(summary, firstKeptEntryId, tokensBefore)`
|
||||
- `getEntries()` now excludes the session header (use `getHeader()` separately)
|
||||
- New methods: `getTree()`, `getPath()`, `getLeafUuid()`, `getLeafEntry()`, `getEntry()`, `branchWithSummary()`
|
||||
- New `appendCustomEntry(customType, data)` for hooks to store custom data (not in LLM context)
|
||||
- New `appendCustomMessageEntry(customType, content, display, details?)` for hooks to inject messages into LLM context
|
||||
- **Compaction API**:
|
||||
- `CompactionEntry<T>` and `CompactionResult<T>` are now generic with optional `details?: T` for hook-specific data
|
||||
- `compact()` now returns `CompactionResult` (`{ summary, firstKeptEntryId, tokensBefore, details? }`) instead of `CompactionEntry`
|
||||
- `appendCompaction()` now accepts optional `details` parameter
|
||||
- `CompactionEntry.firstKeptEntryIndex` replaced with `firstKeptEntryId`
|
||||
- `prepareCompaction()` now returns `firstKeptEntryId` in its result
|
||||
- **Hook types**:
|
||||
- `SessionEventBase` no longer has `sessionManager`/`modelRegistry` - access them via `HookEventContext` instead
|
||||
- `HookEventContext` now has `sessionManager` and `modelRegistry` (moved from events)
|
||||
- `HookEventContext` no longer has `exec()` - use `pi.exec()` instead
|
||||
- `HookCommandContext` no longer has `exec()` - use `pi.exec()` instead
|
||||
- `before_compact` event passes `preparation: CompactionPreparation` and `previousCompactions: CompactionEntry[]` (newest first)
|
||||
- `before_switch` event now has `targetSessionFile`, `switch` event has `previousSessionFile`
|
||||
- Removed `resolveApiKey` (use `modelRegistry.getApiKey(model)`)
|
||||
- Hooks can return `compaction.details` to store custom data (e.g., ArtifactIndex for structured compaction)
|
||||
- **Hook API**:
|
||||
- `pi.send(text, attachments?)` replaced with `pi.sendMessage(message, triggerTurn?)` which creates `CustomMessageEntry` instead of user messages
|
||||
- New `pi.appendEntry(customType, data?)` to persist hook state (does NOT participate in LLM context)
|
||||
- New `pi.registerCommand(name, options)` to register custom slash commands
|
||||
- New `pi.registerMessageRenderer(customType, renderer)` to register custom renderers for hook messages
|
||||
- New `pi.exec(command, args, options?)` to execute shell commands (moved from `HookEventContext`/`HookCommandContext`)
|
||||
- `HookMessageRenderer` type: `(message: HookMessage, options, theme) => Component | null`
|
||||
- Renderers return inner content; the TUI wraps it in a styled Box
|
||||
- New types: `HookMessage<T>`, `RegisteredCommand`, `HookCommandContext`
|
||||
- Handler types renamed: `SendHandler` → `SendMessageHandler`, new `AppendEntryHandler`
|
||||
- **SessionManager**:
|
||||
- `getSessionFile()` now returns `string | undefined` (undefined for in-memory sessions)
|
||||
- **Themes**: Custom themes must add `customMessageBg`, `customMessageText`, `customMessageLabel` color tokens
|
||||
|
||||
### Added
|
||||
|
||||
- **`enabledModels` setting**: Configure whitelisted models in `settings.json` (same format as `--models` CLI flag). CLI `--models` takes precedence over the setting.
|
||||
|
||||
### Changed
|
||||
|
||||
- **Entry IDs**: Session entries now use short 8-character hex IDs instead of full UUIDs
|
||||
- **API key priority**: `ANTHROPIC_OAUTH_TOKEN` now takes precedence over `ANTHROPIC_API_KEY`
|
||||
- **New entry types**: `BranchSummaryEntry` for branch context, `CustomEntry<T>` for hook state persistence, `CustomMessageEntry<T>` for hook-injected context messages, `LabelEntry` for user-defined bookmarks
|
||||
- **Entry labels**: New `getLabel(id)` and `appendLabelChange(targetId, label)` methods for labeling entries. Labels are included in `SessionTreeNode` for UI/export.
|
||||
- **TUI**: `CustomMessageEntry` renders with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors). Entries with `display: false` are hidden.
|
||||
- **AgentSession**: New `sendHookMessage(message, triggerTurn?)` method for hooks to inject messages. Handles queuing during streaming, direct append when idle, and optional turn triggering.
|
||||
- **HookMessage**: New message type with `role: "hookMessage"` for hook-injected messages in agent events. Use `isHookMessage(msg)` type guard to identify them. These are converted to user messages for LLM context via `messageTransformer`.
|
||||
- **Agent.prompt()**: Now accepts `AppMessage` directly (in addition to `string, attachments?`) for custom message types like `HookMessage`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Edit tool fails on Windows due to CRLF line endings**: Files with CRLF line endings now match correctly when LLMs send LF-only text. Line endings are normalized before matching and restored to original style on write. ([#355](https://github.com/badlogic/pi-mono/issues/355))
|
||||
- **Session file validation**: `findMostRecentSession()` now validates session headers before returning, preventing non-session JSONL files from being loaded
|
||||
- **Compaction error handling**: `generateSummary()` and `generateTurnPrefixSummary()` now throw on LLM errors instead of returning empty strings
|
||||
|
||||
## [0.30.2] - 2025-12-26
|
||||
|
||||
|
|
|
|||
441
packages/coding-agent/docs/session-tree-plan.md
Normal file
441
packages/coding-agent/docs/session-tree-plan.md
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
# Session Tree Implementation Plan
|
||||
|
||||
Reference: [session-tree.md](./session-tree.md)
|
||||
|
||||
## Phase 1: SessionManager Core ✅
|
||||
|
||||
- [x] Update entry types with `id`, `parentId` fields (using SessionEntryBase)
|
||||
- [x] Add `version` field to `SessionHeader`
|
||||
- [x] Change `CompactionEntry.firstKeptEntryIndex` → `firstKeptEntryId`
|
||||
- [x] Add `BranchSummaryEntry` type
|
||||
- [x] Add `CustomEntry` type for hooks
|
||||
- [x] Add `byId: Map<string, SessionEntry>` index
|
||||
- [x] Add `leafId: string` tracking
|
||||
- [x] Implement `getPath(fromId?)` tree traversal
|
||||
- [x] Implement `getTree()` returning `SessionTreeNode[]`
|
||||
- [x] Implement `getEntry(id)` lookup
|
||||
- [x] Implement `getLeafUuid()` and `getLeafEntry()` helpers
|
||||
- [x] Update `_buildIndex()` to populate `byId` map
|
||||
- [x] Rename `saveXXX()` to `appendXXX()` (returns id, advances leaf)
|
||||
- [x] Add `appendCustomEntry(customType, data)` for hooks
|
||||
- [x] Update `buildSessionContext()` to use `getPath()` traversal
|
||||
|
||||
## Phase 2: Migration ✅
|
||||
|
||||
- [x] Add `CURRENT_SESSION_VERSION = 2` constant
|
||||
- [x] Implement `migrateV1ToV2()` with extensible migration chain
|
||||
- [x] Update `setSessionFile()` to detect version and migrate
|
||||
- [x] Implement `_rewriteFile()` for post-migration persistence
|
||||
- [x] Handle `firstKeptEntryIndex` → `firstKeptEntryId` conversion in migration
|
||||
|
||||
## Phase 3: Branching ✅
|
||||
|
||||
- [x] Implement `branch(id)` - switch leaf pointer
|
||||
- [x] Implement `branchWithSummary(id, summary)` - create summary entry
|
||||
- [x] Implement `createBranchedSession(leafId)` - extract path to new file
|
||||
- [x] Update `AgentSession.branch()` to use new API
|
||||
|
||||
## Phase 4: Compaction Integration ✅
|
||||
|
||||
- [x] Update `compaction.ts` to work with IDs
|
||||
- [x] Update `prepareCompaction()` to return `firstKeptEntryId`
|
||||
- [x] Update `compact()` to return `CompactionResult` with `firstKeptEntryId`
|
||||
- [x] Update `AgentSession` compaction methods
|
||||
- [x] Add `firstKeptEntryId` to `before_compact` hook event
|
||||
|
||||
## Phase 5: Testing ✅
|
||||
|
||||
- [x] `migration.test.ts` - v1 to v2 migration, idempotency
|
||||
- [x] `build-context.test.ts` - context building with tree structure, compaction, branches
|
||||
- [x] `tree-traversal.test.ts` - append operations, getPath, getTree, branching
|
||||
- [x] `file-operations.test.ts` - loadEntriesFromFile, findMostRecentSession
|
||||
- [x] `save-entry.test.ts` - custom entry integration
|
||||
- [x] Update existing compaction tests for new types
|
||||
|
||||
---
|
||||
|
||||
## Remaining Work
|
||||
|
||||
### Compaction Refactor
|
||||
|
||||
- [x] Use `CompactionResult` type for hook return value
|
||||
- [x] Make `CompactionEntry<T>` generic with optional `details?: T` field for hook-specific data
|
||||
- [x] Make `CompactionResult<T>` generic to match
|
||||
- [x] Update `SessionEventBase` to pass `sessionManager` and `modelRegistry` instead of derived fields
|
||||
- [x] Update `before_compact` event:
|
||||
- Pass `preparation: CompactionPreparation` instead of individual fields
|
||||
- Pass `previousCompactions: CompactionEntry[]` (newest first) instead of `previousSummary?: string`
|
||||
- Keep: `customInstructions`, `model`, `signal`
|
||||
- Drop: `resolveApiKey` (use `modelRegistry.getApiKey()`), `cutPoint`, `entries`
|
||||
- [x] Update hook example `custom-compaction.ts` to use new API
|
||||
- [x] Update `getSessionFile()` to return `string | undefined` for in-memory sessions
|
||||
- [x] Update `before_switch` to have `targetSessionFile`, `switch` to have `previousSessionFile`
|
||||
|
||||
Reference: [#314](https://github.com/badlogic/pi-mono/pull/314) - Structured compaction with anchored iterative summarization needs `details` field to store `ArtifactIndex` and version markers.
|
||||
|
||||
### Branch Summary Design ✅
|
||||
|
||||
Current type:
|
||||
```typescript
|
||||
export interface BranchSummaryEntry extends SessionEntryBase {
|
||||
type: "branch_summary";
|
||||
summary: string;
|
||||
fromId: string; // References the abandoned leaf
|
||||
fromHook?: boolean; // Whether summary was generated by a hook
|
||||
details?: unknown; // File tracking: { readFiles, modifiedFiles }
|
||||
}
|
||||
```
|
||||
|
||||
- [x] `fromId` field references the abandoned leaf
|
||||
- [x] `fromHook` field distinguishes pi-generated vs hook-generated summaries
|
||||
- [x] `details` field for file tracking
|
||||
- [x] Branch summarizer implemented with structured output format
|
||||
- [x] Uses serialization approach (same as compaction) to prevent model confusion
|
||||
- [x] Tests for `branchWithSummary()` flow
|
||||
|
||||
### Entry Labels ✅
|
||||
|
||||
- [x] Add `LabelEntry` type with `targetId` and `label` fields
|
||||
- [x] Add `labelsById: Map<string, string>` private field
|
||||
- [x] Build labels map in `_buildIndex()` via linear scan
|
||||
- [x] Add `getLabel(id)` method
|
||||
- [x] Add `appendLabelChange(targetId, label)` method (undefined clears)
|
||||
- [x] Update `createBranchedSession()` to filter out LabelEntry and recreate from resolved map
|
||||
- [x] `buildSessionContext()` already ignores LabelEntry (only handles message types)
|
||||
- [x] Add `label?: string` to `SessionTreeNode`, populated by `getTree()`
|
||||
- [x] Display labels in UI (tree-selector shows labels)
|
||||
- [x] `/label` command (implemented in tree-selector)
|
||||
|
||||
### CustomMessageEntry<T>
|
||||
|
||||
Hook-injected messages that participate in LLM context. Unlike `CustomEntry<T>` (for hook state only), these are sent to the model.
|
||||
|
||||
```typescript
|
||||
export interface CustomMessageEntry<T = unknown> extends SessionEntryBase {
|
||||
type: "custom_message";
|
||||
customType: string; // Hook identifier
|
||||
content: string | (TextContent | ImageContent)[]; // Message content (same as UserMessage)
|
||||
details?: T; // Hook-specific data for state reconstruction on reload
|
||||
display: boolean; // Whether to display in TUI
|
||||
}
|
||||
```
|
||||
|
||||
Behavior:
|
||||
- [x] Type definition matching plan
|
||||
- [x] `appendCustomMessageEntry(customType, content, display, details?)` in SessionManager
|
||||
- [x] `buildSessionContext()` includes custom_message entries as user messages
|
||||
- [x] Exported from main index
|
||||
- [x] TUI rendering:
|
||||
- `display: false` - hidden entirely
|
||||
- `display: true` - rendered with purple styling (customMessageBg, customMessageText, customMessageLabel theme colors)
|
||||
- [x] `registerCustomMessageRenderer(customType, renderer)` in HookAPI for custom renderers
|
||||
- [x] Renderer returns inner Component, TUI wraps in styled Box
|
||||
|
||||
### Hook API Changes ✅
|
||||
|
||||
**Renamed:**
|
||||
- `renderCustomMessage()` → `registerCustomMessageRenderer()`
|
||||
|
||||
**New: `sendMessage()` ✅**
|
||||
|
||||
Replaces `send()`. Always creates CustomMessageEntry, never user messages.
|
||||
|
||||
```typescript
|
||||
type HookMessage<T = unknown> = Pick<CustomMessageEntry<T>, 'customType' | 'content' | 'display' | 'details'>;
|
||||
|
||||
sendMessage(message: HookMessage, triggerTurn?: boolean): void;
|
||||
```
|
||||
|
||||
Implementation:
|
||||
- Uses agent's queue mechanism with `_hookData` marker on AppMessage
|
||||
- `message_end` handler routes based on marker presence
|
||||
- `AgentSession.sendHookMessage()` handles three cases:
|
||||
- Streaming: queues via `agent.queueMessage()`, loop processes and emits `message_end`
|
||||
- Not streaming + triggerTurn: direct append + `agent.continue()`
|
||||
- Not streaming + no trigger: direct append only
|
||||
- TUI updates via event (streaming) or explicit rebuild (non-streaming)
|
||||
|
||||
**New: `appendEntry()` ✅**
|
||||
|
||||
For hook state persistence (NOT in LLM context):
|
||||
|
||||
```typescript
|
||||
appendEntry(customType: string, data?: unknown): void;
|
||||
```
|
||||
|
||||
Calls `sessionManager.appendCustomEntry()` directly.
|
||||
|
||||
**New: `registerCommand()` (types ✅, wiring TODO)**
|
||||
|
||||
```typescript
|
||||
// HookAPI (the `pi` object) - utilities available to all hooks:
|
||||
interface HookAPI {
|
||||
sendMessage(message: HookMessage, triggerTurn?: boolean): void;
|
||||
appendEntry(customType: string, data?: unknown): void;
|
||||
registerCommand(name: string, options: RegisteredCommand): void;
|
||||
registerCustomMessageRenderer(customType: string, renderer: CustomMessageRenderer): void;
|
||||
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
|
||||
}
|
||||
|
||||
// HookEventContext - passed to event handlers, has stable context:
|
||||
interface HookEventContext {
|
||||
ui: HookUIContext;
|
||||
hasUI: boolean;
|
||||
cwd: string;
|
||||
sessionManager: SessionManager;
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
// Note: exec moved to HookAPI, sessionManager/modelRegistry moved from SessionEventBase
|
||||
|
||||
// HookCommandContext - passed to command handlers:
|
||||
interface HookCommandContext {
|
||||
args: string; // Everything after /commandname
|
||||
ui: HookUIContext;
|
||||
hasUI: boolean;
|
||||
cwd: string;
|
||||
sessionManager: SessionManager;
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
// Note: exec and sendMessage accessed via `pi` closure
|
||||
|
||||
registerCommand(name: string, options: {
|
||||
description?: string;
|
||||
handler: (ctx: HookCommandContext) => Promise<void>;
|
||||
}): void;
|
||||
```
|
||||
|
||||
Handler return:
|
||||
- `void` - command completed (use `sendMessage()` with `triggerTurn: true` to prompt LLM)
|
||||
|
||||
Wiring (all in AgentSession.prompt()):
|
||||
- [x] Add hook commands to autocomplete in interactive-mode
|
||||
- [x] `_tryExecuteHookCommand()` in AgentSession handles command execution
|
||||
- [x] Build HookCommandContext with ui (from hookRunner), exec, sessionManager, etc.
|
||||
- [x] If handler returns string, use as prompt text
|
||||
- [x] If handler returns undefined, return early (no LLM call)
|
||||
- [x] Works for all modes (interactive, RPC, print) via shared AgentSession
|
||||
|
||||
**New: `ui.custom()` ✅**
|
||||
|
||||
For arbitrary hook UI with keyboard focus:
|
||||
|
||||
```typescript
|
||||
interface HookUIContext {
|
||||
// ... existing: select, confirm, input, notify
|
||||
|
||||
/** Show custom component with keyboard focus. Call done() when finished. */
|
||||
custom(component: Component, done: () => void): void;
|
||||
}
|
||||
```
|
||||
|
||||
See also: `CustomEntry<T>` for storing hook state that does NOT participate in context.
|
||||
|
||||
**New: `context` event ✅**
|
||||
|
||||
Fires before messages are sent to the LLM, allowing hooks to modify context non-destructively.
|
||||
|
||||
```typescript
|
||||
interface ContextEvent {
|
||||
type: "context";
|
||||
/** Messages that will be sent to the LLM */
|
||||
messages: Message[];
|
||||
}
|
||||
|
||||
interface ContextEventResult {
|
||||
/** Modified messages to send instead */
|
||||
messages?: Message[];
|
||||
}
|
||||
|
||||
// In HookAPI:
|
||||
on(event: "context", handler: HookHandler<ContextEvent, ContextEventResult | void>): void;
|
||||
```
|
||||
|
||||
Example use case: **Dynamic Context Pruning** ([discussion #330](https://github.com/badlogic/pi-mono/discussions/330))
|
||||
|
||||
Non-destructive pruning of tool results to reduce context size:
|
||||
|
||||
```typescript
|
||||
export default function(pi: HookAPI) {
|
||||
// Register /prune command
|
||||
pi.registerCommand("prune", {
|
||||
description: "Mark tool results for pruning",
|
||||
handler: async (ctx) => {
|
||||
// Show UI to select which tool results to prune
|
||||
// Append custom entry recording pruning decisions:
|
||||
// { toolResultId, strategy: "summary" | "truncate" | "remove" }
|
||||
pi.appendEntry("tool-result-pruning", { ... });
|
||||
}
|
||||
});
|
||||
|
||||
// Intercept context before LLM call
|
||||
pi.on("context", async (event, ctx) => {
|
||||
// Find all pruning entries in session
|
||||
const entries = ctx.sessionManager.getEntries();
|
||||
const pruningRules = entries
|
||||
.filter(e => e.type === "custom" && e.customType === "tool-result-pruning")
|
||||
.map(e => e.data);
|
||||
|
||||
// Apply pruning rules to messages
|
||||
const prunedMessages = applyPruning(event.messages, pruningRules);
|
||||
return { messages: prunedMessages };
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
Benefits:
|
||||
- Original tool results stay intact in session
|
||||
- Pruning is stored as custom entries, survives session reload
|
||||
- Works with branching (pruning entries are part of the tree)
|
||||
- Trade-off: cache busting on first submission after pruning
|
||||
|
||||
### Investigate: `context` event vs `before_agent_start` ✅
|
||||
|
||||
References:
|
||||
- [#324](https://github.com/badlogic/pi-mono/issues/324) - `before_agent_start` proposal
|
||||
- [#330](https://github.com/badlogic/pi-mono/discussions/330) - Dynamic Context Pruning (why `context` was added)
|
||||
|
||||
**Current `context` event:**
|
||||
- Fires before each LLM call within the agent loop
|
||||
- Receives `AgentMessage[]` (deep copy, safe to modify)
|
||||
- Returns `Message[]` (inconsistent with input type)
|
||||
- Modifications are transient (not persisted to session)
|
||||
- No TUI visibility of what was changed
|
||||
- Use case: non-destructive pruning, dynamic context manipulation
|
||||
|
||||
**Type inconsistency:** Event receives `AgentMessage[]` but result returns `Message[]`:
|
||||
```typescript
|
||||
interface ContextEvent {
|
||||
messages: AgentMessage[]; // Input
|
||||
}
|
||||
interface ContextEventResult {
|
||||
messages?: Message[]; // Output - different type!
|
||||
}
|
||||
```
|
||||
|
||||
Questions:
|
||||
- [ ] Should input/output both be `Message[]` (LLM format)?
|
||||
- [ ] Or both be `AgentMessage[]` with conversion happening after?
|
||||
- [ ] Where does `AgentMessage[]` → `Message[]` conversion currently happen?
|
||||
|
||||
**Proposed `before_agent_start` event:**
|
||||
- Fires once when user submits a prompt, before `agent_start`
|
||||
- Allows hooks to inject additional content that gets **persisted** to session
|
||||
- Injected content is visible in TUI (observability)
|
||||
- Does not bust prompt cache (appended after user message, not modifying system prompt)
|
||||
|
||||
**Key difference:**
|
||||
| Aspect | `context` | `before_agent_start` |
|
||||
|--------|-----------|---------------------|
|
||||
| When | Before each LLM call | Once per user prompt |
|
||||
| Persisted | No | Yes (as SystemMessage) |
|
||||
| TUI visible | No | Yes (collapsible) |
|
||||
| Cache impact | Can bust cache | Append-only, cache-safe |
|
||||
| Use case | Transient manipulation | Persistent context injection |
|
||||
|
||||
**Implementation (completed):**
|
||||
- Reuses `HookMessage` type (no new message type needed)
|
||||
- Handler returns `{ message: Pick<HookMessage, "customType" | "content" | "display" | "details"> }`
|
||||
- Message is appended to agent state AND persisted to session before `agent.prompt()` is called
|
||||
- Renders using existing `HookMessageComponent` (or custom renderer if registered)
|
||||
- [ ] How does it interact with compaction? (treated like user messages?)
|
||||
- [ ] Can hook return multiple messages or just one?
|
||||
|
||||
**Implementation sketch:**
|
||||
```typescript
|
||||
interface BeforeAgentStartEvent {
|
||||
type: "before_agent_start";
|
||||
userMessage: UserMessage; // The prompt user just submitted
|
||||
}
|
||||
|
||||
interface BeforeAgentStartResult {
|
||||
/** Additional context to inject (persisted as SystemMessage) */
|
||||
inject?: {
|
||||
label: string; // Shown in collapsed TUI state
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### HTML Export
|
||||
|
||||
- [ ] Add collapsible sidebar showing full tree structure
|
||||
- [ ] Allow selecting any node in tree to view that path
|
||||
- [ ] Add "reset to session leaf" button
|
||||
- [ ] Render full path (no compaction resolution needed)
|
||||
- [ ] Responsive: collapse sidebar on mobile
|
||||
|
||||
### UI Commands ✅
|
||||
|
||||
- [x] `/branch` - Creates new session file from current path (uses `createBranchedSession()`)
|
||||
- [x] `/tree` - In-session tree navigation via tree-selector component
|
||||
- Shows full tree structure with labels
|
||||
- Navigate between branches (moves leaf pointer)
|
||||
- Shows current position
|
||||
- Generates branch summaries when switching branches
|
||||
|
||||
### Tree Selector Improvements ✅
|
||||
|
||||
- [x] Active line highlight using `selectedBg` theme color
|
||||
- [x] Filter modes via `^O` (forward) / `Shift+^O` (backward):
|
||||
- `default`: hides label/custom entries
|
||||
- `no-tools`: default minus tool results
|
||||
- `user-only`: just user messages
|
||||
- `labeled-only`: just labeled entries
|
||||
- `all`: everything
|
||||
|
||||
### Documentation
|
||||
|
||||
Review and update all docs:
|
||||
|
||||
- [ ] `docs/hooks.md` - Major update for hook API:
|
||||
- `pi.send()` → `pi.sendMessage()` with new signature
|
||||
- New `pi.appendEntry()` for state persistence
|
||||
- New `pi.registerCommand()` for custom slash commands
|
||||
- New `pi.registerCustomMessageRenderer()` for custom TUI rendering
|
||||
- `HookCommandContext` interface and handler patterns
|
||||
- `HookMessage<T>` type
|
||||
- Updated event signatures (`SessionEventBase`, `before_compact`, etc.)
|
||||
- [ ] `docs/hooks-v2.md` - Review/merge or remove if obsolete
|
||||
- [ ] `docs/sdk.md` - Update for:
|
||||
- `HookMessage` and `isHookMessage()`
|
||||
- `Agent.prompt(AppMessage)` overload
|
||||
- Session v2 tree structure
|
||||
- SessionManager API changes
|
||||
- [ ] `docs/session.md` - Update for v2 tree structure, new entry types
|
||||
- [ ] `docs/custom-tools.md` - Check if hook changes affect custom tools
|
||||
- [ ] `docs/rpc.md` - Check if hook commands work in RPC mode
|
||||
- [ ] `docs/skills.md` - Review for any hook-related updates
|
||||
- [ ] `docs/extension-loading.md` - Review
|
||||
- [x] `docs/theme.md` - Added selectedBg, customMessageBg/Text/Label color tokens (50 total)
|
||||
- [ ] `README.md` - Update hook examples if any
|
||||
|
||||
### Examples
|
||||
|
||||
Review and update examples:
|
||||
|
||||
- [ ] `examples/hooks/` - Update existing, add new examples:
|
||||
- [ ] Review `custom-compaction.ts` for new API
|
||||
- [ ] Add `registerCommand()` example
|
||||
- [ ] Add `sendMessage()` example
|
||||
- [ ] Add `registerCustomMessageRenderer()` example
|
||||
- [ ] `examples/sdk/` - Update for new session/hook APIs
|
||||
- [ ] `examples/custom-tools/` - Review for compatibility
|
||||
|
||||
---
|
||||
|
||||
## Before Release
|
||||
|
||||
- [ ] Run full automated test suite: `npm test`
|
||||
- [ ] Manual testing of tree navigation and branch summarization
|
||||
- [ ] Verify compaction with file tracking works correctly
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
- All append methods return the new entry's ID
|
||||
- Migration rewrites file on first load if version < CURRENT_VERSION
|
||||
- Existing sessions become linear chains after migration (parentId = previous entry)
|
||||
- Tree features available immediately after migration
|
||||
- SessionHeader does NOT have id/parentId (it's metadata, not part of tree)
|
||||
- Session is append-only: entries cannot be modified or deleted, only branching changes the leaf pointer
|
||||
|
|
@ -21,12 +21,16 @@ Every theme must define all color tokens. There are no optional colors.
|
|||
| `dim` | Very dimmed text | Less important info, placeholders |
|
||||
| `text` | Default text color | Main content (usually `""`) |
|
||||
|
||||
### Backgrounds & Content Text (7 colors)
|
||||
### Backgrounds & Content Text (11 colors)
|
||||
|
||||
| Token | Purpose |
|
||||
|-------|---------|
|
||||
| `selectedBg` | Selected/active line background (e.g., tree selector) |
|
||||
| `userMessageBg` | User message background |
|
||||
| `userMessageText` | User message text color |
|
||||
| `customMessageBg` | Hook custom message background |
|
||||
| `customMessageText` | Hook custom message text color |
|
||||
| `customMessageLabel` | Hook custom message label/type text |
|
||||
| `toolPendingBg` | Tool execution box (pending state) |
|
||||
| `toolSuccessBg` | Tool execution box (success state) |
|
||||
| `toolErrorBg` | Tool execution box (error state) |
|
||||
|
|
@ -95,7 +99,7 @@ These create a visual hierarchy: off → minimal → low → medium → high →
|
|||
|-------|---------|
|
||||
| `bashMode` | Editor border color when in bash mode (! prefix) |
|
||||
|
||||
**Total: 46 color tokens** (all required)
|
||||
**Total: 50 color tokens** (all required)
|
||||
|
||||
## Theme Format
|
||||
|
||||
|
|
|
|||
197
packages/coding-agent/docs/tree.md
Normal file
197
packages/coding-agent/docs/tree.md
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
# Session Tree Navigation
|
||||
|
||||
The `/tree` command provides tree-based navigation of the session history.
|
||||
|
||||
## Overview
|
||||
|
||||
Sessions are stored as trees where each entry has an `id` and `parentId`. The "leaf" pointer tracks the current position. `/tree` lets you navigate to any point and optionally summarize the branch you're leaving.
|
||||
|
||||
### Comparison with `/branch`
|
||||
|
||||
| Feature | `/branch` | `/tree` |
|
||||
|---------|-----------|---------|
|
||||
| View | Flat list of user messages | Full tree structure |
|
||||
| Action | Extracts path to **new session file** | Changes leaf in **same session** |
|
||||
| Summary | Never | Optional (user prompted) |
|
||||
| Events | `session_before_branch` / `session_branch` | `session_before_tree` / `session_tree` |
|
||||
|
||||
## Tree UI
|
||||
|
||||
```
|
||||
├─ user: "Hello, can you help..."
|
||||
│ └─ assistant: "Of course! I can..."
|
||||
│ ├─ user: "Let's try approach A..."
|
||||
│ │ └─ assistant: "For approach A..."
|
||||
│ │ └─ [compaction: 12k tokens]
|
||||
│ │ └─ user: "That worked..." ← active
|
||||
│ └─ user: "Actually, approach B..."
|
||||
│ └─ assistant: "For approach B..."
|
||||
```
|
||||
|
||||
### Controls
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| ↑/↓ | Navigate (depth-first order) |
|
||||
| Enter | Select node |
|
||||
| Escape/Ctrl+C | Cancel |
|
||||
| Ctrl+U | Toggle: user messages only |
|
||||
| Ctrl+O | Toggle: show all (including custom/label entries) |
|
||||
|
||||
### Display
|
||||
|
||||
- Height: half terminal height
|
||||
- Current leaf marked with `← active`
|
||||
- Labels shown inline: `[label-name]`
|
||||
- Default filter hides `label` and `custom` entries (shown in Ctrl+O mode)
|
||||
- Children sorted by timestamp (oldest first)
|
||||
|
||||
## Selection Behavior
|
||||
|
||||
### User Message or Custom Message
|
||||
1. Leaf set to **parent** of selected node (or `null` if root)
|
||||
2. Message text placed in **editor** for re-submission
|
||||
3. User edits and submits, creating a new branch
|
||||
|
||||
### Non-User Message (assistant, compaction, etc.)
|
||||
1. Leaf set to **selected node**
|
||||
2. Editor stays empty
|
||||
3. User continues from that point
|
||||
|
||||
### Selecting Root User Message
|
||||
If user selects the very first message (has no parent):
|
||||
1. Leaf reset to `null` (empty conversation)
|
||||
2. Message text placed in editor
|
||||
3. User effectively restarts from scratch
|
||||
|
||||
## Branch Summarization
|
||||
|
||||
When switching, user is prompted: "Summarize the branch you're leaving?"
|
||||
|
||||
### What Gets Summarized
|
||||
|
||||
Path from old leaf back to common ancestor with target:
|
||||
|
||||
```
|
||||
A → B → C → D → E → F ← old leaf
|
||||
↘ G → H ← target
|
||||
```
|
||||
|
||||
Abandoned path: D → E → F (summarized)
|
||||
|
||||
Summarization stops at:
|
||||
1. Common ancestor (always)
|
||||
2. Compaction node (if encountered first)
|
||||
|
||||
### Summary Storage
|
||||
|
||||
Stored as `BranchSummaryEntry`:
|
||||
|
||||
```typescript
|
||||
interface BranchSummaryEntry {
|
||||
type: "branch_summary";
|
||||
id: string;
|
||||
parentId: string; // New leaf position
|
||||
timestamp: string;
|
||||
fromId: string; // Old leaf we abandoned
|
||||
summary: string; // LLM-generated summary
|
||||
details?: unknown; // Optional hook data
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation
|
||||
|
||||
### AgentSession.navigateTree()
|
||||
|
||||
```typescript
|
||||
async navigateTree(
|
||||
targetId: string,
|
||||
options?: { summarize?: boolean; customInstructions?: string }
|
||||
): Promise<{ editorText?: string; cancelled: boolean }>
|
||||
```
|
||||
|
||||
Flow:
|
||||
1. Validate target, check no-op (target === current leaf)
|
||||
2. Find common ancestor between old leaf and target
|
||||
3. Collect entries to summarize (if requested)
|
||||
4. Fire `session_before_tree` event (hook can cancel or provide summary)
|
||||
5. Run default summarizer if needed
|
||||
6. Switch leaf via `branch()` or `branchWithSummary()`
|
||||
7. Update agent: `agent.replaceMessages(sessionManager.buildSessionContext().messages)`
|
||||
8. Fire `session_tree` event
|
||||
9. Notify custom tools via session event
|
||||
10. Return result with `editorText` if user message was selected
|
||||
|
||||
### SessionManager
|
||||
|
||||
- `getLeafUuid(): string | null` - Current leaf (null if empty)
|
||||
- `resetLeaf(): void` - Set leaf to null (for root user message navigation)
|
||||
- `getTree(): SessionTreeNode[]` - Full tree with children sorted by timestamp
|
||||
- `branch(id)` - Change leaf pointer
|
||||
- `branchWithSummary(id, summary)` - Change leaf and create summary entry
|
||||
|
||||
### InteractiveMode
|
||||
|
||||
`/tree` command shows `TreeSelectorComponent`, then:
|
||||
1. Prompt for summarization
|
||||
2. Call `session.navigateTree()`
|
||||
3. Clear and re-render chat
|
||||
4. Set editor text if applicable
|
||||
|
||||
## Hook Events
|
||||
|
||||
### `session_before_tree`
|
||||
|
||||
```typescript
|
||||
interface TreePreparation {
|
||||
targetId: string;
|
||||
oldLeafId: string | null;
|
||||
commonAncestorId: string | null;
|
||||
entriesToSummarize: SessionEntry[];
|
||||
userWantsSummary: boolean;
|
||||
}
|
||||
|
||||
interface SessionBeforeTreeEvent {
|
||||
type: "session_before_tree";
|
||||
preparation: TreePreparation;
|
||||
model: Model;
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
interface SessionBeforeTreeResult {
|
||||
cancel?: boolean;
|
||||
summary?: { summary: string; details?: unknown };
|
||||
}
|
||||
```
|
||||
|
||||
### `session_tree`
|
||||
|
||||
```typescript
|
||||
interface SessionTreeEvent {
|
||||
type: "session_tree";
|
||||
newLeafId: string | null;
|
||||
oldLeafId: string | null;
|
||||
summaryEntry?: BranchSummaryEntry;
|
||||
fromHook?: boolean;
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Custom Summarizer
|
||||
|
||||
```typescript
|
||||
export default function(pi: HookAPI) {
|
||||
pi.on("session_before_tree", async (event, ctx) => {
|
||||
if (!event.preparation.userWantsSummary) return;
|
||||
if (event.preparation.entriesToSummarize.length === 0) return;
|
||||
|
||||
const summary = await myCustomSummarizer(event.preparation.entriesToSummarize);
|
||||
return { summary: { summary, details: { custom: true } } };
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Summarization failure: cancels navigation, shows error
|
||||
- User abort (Escape): cancels navigation
|
||||
- Hook returns `cancel: true`: cancels navigation silently
|
||||
|
|
@ -41,7 +41,7 @@ const factory: CustomToolFactory = (pi) => {
|
|||
|
||||
const answer = await pi.ui.select(params.question, params.options);
|
||||
|
||||
if (answer === null) {
|
||||
if (answer === undefined) {
|
||||
return {
|
||||
content: [{ type: "text", text: "User cancelled the selection" }],
|
||||
details: { question: params.question, options: params.options, answer: null },
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ import { spawn } from "node:child_process";
|
|||
import * as fs from "node:fs";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import type { AgentToolResult, Message } from "@mariozechner/pi-ai";
|
||||
import type { AgentToolResult } from "@mariozechner/pi-agent-core";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
import { StringEnum } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
type CustomAgentTool,
|
||||
|
|
|
|||
|
|
@ -8,11 +8,9 @@
|
|||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
if (event.reason !== "shutdown") return;
|
||||
|
||||
pi.on("session_shutdown", async (_event, ctx) => {
|
||||
// Check for uncommitted changes
|
||||
const { stdout: status, code } = await ctx.exec("git", ["status", "--porcelain"]);
|
||||
const { stdout: status, code } = await pi.exec("git", ["status", "--porcelain"]);
|
||||
|
||||
if (code !== 0 || status.trim().length === 0) {
|
||||
// Not a git repo or no changes
|
||||
|
|
@ -20,9 +18,10 @@ export default function (pi: HookAPI) {
|
|||
}
|
||||
|
||||
// Find the last assistant message for commit context
|
||||
const entries = ctx.sessionManager.getEntries();
|
||||
let lastAssistantText = "";
|
||||
for (let i = event.entries.length - 1; i >= 0; i--) {
|
||||
const entry = event.entries[i];
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message" && entry.message.role === "assistant") {
|
||||
const content = entry.message.content;
|
||||
if (Array.isArray(content)) {
|
||||
|
|
@ -40,8 +39,8 @@ export default function (pi: HookAPI) {
|
|||
const commitMessage = `[pi] ${firstLine.slice(0, 50)}${firstLine.length > 50 ? "..." : ""}`;
|
||||
|
||||
// Stage and commit
|
||||
await ctx.exec("git", ["add", "-A"]);
|
||||
const { code: commitCode } = await ctx.exec("git", ["commit", "-m", commitMessage]);
|
||||
await pi.exec("git", ["add", "-A"]);
|
||||
const { code: commitCode } = await pi.exec("git", ["commit", "-m", commitMessage]);
|
||||
|
||||
if (commitCode === 0 && ctx.hasUI) {
|
||||
ctx.ui.notify(`Auto-committed: ${commitMessage}`, "info");
|
||||
|
|
|
|||
|
|
@ -2,59 +2,57 @@
|
|||
* Confirm Destructive Actions Hook
|
||||
*
|
||||
* Prompts for confirmation before destructive session actions (clear, switch, branch).
|
||||
* Demonstrates how to cancel session events using the before_* variants.
|
||||
* Demonstrates how to cancel session events using the before_* events.
|
||||
*/
|
||||
|
||||
import type { SessionMessageEntry } from "@mariozechner/pi-coding-agent";
|
||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
// Only handle before_* events (the ones that can be cancelled)
|
||||
if (event.reason === "before_new") {
|
||||
if (!ctx.hasUI) return;
|
||||
pi.on("session_before_new", async (_event, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const confirmed = await ctx.ui.confirm("Clear session?", "This will delete all messages in the current session.");
|
||||
|
||||
if (!confirmed) {
|
||||
ctx.ui.notify("Clear cancelled", "info");
|
||||
return { cancel: true };
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session_before_switch", async (_event, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
// Check if there are unsaved changes (messages since last assistant response)
|
||||
const entries = ctx.sessionManager.getEntries();
|
||||
const hasUnsavedWork = entries.some(
|
||||
(e): e is SessionMessageEntry => e.type === "message" && e.message.role === "user",
|
||||
);
|
||||
|
||||
if (hasUnsavedWork) {
|
||||
const confirmed = await ctx.ui.confirm(
|
||||
"Clear session?",
|
||||
"This will delete all messages in the current session.",
|
||||
"Switch session?",
|
||||
"You have messages in the current session. Switch anyway?",
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
ctx.ui.notify("Clear cancelled", "info");
|
||||
return { cancel: true };
|
||||
}
|
||||
}
|
||||
|
||||
if (event.reason === "before_switch") {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
// Check if there are unsaved changes (messages since last assistant response)
|
||||
const hasUnsavedWork = event.entries.some((e) => e.type === "message" && e.message.role === "user");
|
||||
|
||||
if (hasUnsavedWork) {
|
||||
const confirmed = await ctx.ui.confirm(
|
||||
"Switch session?",
|
||||
"You have messages in the current session. Switch anyway?",
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
ctx.ui.notify("Switch cancelled", "info");
|
||||
return { cancel: true };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (event.reason === "before_branch") {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const choice = await ctx.ui.select(`Branch from turn ${event.targetTurnIndex}?`, [
|
||||
"Yes, create branch",
|
||||
"No, stay in current session",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, create branch") {
|
||||
ctx.ui.notify("Branch cancelled", "info");
|
||||
ctx.ui.notify("Switch cancelled", "info");
|
||||
return { cancel: true };
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session_before_branch", async (event, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const choice = await ctx.ui.select(`Branch from turn ${event.entryIndex}?`, [
|
||||
"Yes, create branch",
|
||||
"No, stay in current session",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, create branch") {
|
||||
ctx.ui.notify("Branch cancelled", "info");
|
||||
return { cancel: true };
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,17 +14,18 @@
|
|||
*/
|
||||
|
||||
import { complete, getModel } from "@mariozechner/pi-ai";
|
||||
import { messageTransformer } from "@mariozechner/pi-coding-agent";
|
||||
import { convertToLlm } from "@mariozechner/pi-coding-agent";
|
||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
if (event.reason !== "before_compact") return;
|
||||
|
||||
pi.on("session_before_compact", async (event, ctx) => {
|
||||
ctx.ui.notify("Custom compaction hook triggered", "info");
|
||||
|
||||
const { messagesToSummarize, messagesToKeep, previousSummary, tokensBefore, resolveApiKey, entries, signal } =
|
||||
event;
|
||||
const { preparation, previousCompactions, signal } = event;
|
||||
const { messagesToSummarize, messagesToKeep, tokensBefore, firstKeptEntryId } = preparation;
|
||||
|
||||
// Get previous summary from most recent compaction (if any)
|
||||
const previousSummary = previousCompactions[0]?.summary;
|
||||
|
||||
// Use Gemini Flash for summarization (cheaper/faster than most conversation models)
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
|
@ -34,7 +35,7 @@ export default function (pi: HookAPI) {
|
|||
}
|
||||
|
||||
// Resolve API key for the summarization model
|
||||
const apiKey = await resolveApiKey(model);
|
||||
const apiKey = await ctx.modelRegistry.getApiKey(model);
|
||||
if (!apiKey) {
|
||||
ctx.ui.notify(`No API key for ${model.provider}, using default compaction`, "warning");
|
||||
return;
|
||||
|
|
@ -49,7 +50,7 @@ export default function (pi: HookAPI) {
|
|||
);
|
||||
|
||||
// Transform app messages to pi-ai package format
|
||||
const transformedMessages = messageTransformer(allMessages);
|
||||
const transformedMessages = convertToLlm(allMessages);
|
||||
|
||||
// Include previous summary context if available
|
||||
const previousContext = previousSummary ? `\n\nPrevious session summary for context:\n${previousSummary}` : "";
|
||||
|
|
@ -94,14 +95,12 @@ Format the summary as structured markdown with clear sections.`,
|
|||
return;
|
||||
}
|
||||
|
||||
// Return a compaction entry that discards ALL messages
|
||||
// firstKeptEntryIndex points past all current entries
|
||||
// Return compaction content - SessionManager adds id/parentId
|
||||
// Use firstKeptEntryId from preparation to keep recent messages
|
||||
return {
|
||||
compactionEntry: {
|
||||
type: "compaction" as const,
|
||||
timestamp: new Date().toISOString(),
|
||||
compaction: {
|
||||
summary,
|
||||
firstKeptEntryIndex: entries.length,
|
||||
firstKeptEntryId,
|
||||
tokensBefore,
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,47 +5,55 @@
|
|||
* Useful to ensure work is committed before switching context.
|
||||
*/
|
||||
|
||||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
import type { HookAPI, HookEventContext } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
async function checkDirtyRepo(
|
||||
pi: HookAPI,
|
||||
ctx: HookEventContext,
|
||||
action: string,
|
||||
): Promise<{ cancel: boolean } | undefined> {
|
||||
// Check for uncommitted changes
|
||||
const { stdout, code } = await pi.exec("git", ["status", "--porcelain"]);
|
||||
|
||||
if (code !== 0) {
|
||||
// Not a git repo, allow the action
|
||||
return;
|
||||
}
|
||||
|
||||
const hasChanges = stdout.trim().length > 0;
|
||||
if (!hasChanges) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
// In non-interactive mode, block by default
|
||||
return { cancel: true };
|
||||
}
|
||||
|
||||
// Count changed files
|
||||
const changedFiles = stdout.trim().split("\n").filter(Boolean).length;
|
||||
|
||||
const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [
|
||||
"Yes, proceed anyway",
|
||||
"No, let me commit first",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, proceed anyway") {
|
||||
ctx.ui.notify("Commit your changes first", "warning");
|
||||
return { cancel: true };
|
||||
}
|
||||
}
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
// Only guard destructive actions
|
||||
if (event.reason !== "before_new" && event.reason !== "before_switch" && event.reason !== "before_branch") {
|
||||
return;
|
||||
}
|
||||
pi.on("session_before_new", async (_event, ctx) => {
|
||||
return checkDirtyRepo(pi, ctx, "new session");
|
||||
});
|
||||
|
||||
// Check for uncommitted changes
|
||||
const { stdout, code } = await ctx.exec("git", ["status", "--porcelain"]);
|
||||
pi.on("session_before_switch", async (_event, ctx) => {
|
||||
return checkDirtyRepo(pi, ctx, "switch session");
|
||||
});
|
||||
|
||||
if (code !== 0) {
|
||||
// Not a git repo, allow the action
|
||||
return;
|
||||
}
|
||||
|
||||
const hasChanges = stdout.trim().length > 0;
|
||||
if (!hasChanges) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
// In non-interactive mode, block by default
|
||||
return { cancel: true };
|
||||
}
|
||||
|
||||
// Count changed files
|
||||
const changedFiles = stdout.trim().split("\n").filter(Boolean).length;
|
||||
|
||||
const action =
|
||||
event.reason === "before_new" ? "new session" : event.reason === "before_switch" ? "switch session" : "branch";
|
||||
|
||||
const choice = await ctx.ui.select(`You have ${changedFiles} uncommitted file(s). ${action} anyway?`, [
|
||||
"Yes, proceed anyway",
|
||||
"No, let me commit first",
|
||||
]);
|
||||
|
||||
if (choice !== "Yes, proceed anyway") {
|
||||
ctx.ui.notify("Commit your changes first", "warning");
|
||||
return { cancel: true };
|
||||
}
|
||||
pi.on("session_before_branch", async (_event, ctx) => {
|
||||
return checkDirtyRepo(pi, ctx, "branch");
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,16 +12,21 @@ import * as fs from "node:fs";
|
|||
import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.on("session", async (event, ctx) => {
|
||||
if (event.reason !== "start") return;
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
const triggerFile = "/tmp/agent-trigger.txt";
|
||||
|
||||
fs.watch(triggerFile, () => {
|
||||
try {
|
||||
const content = fs.readFileSync(triggerFile, "utf-8").trim();
|
||||
if (content) {
|
||||
pi.send(`External trigger: ${content}`);
|
||||
pi.sendMessage(
|
||||
{
|
||||
customType: "file-trigger",
|
||||
content: `External trigger: ${content}`,
|
||||
display: true,
|
||||
},
|
||||
true, // triggerTurn - get LLM to respond
|
||||
);
|
||||
fs.writeFileSync(triggerFile, ""); // Clear after reading
|
||||
}
|
||||
} catch {
|
||||
|
|
|
|||
|
|
@ -10,20 +10,17 @@ import type { HookAPI } from "@mariozechner/pi-coding-agent/hooks";
|
|||
export default function (pi: HookAPI) {
|
||||
const checkpoints = new Map<number, string>();
|
||||
|
||||
pi.on("turn_start", async (event, ctx) => {
|
||||
pi.on("turn_start", async (event) => {
|
||||
// Create a git stash entry before LLM makes changes
|
||||
const { stdout } = await ctx.exec("git", ["stash", "create"]);
|
||||
const { stdout } = await pi.exec("git", ["stash", "create"]);
|
||||
const ref = stdout.trim();
|
||||
if (ref) {
|
||||
checkpoints.set(event.turnIndex, ref);
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session", async (event, ctx) => {
|
||||
// Only handle before_branch events
|
||||
if (event.reason !== "before_branch") return;
|
||||
|
||||
const ref = checkpoints.get(event.targetTurnIndex);
|
||||
pi.on("session_before_branch", async (event, ctx) => {
|
||||
const ref = checkpoints.get(event.entryIndex);
|
||||
if (!ref) return;
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
|
|
@ -37,7 +34,7 @@ export default function (pi: HookAPI) {
|
|||
]);
|
||||
|
||||
if (choice?.startsWith("Yes")) {
|
||||
await ctx.exec("git", ["stash", "apply", ref]);
|
||||
await pi.exec("git", ["stash", "apply", ref]);
|
||||
ctx.ui.notify("Code restored to checkpoint", "info");
|
||||
}
|
||||
});
|
||||
|
|
|
|||
345
packages/coding-agent/examples/hooks/snake.ts
Normal file
345
packages/coding-agent/examples/hooks/snake.ts
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
/**
|
||||
* Snake game hook - play snake with /snake command
|
||||
*/
|
||||
|
||||
import { isArrowDown, isArrowLeft, isArrowRight, isArrowUp, isEscape, visibleWidth } from "@mariozechner/pi-tui";
|
||||
import type { HookAPI } from "../../src/core/hooks/types.js";
|
||||
|
||||
const GAME_WIDTH = 40;
|
||||
const GAME_HEIGHT = 15;
|
||||
const TICK_MS = 100;
|
||||
|
||||
type Direction = "up" | "down" | "left" | "right";
|
||||
type Point = { x: number; y: number };
|
||||
|
||||
interface GameState {
|
||||
snake: Point[];
|
||||
food: Point;
|
||||
direction: Direction;
|
||||
nextDirection: Direction;
|
||||
score: number;
|
||||
gameOver: boolean;
|
||||
highScore: number;
|
||||
}
|
||||
|
||||
function createInitialState(): GameState {
|
||||
const startX = Math.floor(GAME_WIDTH / 2);
|
||||
const startY = Math.floor(GAME_HEIGHT / 2);
|
||||
return {
|
||||
snake: [
|
||||
{ x: startX, y: startY },
|
||||
{ x: startX - 1, y: startY },
|
||||
{ x: startX - 2, y: startY },
|
||||
],
|
||||
food: spawnFood([{ x: startX, y: startY }]),
|
||||
direction: "right",
|
||||
nextDirection: "right",
|
||||
score: 0,
|
||||
gameOver: false,
|
||||
highScore: 0,
|
||||
};
|
||||
}
|
||||
|
||||
function spawnFood(snake: Point[]): Point {
|
||||
let food: Point;
|
||||
do {
|
||||
food = {
|
||||
x: Math.floor(Math.random() * GAME_WIDTH),
|
||||
y: Math.floor(Math.random() * GAME_HEIGHT),
|
||||
};
|
||||
} while (snake.some((s) => s.x === food.x && s.y === food.y));
|
||||
return food;
|
||||
}
|
||||
|
||||
class SnakeComponent {
|
||||
private state: GameState;
|
||||
private interval: ReturnType<typeof setInterval> | null = null;
|
||||
private onClose: () => void;
|
||||
private onSave: (state: GameState | null) => void;
|
||||
private requestRender: () => void;
|
||||
private cachedLines: string[] = [];
|
||||
private cachedWidth = 0;
|
||||
private version = 0;
|
||||
private cachedVersion = -1;
|
||||
private paused: boolean;
|
||||
|
||||
constructor(
|
||||
onClose: () => void,
|
||||
onSave: (state: GameState | null) => void,
|
||||
requestRender: () => void,
|
||||
savedState?: GameState,
|
||||
) {
|
||||
if (savedState && !savedState.gameOver) {
|
||||
// Resume from saved state, start paused
|
||||
this.state = savedState;
|
||||
this.paused = true;
|
||||
} else {
|
||||
// New game or saved game was over
|
||||
this.state = createInitialState();
|
||||
if (savedState) {
|
||||
this.state.highScore = savedState.highScore;
|
||||
}
|
||||
this.paused = false;
|
||||
this.startGame();
|
||||
}
|
||||
this.onClose = onClose;
|
||||
this.onSave = onSave;
|
||||
this.requestRender = requestRender;
|
||||
}
|
||||
|
||||
private startGame(): void {
|
||||
this.interval = setInterval(() => {
|
||||
if (!this.state.gameOver) {
|
||||
this.tick();
|
||||
this.version++;
|
||||
this.requestRender();
|
||||
}
|
||||
}, TICK_MS);
|
||||
}
|
||||
|
||||
private tick(): void {
|
||||
// Apply queued direction change
|
||||
this.state.direction = this.state.nextDirection;
|
||||
|
||||
// Calculate new head position
|
||||
const head = this.state.snake[0];
|
||||
let newHead: Point;
|
||||
|
||||
switch (this.state.direction) {
|
||||
case "up":
|
||||
newHead = { x: head.x, y: head.y - 1 };
|
||||
break;
|
||||
case "down":
|
||||
newHead = { x: head.x, y: head.y + 1 };
|
||||
break;
|
||||
case "left":
|
||||
newHead = { x: head.x - 1, y: head.y };
|
||||
break;
|
||||
case "right":
|
||||
newHead = { x: head.x + 1, y: head.y };
|
||||
break;
|
||||
}
|
||||
|
||||
// Check wall collision
|
||||
if (newHead.x < 0 || newHead.x >= GAME_WIDTH || newHead.y < 0 || newHead.y >= GAME_HEIGHT) {
|
||||
this.state.gameOver = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check self collision
|
||||
if (this.state.snake.some((s) => s.x === newHead.x && s.y === newHead.y)) {
|
||||
this.state.gameOver = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Move snake
|
||||
this.state.snake.unshift(newHead);
|
||||
|
||||
// Check food collision
|
||||
if (newHead.x === this.state.food.x && newHead.y === this.state.food.y) {
|
||||
this.state.score += 10;
|
||||
if (this.state.score > this.state.highScore) {
|
||||
this.state.highScore = this.state.score;
|
||||
}
|
||||
this.state.food = spawnFood(this.state.snake);
|
||||
} else {
|
||||
this.state.snake.pop();
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
// If paused (resuming), wait for any key
|
||||
if (this.paused) {
|
||||
if (isEscape(data) || data === "q" || data === "Q") {
|
||||
// Quit without clearing save
|
||||
this.dispose();
|
||||
this.onClose();
|
||||
return;
|
||||
}
|
||||
// Any other key resumes
|
||||
this.paused = false;
|
||||
this.startGame();
|
||||
return;
|
||||
}
|
||||
|
||||
// ESC to pause and save
|
||||
if (isEscape(data)) {
|
||||
this.dispose();
|
||||
this.onSave(this.state);
|
||||
this.onClose();
|
||||
return;
|
||||
}
|
||||
|
||||
// Q to quit without saving (clears saved state)
|
||||
if (data === "q" || data === "Q") {
|
||||
this.dispose();
|
||||
this.onSave(null); // Clear saved state
|
||||
this.onClose();
|
||||
return;
|
||||
}
|
||||
|
||||
// Arrow keys or WASD
|
||||
if (isArrowUp(data) || data === "w" || data === "W") {
|
||||
if (this.state.direction !== "down") this.state.nextDirection = "up";
|
||||
} else if (isArrowDown(data) || data === "s" || data === "S") {
|
||||
if (this.state.direction !== "up") this.state.nextDirection = "down";
|
||||
} else if (isArrowRight(data) || data === "d" || data === "D") {
|
||||
if (this.state.direction !== "left") this.state.nextDirection = "right";
|
||||
} else if (isArrowLeft(data) || data === "a" || data === "A") {
|
||||
if (this.state.direction !== "right") this.state.nextDirection = "left";
|
||||
}
|
||||
|
||||
// Restart on game over
|
||||
if (this.state.gameOver && (data === "r" || data === "R" || data === " ")) {
|
||||
const highScore = this.state.highScore;
|
||||
this.state = createInitialState();
|
||||
this.state.highScore = highScore;
|
||||
this.onSave(null); // Clear saved state on restart
|
||||
this.version++;
|
||||
this.requestRender();
|
||||
}
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
this.cachedWidth = 0;
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
if (width === this.cachedWidth && this.cachedVersion === this.version) {
|
||||
return this.cachedLines;
|
||||
}
|
||||
|
||||
const lines: string[] = [];
|
||||
|
||||
// Each game cell is 2 chars wide to appear square (terminal cells are ~2:1 aspect)
|
||||
const cellWidth = 2;
|
||||
const effectiveWidth = Math.min(GAME_WIDTH, Math.floor((width - 4) / cellWidth));
|
||||
const effectiveHeight = GAME_HEIGHT;
|
||||
|
||||
// Colors
|
||||
const dim = (s: string) => `\x1b[2m${s}\x1b[22m`;
|
||||
const green = (s: string) => `\x1b[32m${s}\x1b[0m`;
|
||||
const red = (s: string) => `\x1b[31m${s}\x1b[0m`;
|
||||
const yellow = (s: string) => `\x1b[33m${s}\x1b[0m`;
|
||||
const bold = (s: string) => `\x1b[1m${s}\x1b[22m`;
|
||||
|
||||
const boxWidth = effectiveWidth * cellWidth;
|
||||
|
||||
// Helper to pad content inside box
|
||||
const boxLine = (content: string) => {
|
||||
const contentLen = visibleWidth(content);
|
||||
const padding = Math.max(0, boxWidth - contentLen);
|
||||
return dim(" │") + content + " ".repeat(padding) + dim("│");
|
||||
};
|
||||
|
||||
// Top border
|
||||
lines.push(this.padLine(dim(` ╭${"─".repeat(boxWidth)}╮`), width));
|
||||
|
||||
// Header with score
|
||||
const scoreText = `Score: ${bold(yellow(String(this.state.score)))}`;
|
||||
const highText = `High: ${bold(yellow(String(this.state.highScore)))}`;
|
||||
const title = `${bold(green("SNAKE"))} │ ${scoreText} │ ${highText}`;
|
||||
lines.push(this.padLine(boxLine(title), width));
|
||||
|
||||
// Separator
|
||||
lines.push(this.padLine(dim(` ├${"─".repeat(boxWidth)}┤`), width));
|
||||
|
||||
// Game grid
|
||||
for (let y = 0; y < effectiveHeight; y++) {
|
||||
let row = "";
|
||||
for (let x = 0; x < effectiveWidth; x++) {
|
||||
const isHead = this.state.snake[0].x === x && this.state.snake[0].y === y;
|
||||
const isBody = this.state.snake.slice(1).some((s) => s.x === x && s.y === y);
|
||||
const isFood = this.state.food.x === x && this.state.food.y === y;
|
||||
|
||||
if (isHead) {
|
||||
row += green("██"); // Snake head (2 chars)
|
||||
} else if (isBody) {
|
||||
row += green("▓▓"); // Snake body (2 chars)
|
||||
} else if (isFood) {
|
||||
row += red("◆ "); // Food (2 chars)
|
||||
} else {
|
||||
row += " "; // Empty cell (2 spaces)
|
||||
}
|
||||
}
|
||||
lines.push(this.padLine(dim(" │") + row + dim("│"), width));
|
||||
}
|
||||
|
||||
// Separator
|
||||
lines.push(this.padLine(dim(` ├${"─".repeat(boxWidth)}┤`), width));
|
||||
|
||||
// Footer
|
||||
let footer: string;
|
||||
if (this.paused) {
|
||||
footer = `${yellow(bold("PAUSED"))} Press any key to continue, ${bold("Q")} to quit`;
|
||||
} else if (this.state.gameOver) {
|
||||
footer = `${red(bold("GAME OVER!"))} Press ${bold("R")} to restart, ${bold("Q")} to quit`;
|
||||
} else {
|
||||
footer = `↑↓←→ or WASD to move, ${bold("ESC")} pause, ${bold("Q")} quit`;
|
||||
}
|
||||
lines.push(this.padLine(boxLine(footer), width));
|
||||
|
||||
// Bottom border
|
||||
lines.push(this.padLine(dim(` ╰${"─".repeat(boxWidth)}╯`), width));
|
||||
|
||||
this.cachedLines = lines;
|
||||
this.cachedWidth = width;
|
||||
this.cachedVersion = this.version;
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
private padLine(line: string, width: number): string {
|
||||
// Calculate visible length (strip ANSI codes)
|
||||
const visibleLen = line.replace(/\x1b\[[0-9;]*m/g, "").length;
|
||||
const padding = Math.max(0, width - visibleLen);
|
||||
return line + " ".repeat(padding);
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
if (this.interval) {
|
||||
clearInterval(this.interval);
|
||||
this.interval = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const SNAKE_SAVE_TYPE = "snake-save";
|
||||
|
||||
export default function (pi: HookAPI) {
|
||||
pi.registerCommand("snake", {
|
||||
description: "Play Snake!",
|
||||
|
||||
handler: async (ctx) => {
|
||||
if (!ctx.hasUI) {
|
||||
ctx.ui.notify("Snake requires interactive mode", "error");
|
||||
return;
|
||||
}
|
||||
|
||||
// Load saved state from session
|
||||
const entries = ctx.sessionManager.getEntries();
|
||||
let savedState: GameState | undefined;
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "custom" && entry.customType === SNAKE_SAVE_TYPE) {
|
||||
savedState = entry.data as GameState;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let ui: { close: () => void; requestRender: () => void } | null = null;
|
||||
|
||||
const component = new SnakeComponent(
|
||||
() => ui?.close(),
|
||||
(state) => {
|
||||
// Save or clear state
|
||||
pi.appendEntry(SNAKE_SAVE_TYPE, state);
|
||||
},
|
||||
() => ui?.requestRender(),
|
||||
savedState,
|
||||
);
|
||||
|
||||
ui = ctx.ui.custom(component);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -3,21 +3,21 @@
|
|||
*/
|
||||
|
||||
import { access, readFile, stat } from "node:fs/promises";
|
||||
import type { Attachment } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { resolve } from "path";
|
||||
import { resolveReadPath } from "../core/tools/path-utils.js";
|
||||
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
|
||||
|
||||
export interface ProcessedFiles {
|
||||
textContent: string;
|
||||
imageAttachments: Attachment[];
|
||||
text: string;
|
||||
images: ImageContent[];
|
||||
}
|
||||
|
||||
/** Process @file arguments into text content and image attachments */
|
||||
export async function processFileArguments(fileArgs: string[]): Promise<ProcessedFiles> {
|
||||
let textContent = "";
|
||||
const imageAttachments: Attachment[] = [];
|
||||
let text = "";
|
||||
const images: ImageContent[] = [];
|
||||
|
||||
for (const fileArg of fileArgs) {
|
||||
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
|
||||
|
|
@ -45,24 +45,21 @@ export async function processFileArguments(fileArgs: string[]): Promise<Processe
|
|||
const content = await readFile(absolutePath);
|
||||
const base64Content = content.toString("base64");
|
||||
|
||||
const attachment: Attachment = {
|
||||
id: `file-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
|
||||
const attachment: ImageContent = {
|
||||
type: "image",
|
||||
fileName: absolutePath.split("/").pop() || absolutePath,
|
||||
mimeType,
|
||||
size: stats.size,
|
||||
content: base64Content,
|
||||
data: base64Content,
|
||||
};
|
||||
|
||||
imageAttachments.push(attachment);
|
||||
images.push(attachment);
|
||||
|
||||
// Add text reference to image
|
||||
textContent += `<file name="${absolutePath}"></file>\n`;
|
||||
text += `<file name="${absolutePath}"></file>\n`;
|
||||
} else {
|
||||
// Handle text file
|
||||
try {
|
||||
const content = await readFile(absolutePath, "utf-8");
|
||||
textContent += `<file name="${absolutePath}">\n${content}\n</file>\n`;
|
||||
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.error(chalk.red(`Error: Could not read file ${absolutePath}: ${message}`));
|
||||
|
|
@ -71,5 +68,5 @@ export async function processFileArguments(fileArgs: string[]): Promise<Processe
|
|||
}
|
||||
}
|
||||
|
||||
return { textContent, imageAttachments };
|
||||
return { text, images };
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -94,8 +94,8 @@ export class AuthStorage {
|
|||
/**
|
||||
* Get credential for a provider.
|
||||
*/
|
||||
get(provider: string): AuthCredential | null {
|
||||
return this.data[provider] ?? null;
|
||||
get(provider: string): AuthCredential | undefined {
|
||||
return this.data[provider] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -191,7 +191,7 @@ export class AuthStorage {
|
|||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(provider: string): Promise<string | null> {
|
||||
async getApiKey(provider: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(provider);
|
||||
if (runtimeKey) {
|
||||
|
|
@ -230,6 +230,6 @@ export class AuthStorage {
|
|||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(provider) ?? null;
|
||||
return this.fallbackResolver?.(provider) ?? undefined;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ export interface BashExecutorOptions {
|
|||
export interface BashResult {
|
||||
/** Combined stdout + stderr output (sanitized, possibly truncated) */
|
||||
output: string;
|
||||
/** Process exit code (null if killed/cancelled) */
|
||||
exitCode: number | null;
|
||||
/** Process exit code (undefined if killed/cancelled) */
|
||||
exitCode: number | undefined;
|
||||
/** Whether the command was cancelled via signal */
|
||||
cancelled: boolean;
|
||||
/** Whether the output was truncated */
|
||||
|
|
@ -88,7 +88,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro
|
|||
child.kill();
|
||||
resolve({
|
||||
output: "",
|
||||
exitCode: null,
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: false,
|
||||
});
|
||||
|
|
@ -154,7 +154,7 @@ export function executeBash(command: string, options?: BashExecutorOptions): Pro
|
|||
|
||||
resolve({
|
||||
output: truncationResult.truncated ? truncationResult.content : fullOutput,
|
||||
exitCode: code,
|
||||
exitCode: cancelled ? undefined : code,
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
|
|
|
|||
|
|
@ -1,530 +0,0 @@
|
|||
/**
|
||||
* Context compaction for long sessions.
|
||||
*
|
||||
* Pure functions for compaction logic. The session manager handles I/O,
|
||||
* and after compaction the session is reloaded.
|
||||
*/
|
||||
|
||||
import type { AppMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
||||
import { complete } from "@mariozechner/pi-ai";
|
||||
import { messageTransformer } from "./messages.js";
|
||||
import type { CompactionEntry, SessionEntry } from "./session-manager.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionSettings {
|
||||
enabled: boolean;
|
||||
reserveTokens: number;
|
||||
keepRecentTokens: number;
|
||||
}
|
||||
|
||||
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
|
||||
enabled: true,
|
||||
reserveTokens: 16384,
|
||||
keepRecentTokens: 20000,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Token calculation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Calculate total context tokens from usage.
|
||||
* Uses the native totalTokens field when available, falls back to computing from components.
|
||||
*/
|
||||
export function calculateContextTokens(usage: Usage): number {
|
||||
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get usage from an assistant message if available.
|
||||
* Skips aborted and error messages as they don't have valid usage data.
|
||||
*/
|
||||
function getAssistantUsage(msg: AppMessage): Usage | null {
|
||||
if (msg.role === "assistant" && "usage" in msg) {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
|
||||
return assistantMsg.usage;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the last non-aborted assistant message usage from session entries.
|
||||
*/
|
||||
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | null {
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const usage = getAssistantUsage(entry.message);
|
||||
if (usage) return usage;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if compaction should trigger based on context usage.
|
||||
*/
|
||||
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
|
||||
if (!settings.enabled) return false;
|
||||
return contextTokens > contextWindow - settings.reserveTokens;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cut point detection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Estimate token count for a message using chars/4 heuristic.
|
||||
* This is conservative (overestimates tokens).
|
||||
*/
|
||||
export function estimateTokens(message: AppMessage): number {
|
||||
let chars = 0;
|
||||
|
||||
// Handle bashExecution messages
|
||||
if (message.role === "bashExecution") {
|
||||
const bash = message as unknown as { command: string; output: string };
|
||||
chars = bash.command.length + bash.output.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
|
||||
// Handle user messages
|
||||
if (message.role === "user") {
|
||||
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
|
||||
if (typeof content === "string") {
|
||||
chars = content.length;
|
||||
} else if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
|
||||
// Handle assistant messages
|
||||
if (message.role === "assistant") {
|
||||
const assistant = message as AssistantMessage;
|
||||
for (const block of assistant.content) {
|
||||
if (block.type === "text") {
|
||||
chars += block.text.length;
|
||||
} else if (block.type === "thinking") {
|
||||
chars += block.thinking.length;
|
||||
} else if (block.type === "toolCall") {
|
||||
chars += block.name.length + JSON.stringify(block.arguments).length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
|
||||
// Handle tool results
|
||||
if (message.role === "toolResult") {
|
||||
const toolResult = message as { content: Array<{ type: string; text?: string }> };
|
||||
for (const block of toolResult.content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find valid cut points: indices of user, assistant, or bashExecution messages.
|
||||
* Never cut at tool results (they must follow their tool call).
|
||||
* When we cut at an assistant message with tool calls, its tool results follow it
|
||||
* and will be kept.
|
||||
* BashExecutionMessage is treated like a user message (user-initiated context).
|
||||
*/
|
||||
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
|
||||
const cutPoints: number[] = [];
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
// user, assistant, and bashExecution are valid cut points
|
||||
// toolResult must stay with its preceding tool call
|
||||
if (role === "user" || role === "assistant" || role === "bashExecution") {
|
||||
cutPoints.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
|
||||
* Returns -1 if no turn start found before the index.
|
||||
* BashExecutionMessage is treated like a user message for turn boundaries.
|
||||
*/
|
||||
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
|
||||
for (let i = entryIndex; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
if (role === "user" || role === "bashExecution") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export interface CutPointResult {
|
||||
/** Index of first entry to keep */
|
||||
firstKeptEntryIndex: number;
|
||||
/** Index of user message that starts the turn being split, or -1 if not splitting */
|
||||
turnStartIndex: number;
|
||||
/** Whether this cut splits a turn (cut point is not a user message) */
|
||||
isSplitTurn: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
|
||||
*
|
||||
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
|
||||
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
|
||||
*
|
||||
* Can cut at user OR assistant messages (never tool results). When cutting at an
|
||||
* assistant message with tool calls, its tool results come after and will be kept.
|
||||
*
|
||||
* Returns CutPointResult with:
|
||||
* - firstKeptEntryIndex: the entry index to start keeping from
|
||||
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
|
||||
* - isSplitTurn: whether we're cutting in the middle of a turn
|
||||
*
|
||||
* Only considers entries between `startIndex` and `endIndex` (exclusive).
|
||||
*/
|
||||
export function findCutPoint(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
keepRecentTokens: number,
|
||||
): CutPointResult {
|
||||
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
|
||||
|
||||
if (cutPoints.length === 0) {
|
||||
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
|
||||
}
|
||||
|
||||
// Walk backwards from newest, accumulating estimated message sizes
|
||||
let accumulatedTokens = 0;
|
||||
let cutIndex = startIndex; // Default: keep everything in range
|
||||
|
||||
for (let i = endIndex - 1; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type !== "message") continue;
|
||||
|
||||
// Estimate this message's size
|
||||
const messageTokens = estimateTokens(entry.message);
|
||||
accumulatedTokens += messageTokens;
|
||||
|
||||
// Check if we've exceeded the budget
|
||||
if (accumulatedTokens >= keepRecentTokens) {
|
||||
// Find the closest valid cut point at or after this entry
|
||||
for (let c = 0; c < cutPoints.length; c++) {
|
||||
if (cutPoints[c] >= i) {
|
||||
cutIndex = cutPoints[c];
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
|
||||
while (cutIndex > startIndex) {
|
||||
const prevEntry = entries[cutIndex - 1];
|
||||
// Stop at compaction boundaries
|
||||
if (prevEntry.type === "compaction") {
|
||||
break;
|
||||
}
|
||||
if (prevEntry.type === "message") {
|
||||
// Stop if we hit any message
|
||||
break;
|
||||
}
|
||||
// Include this non-message entry (bash, settings change, etc.)
|
||||
cutIndex--;
|
||||
}
|
||||
|
||||
// Determine if this is a split turn
|
||||
const cutEntry = entries[cutIndex];
|
||||
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
|
||||
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
|
||||
|
||||
return {
|
||||
firstKeptEntryIndex: cutIndex,
|
||||
turnStartIndex,
|
||||
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization
|
||||
// ============================================================================
|
||||
|
||||
const SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION. Create a handoff summary for another LLM that will resume the task.
|
||||
|
||||
Include:
|
||||
- Current progress and key decisions made
|
||||
- Important context, constraints, or user preferences
|
||||
- Absolute file paths of any relevant files that were read or modified
|
||||
- What remains to be done (clear next steps)
|
||||
- Any critical data, examples, or references needed to continue
|
||||
|
||||
Be concise, structured, and focused on helping the next LLM seamlessly continue the work.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of the conversation using the LLM.
|
||||
*/
|
||||
export async function generateSummary(
|
||||
currentMessages: AppMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.8 * reserveTokens);
|
||||
|
||||
const prompt = customInstructions
|
||||
? `${SUMMARIZATION_PROMPT}\n\nAdditional focus: ${customInstructions}`
|
||||
: SUMMARIZATION_PROMPT;
|
||||
|
||||
// Transform custom messages (like bashExecution) to LLM-compatible messages
|
||||
const transformedMessages = messageTransformer(currentMessages);
|
||||
|
||||
const summarizationMessages = [
|
||||
...transformedMessages,
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: prompt }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
|
||||
|
||||
const textContent = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
return textContent;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compaction Preparation (for hooks)
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionPreparation {
|
||||
cutPoint: CutPointResult;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AppMessage[];
|
||||
/** Messages that will be kept after the summary (recent turns) */
|
||||
messagesToKeep: AppMessage[];
|
||||
tokensBefore: number;
|
||||
boundaryStart: number;
|
||||
}
|
||||
|
||||
export function prepareCompaction(entries: SessionEntry[], settings: CompactionSettings): CompactionPreparation | null {
|
||||
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
|
||||
return null;
|
||||
}
|
||||
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
if (entries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = entries.length;
|
||||
|
||||
const lastUsage = getLastAssistantUsage(entries);
|
||||
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
|
||||
|
||||
const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
|
||||
|
||||
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
|
||||
|
||||
// Messages to summarize (will be discarded after summary)
|
||||
const messagesToSummarize: AppMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
messagesToSummarize.push(entry.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Messages to keep (recent turns, kept after summary)
|
||||
const messagesToKeep: AppMessage[] = [];
|
||||
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
messagesToKeep.push(entry.message);
|
||||
}
|
||||
}
|
||||
|
||||
return { cutPoint, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main compaction function
|
||||
// ============================================================================
|
||||
|
||||
const TURN_PREFIX_SUMMARIZATION_PROMPT = `You are performing a CONTEXT CHECKPOINT COMPACTION for a split turn.
|
||||
This is the PREFIX of a turn that was too large to keep in full. The SUFFIX (recent work) is being kept.
|
||||
|
||||
Create a handoff summary that captures:
|
||||
- What the user originally asked for in this turn
|
||||
- Key decisions and progress made early in this turn
|
||||
- Important context needed to understand the kept suffix
|
||||
|
||||
Be concise. Focus on information needed to understand the retained recent work.`;
|
||||
|
||||
/**
|
||||
* Calculate compaction and generate summary.
|
||||
* Returns the CompactionEntry to append to the session file.
|
||||
*
|
||||
* @param entries - All session entries
|
||||
* @param model - Model to use for summarization
|
||||
* @param settings - Compaction settings
|
||||
* @param apiKey - API key for LLM
|
||||
* @param signal - Optional abort signal
|
||||
* @param customInstructions - Optional custom focus for the summary
|
||||
*/
|
||||
export async function compact(
|
||||
entries: SessionEntry[],
|
||||
model: Model<any>,
|
||||
settings: CompactionSettings,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
): Promise<CompactionEntry> {
|
||||
// Don't compact if the last entry is already a compaction
|
||||
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
|
||||
throw new Error("Already compacted");
|
||||
}
|
||||
|
||||
// Find previous compaction boundary
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
if (entries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = entries.length;
|
||||
|
||||
// Get token count before compaction
|
||||
const lastUsage = getLastAssistantUsage(entries);
|
||||
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
|
||||
|
||||
// Find cut point (entry index) within the valid range
|
||||
const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
|
||||
|
||||
// Extract messages for history summary (before the turn that contains the cut point)
|
||||
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
|
||||
const historyMessages: AppMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
historyMessages.push(entry.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Include previous summary if there was a compaction
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
historyMessages.unshift({
|
||||
role: "user",
|
||||
content: `Previous session summary:\n${prevCompaction.summary}`,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
// Extract messages for turn prefix summary (if splitting a turn)
|
||||
const turnPrefixMessages: AppMessage[] = [];
|
||||
if (cutResult.isSplitTurn) {
|
||||
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
turnPrefixMessages.push(entry.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate summaries (can be parallel if both needed) and merge into one
|
||||
let summary: string;
|
||||
|
||||
if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) {
|
||||
// Generate both summaries in parallel
|
||||
const [historyResult, turnPrefixResult] = await Promise.all([
|
||||
historyMessages.length > 0
|
||||
? generateSummary(historyMessages, model, settings.reserveTokens, apiKey, signal, customInstructions)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
|
||||
]);
|
||||
// Merge into single summary
|
||||
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
|
||||
} else {
|
||||
// Just generate history summary
|
||||
summary = await generateSummary(
|
||||
historyMessages,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
type: "compaction",
|
||||
timestamp: new Date().toISOString(),
|
||||
summary,
|
||||
firstKeptEntryIndex: cutResult.firstKeptEntryIndex,
|
||||
tokensBefore,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a summary for a turn prefix (when splitting a turn).
|
||||
*/
|
||||
async function generateTurnPrefixSummary(
|
||||
messages: AppMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
|
||||
|
||||
const transformedMessages = messageTransformer(messages);
|
||||
const summarizationMessages = [
|
||||
...transformedMessages,
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
|
||||
|
||||
return response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
|
|
@ -0,0 +1,343 @@
|
|||
/**
|
||||
* Branch summarization for tree navigation.
|
||||
*
|
||||
* When navigating to a different point in the session tree, this generates
|
||||
* a summary of the branch being left so context isn't lost.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { completeSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createHookMessage,
|
||||
} from "../messages.js";
|
||||
import type { ReadonlySessionManager, SessionEntry } from "../session-manager.js";
|
||||
import { estimateTokens } from "./compaction.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BranchSummaryResult {
|
||||
summary?: string;
|
||||
readFiles?: string[];
|
||||
modifiedFiles?: string[];
|
||||
aborted?: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Details stored in BranchSummaryEntry.details for file tracking */
|
||||
export interface BranchSummaryDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
export type { FileOperations } from "./utils.js";
|
||||
|
||||
export interface BranchPreparation {
|
||||
/** Messages extracted for summarization, in chronological order */
|
||||
messages: AgentMessage[];
|
||||
/** File operations extracted from tool calls */
|
||||
fileOps: FileOperations;
|
||||
/** Total estimated tokens in messages */
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export interface CollectEntriesResult {
|
||||
/** Entries to summarize, in chronological order */
|
||||
entries: SessionEntry[];
|
||||
/** Common ancestor between old and new position, if any */
|
||||
commonAncestorId: string | null;
|
||||
}
|
||||
|
||||
export interface GenerateBranchSummaryOptions {
|
||||
/** Model to use for summarization */
|
||||
model: Model<any>;
|
||||
/** API key for the model */
|
||||
apiKey: string;
|
||||
/** Abort signal for cancellation */
|
||||
signal: AbortSignal;
|
||||
/** Optional custom instructions for summarization */
|
||||
customInstructions?: string;
|
||||
/** Tokens reserved for prompt + LLM response (default 16384) */
|
||||
reserveTokens?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry Collection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Collect entries that should be summarized when navigating from one position to another.
|
||||
*
|
||||
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
|
||||
* along the way. Does NOT stop at compaction boundaries - those are included and their
|
||||
* summaries become context.
|
||||
*
|
||||
* @param session - Session manager (read-only access)
|
||||
* @param oldLeafId - Current position (where we're navigating from)
|
||||
* @param targetId - Target position (where we're navigating to)
|
||||
* @returns Entries to summarize and the common ancestor
|
||||
*/
|
||||
export function collectEntriesForBranchSummary(
|
||||
session: ReadonlySessionManager,
|
||||
oldLeafId: string | null,
|
||||
targetId: string,
|
||||
): CollectEntriesResult {
|
||||
// If no old position, nothing to summarize
|
||||
if (!oldLeafId) {
|
||||
return { entries: [], commonAncestorId: null };
|
||||
}
|
||||
|
||||
// Find common ancestor (deepest node that's on both paths)
|
||||
const oldPath = new Set(session.getPath(oldLeafId).map((e) => e.id));
|
||||
const targetPath = session.getPath(targetId);
|
||||
|
||||
// targetPath is root-first, so iterate backwards to find deepest common ancestor
|
||||
let commonAncestorId: string | null = null;
|
||||
for (let i = targetPath.length - 1; i >= 0; i--) {
|
||||
if (oldPath.has(targetPath[i].id)) {
|
||||
commonAncestorId = targetPath[i].id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect entries from old leaf back to common ancestor
|
||||
const entries: SessionEntry[] = [];
|
||||
let current: string | null = oldLeafId;
|
||||
|
||||
while (current && current !== commonAncestorId) {
|
||||
const entry = session.getEntry(current);
|
||||
if (!entry) break;
|
||||
entries.push(entry);
|
||||
current = entry.parentId;
|
||||
}
|
||||
|
||||
// Reverse to get chronological order
|
||||
entries.reverse();
|
||||
|
||||
return { entries, commonAncestorId };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry to Message Conversion
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from a session entry.
|
||||
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
switch (entry.type) {
|
||||
case "message":
|
||||
// Skip tool results - context is in assistant's tool call
|
||||
if (entry.message.role === "toolResult") return undefined;
|
||||
return entry.message;
|
||||
|
||||
case "custom_message":
|
||||
return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
|
||||
|
||||
case "branch_summary":
|
||||
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
|
||||
|
||||
case "compaction":
|
||||
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
|
||||
|
||||
// These don't contribute to conversation content
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "custom":
|
||||
case "label":
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare entries for summarization with token budget.
|
||||
*
|
||||
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
|
||||
* This ensures we keep the most recent context when the branch is too long.
|
||||
*
|
||||
* Also collects file operations from:
|
||||
* - Tool calls in assistant messages
|
||||
* - Existing branch_summary entries' details (for cumulative tracking)
|
||||
*
|
||||
* @param entries - Entries in chronological order
|
||||
* @param tokenBudget - Maximum tokens to include (0 = no limit)
|
||||
*/
|
||||
export function prepareBranchEntries(entries: SessionEntry[], tokenBudget: number = 0): BranchPreparation {
|
||||
const messages: AgentMessage[] = [];
|
||||
const fileOps = createFileOps();
|
||||
let totalTokens = 0;
|
||||
|
||||
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
|
||||
// This ensures we capture cumulative file tracking from nested branch summaries
|
||||
// Only extract from pi-generated summaries (fromHook !== true), not hook-generated ones
|
||||
for (const entry of entries) {
|
||||
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
|
||||
const details = entry.details as BranchSummaryDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
// Modified files go into both edited and written for proper deduplication
|
||||
for (const f of details.modifiedFiles) {
|
||||
fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: walk from newest to oldest, adding messages until token budget
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
const message = getMessageFromEntry(entry);
|
||||
if (!message) continue;
|
||||
|
||||
// Extract file ops from assistant messages (tool calls)
|
||||
extractFileOpsFromMessage(message, fileOps);
|
||||
|
||||
const tokens = estimateTokens(message);
|
||||
|
||||
// Check budget before adding
|
||||
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
|
||||
// If this is a summary entry, try to fit it anyway as it's important context
|
||||
if (entry.type === "compaction" || entry.type === "branch_summary") {
|
||||
if (totalTokens < tokenBudget * 0.9) {
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
}
|
||||
// Stop - we've hit the budget
|
||||
break;
|
||||
}
|
||||
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
|
||||
return { messages, fileOps, totalTokens };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summary Generation
|
||||
// ============================================================================
|
||||
|
||||
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
|
||||
Summary of that exploration:
|
||||
|
||||
`;
|
||||
|
||||
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What was the user trying to accomplish in this branch?]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Work that was started but not finished]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [What should happen next to continue this work]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of abandoned branch entries.
|
||||
*
|
||||
* @param entries - Session entries to summarize (chronological order)
|
||||
* @param options - Generation options
|
||||
*/
|
||||
export async function generateBranchSummary(
|
||||
entries: SessionEntry[],
|
||||
options: GenerateBranchSummaryOptions,
|
||||
): Promise<BranchSummaryResult> {
|
||||
const { model, apiKey, signal, customInstructions, reserveTokens = 16384 } = options;
|
||||
|
||||
// Token budget = context window minus reserved space for prompt + response
|
||||
const contextWindow = model.contextWindow || 128000;
|
||||
const tokenBudget = contextWindow - reserveTokens;
|
||||
|
||||
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
|
||||
|
||||
if (messages.length === 0) {
|
||||
return { summary: "No content to summarize" };
|
||||
}
|
||||
|
||||
// Transform to LLM-compatible messages, then serialize to text
|
||||
// Serialization prevents the model from treating it as a conversation to continue
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build prompt
|
||||
const instructions = customInstructions || BRANCH_SUMMARY_PROMPT;
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
// Call LLM for summarization
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
{ apiKey, signal, maxTokens: 2048 },
|
||||
);
|
||||
|
||||
// Check if aborted or errored
|
||||
if (response.stopReason === "aborted") {
|
||||
return { aborted: true };
|
||||
}
|
||||
if (response.stopReason === "error") {
|
||||
return { error: response.errorMessage || "Summarization failed" };
|
||||
}
|
||||
|
||||
let summary = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
// Prepend preamble to provide context about the branch summary
|
||||
summary = BRANCH_SUMMARY_PREAMBLE + summary;
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
return {
|
||||
summary: summary || "No summary generated",
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
};
|
||||
}
|
||||
759
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
759
packages/coding-agent/src/core/compaction/compaction.ts
Normal file
|
|
@ -0,0 +1,759 @@
|
|||
/**
|
||||
* Context compaction for long sessions.
|
||||
*
|
||||
* Pure functions for compaction logic. The session manager handles I/O,
|
||||
* and after compaction the session is reloaded.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
|
||||
import { complete, completeSimple } from "@mariozechner/pi-ai";
|
||||
import { convertToLlm, createBranchSummaryMessage, createHookMessage } from "../messages.js";
|
||||
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
/** Details stored in CompactionEntry.details for file tracking */
|
||||
export interface CompactionDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from messages and previous compaction entries.
|
||||
*/
|
||||
function extractFileOperations(
|
||||
messages: AgentMessage[],
|
||||
entries: SessionEntry[],
|
||||
prevCompactionIndex: number,
|
||||
): FileOperations {
|
||||
const fileOps = createFileOps();
|
||||
|
||||
// Collect from previous compaction's details (if pi-generated)
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
if (!prevCompaction.fromHook && prevCompaction.details) {
|
||||
const details = prevCompaction.details as CompactionDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
for (const f of details.modifiedFiles) fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from tool calls in messages
|
||||
for (const msg of messages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
|
||||
return fileOps;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Extraction
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from an entry if it produces one.
|
||||
* Returns undefined for entries that don't contribute to LLM context.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
if (entry.type === "message") {
|
||||
return entry.message;
|
||||
}
|
||||
if (entry.type === "custom_message") {
|
||||
return createHookMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
|
||||
}
|
||||
if (entry.type === "branch_summary") {
|
||||
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
||||
export interface CompactionResult<T = unknown> {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
/** Hook-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
|
||||
details?: T;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionSettings {
|
||||
enabled: boolean;
|
||||
reserveTokens: number;
|
||||
keepRecentTokens: number;
|
||||
}
|
||||
|
||||
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
|
||||
enabled: true,
|
||||
reserveTokens: 16384,
|
||||
keepRecentTokens: 20000,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Token calculation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Calculate total context tokens from usage.
|
||||
* Uses the native totalTokens field when available, falls back to computing from components.
|
||||
*/
|
||||
export function calculateContextTokens(usage: Usage): number {
|
||||
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get usage from an assistant message if available.
|
||||
* Skips aborted and error messages as they don't have valid usage data.
|
||||
*/
|
||||
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
|
||||
if (msg.role === "assistant" && "usage" in msg) {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
|
||||
return assistantMsg.usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the last non-aborted assistant message usage from session entries.
|
||||
*/
|
||||
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | undefined {
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const usage = getAssistantUsage(entry.message);
|
||||
if (usage) return usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if compaction should trigger based on context usage.
|
||||
*/
|
||||
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
|
||||
if (!settings.enabled) return false;
|
||||
return contextTokens > contextWindow - settings.reserveTokens;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cut point detection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Estimate token count for a message using chars/4 heuristic.
|
||||
* This is conservative (overestimates tokens).
|
||||
*/
|
||||
export function estimateTokens(message: AgentMessage): number {
|
||||
let chars = 0;
|
||||
|
||||
switch (message.role) {
|
||||
case "user": {
|
||||
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
|
||||
if (typeof content === "string") {
|
||||
chars = content.length;
|
||||
} else if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "assistant": {
|
||||
const assistant = message as AssistantMessage;
|
||||
for (const block of assistant.content) {
|
||||
if (block.type === "text") {
|
||||
chars += block.text.length;
|
||||
} else if (block.type === "thinking") {
|
||||
chars += block.thinking.length;
|
||||
} else if (block.type === "toolCall") {
|
||||
chars += block.name.length + JSON.stringify(block.arguments).length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "hookMessage":
|
||||
case "toolResult": {
|
||||
if (typeof message.content === "string") {
|
||||
chars = message.content.length;
|
||||
} else {
|
||||
for (const block of message.content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
if (block.type === "image") {
|
||||
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "bashExecution": {
|
||||
chars = message.command.length + message.output.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "branchSummary":
|
||||
case "compactionSummary": {
|
||||
chars = message.summary.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
|
||||
* Never cut at tool results (they must follow their tool call).
|
||||
* When we cut at an assistant message with tool calls, its tool results follow it
|
||||
* and will be kept.
|
||||
* BashExecutionMessage is treated like a user message (user-initiated context).
|
||||
*/
|
||||
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
|
||||
const cutPoints: number[] = [];
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = entries[i];
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const role = entry.message.role;
|
||||
switch (role) {
|
||||
case "bashExecution":
|
||||
case "hookMessage":
|
||||
case "branchSummary":
|
||||
case "compactionSummary":
|
||||
case "user":
|
||||
case "assistant":
|
||||
cutPoints.push(i);
|
||||
break;
|
||||
case "toolResult":
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "compaction":
|
||||
case "branch_summary":
|
||||
case "custom":
|
||||
case "custom_message":
|
||||
case "label":
|
||||
}
|
||||
// branch_summary and custom_message are user-role messages, valid cut points
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
cutPoints.push(i);
|
||||
}
|
||||
}
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
|
||||
* Returns -1 if no turn start found before the index.
|
||||
* BashExecutionMessage is treated like a user message for turn boundaries.
|
||||
*/
|
||||
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
|
||||
for (let i = entryIndex; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
// branch_summary and custom_message are user-role messages, can start a turn
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
return i;
|
||||
}
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
if (role === "user" || role === "bashExecution") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export interface CutPointResult {
|
||||
/** Index of first entry to keep */
|
||||
firstKeptEntryIndex: number;
|
||||
/** Index of user message that starts the turn being split, or -1 if not splitting */
|
||||
turnStartIndex: number;
|
||||
/** Whether this cut splits a turn (cut point is not a user message) */
|
||||
isSplitTurn: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
|
||||
*
|
||||
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
|
||||
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
|
||||
*
|
||||
* Can cut at user OR assistant messages (never tool results). When cutting at an
|
||||
* assistant message with tool calls, its tool results come after and will be kept.
|
||||
*
|
||||
* Returns CutPointResult with:
|
||||
* - firstKeptEntryIndex: the entry index to start keeping from
|
||||
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
|
||||
* - isSplitTurn: whether we're cutting in the middle of a turn
|
||||
*
|
||||
* Only considers entries between `startIndex` and `endIndex` (exclusive).
|
||||
*/
|
||||
export function findCutPoint(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
keepRecentTokens: number,
|
||||
): CutPointResult {
|
||||
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
|
||||
|
||||
if (cutPoints.length === 0) {
|
||||
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
|
||||
}
|
||||
|
||||
// Walk backwards from newest, accumulating estimated message sizes
|
||||
let accumulatedTokens = 0;
|
||||
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
|
||||
|
||||
for (let i = endIndex - 1; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type !== "message") continue;
|
||||
|
||||
// Estimate this message's size
|
||||
const messageTokens = estimateTokens(entry.message);
|
||||
accumulatedTokens += messageTokens;
|
||||
|
||||
// Check if we've exceeded the budget
|
||||
if (accumulatedTokens >= keepRecentTokens) {
|
||||
// Find the closest valid cut point at or after this entry
|
||||
for (let c = 0; c < cutPoints.length; c++) {
|
||||
if (cutPoints[c] >= i) {
|
||||
cutIndex = cutPoints[c];
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
|
||||
while (cutIndex > startIndex) {
|
||||
const prevEntry = entries[cutIndex - 1];
|
||||
// Stop at session header or compaction boundaries
|
||||
if (prevEntry.type === "compaction") {
|
||||
break;
|
||||
}
|
||||
if (prevEntry.type === "message") {
|
||||
// Stop if we hit any message
|
||||
break;
|
||||
}
|
||||
// Include this non-message entry (bash, settings change, etc.)
|
||||
cutIndex--;
|
||||
}
|
||||
|
||||
// Determine if this is a split turn
|
||||
const cutEntry = entries[cutIndex];
|
||||
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
|
||||
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
|
||||
|
||||
return {
|
||||
firstKeptEntryIndex: cutIndex,
|
||||
turnStartIndex,
|
||||
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization
|
||||
// ============================================================================
|
||||
|
||||
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned by user]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [Ordered list of what should happen next]
|
||||
|
||||
## Critical Context
|
||||
- [Any data, examples, or references needed to continue]
|
||||
- [Or "(none)" if not applicable]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
|
||||
|
||||
Update the existing structured summary with new information. RULES:
|
||||
- PRESERVE all existing information from the previous summary
|
||||
- ADD new progress, decisions, and context from the new messages
|
||||
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
|
||||
- UPDATE "Next Steps" based on what was accomplished
|
||||
- PRESERVE exact file paths, function names, and error messages
|
||||
- If something is no longer relevant, you may remove it
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[Preserve existing goals, add new ones if the task expanded]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Preserve existing, add new ones discovered]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Include previously done items AND newly completed items]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work - update based on progress]
|
||||
|
||||
### Blocked
|
||||
- [Current blockers - remove if resolved]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
|
||||
|
||||
## Next Steps
|
||||
1. [Update based on current state]
|
||||
|
||||
## Critical Context
|
||||
- [Preserve important context, add new if needed]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of the conversation using the LLM.
|
||||
* If previousSummary is provided, uses the update prompt to merge.
|
||||
*/
|
||||
export async function generateSummary(
|
||||
currentMessages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
previousSummary?: string,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.8 * reserveTokens);
|
||||
|
||||
// Use update prompt if we have a previous summary, otherwise initial prompt
|
||||
let basePrompt = previousSummary ? UPDATE_SUMMARIZATION_PROMPT : SUMMARIZATION_PROMPT;
|
||||
if (customInstructions) {
|
||||
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
|
||||
}
|
||||
|
||||
// Serialize conversation to text so model doesn't try to continue it
|
||||
// Convert to LLM messages first (handles custom types like bashExecution, hookMessage, etc.)
|
||||
const llmMessages = convertToLlm(currentMessages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build the prompt with conversation wrapped in tags
|
||||
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
|
||||
if (previousSummary) {
|
||||
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
|
||||
}
|
||||
promptText += basePrompt;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
{ maxTokens, signal, apiKey, reasoning: "high" },
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(`Summarization failed: ${response.errorMessage || "Unknown error"}`);
|
||||
}
|
||||
|
||||
const textContent = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
return textContent;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compaction Preparation (for hooks)
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionPreparation {
|
||||
cutPoint: CutPointResult;
|
||||
/** UUID of first entry to keep */
|
||||
firstKeptEntryId: string;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AgentMessage[];
|
||||
/** Messages that will be kept after the summary (recent turns) */
|
||||
messagesToKeep: AgentMessage[];
|
||||
tokensBefore: number;
|
||||
boundaryStart: number;
|
||||
}
|
||||
|
||||
export function prepareCompaction(
|
||||
entries: SessionEntry[],
|
||||
settings: CompactionSettings,
|
||||
): CompactionPreparation | undefined {
|
||||
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
if (entries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = entries.length;
|
||||
|
||||
const lastUsage = getLastAssistantUsage(entries);
|
||||
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
|
||||
|
||||
const cutPoint = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
|
||||
|
||||
// Get UUID of first kept entry
|
||||
const firstKeptEntry = entries[cutPoint.firstKeptEntryIndex];
|
||||
if (!firstKeptEntry?.id) {
|
||||
return undefined; // Session needs migration
|
||||
}
|
||||
const firstKeptEntryId = firstKeptEntry.id;
|
||||
|
||||
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
|
||||
|
||||
// Messages to summarize (will be discarded after summary)
|
||||
const messagesToSummarize: AgentMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const msg = getMessageFromEntry(entries[i]);
|
||||
if (msg) messagesToSummarize.push(msg);
|
||||
}
|
||||
|
||||
// Messages to keep (recent turns, kept after summary)
|
||||
const messagesToKeep: AgentMessage[] = [];
|
||||
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
|
||||
const msg = getMessageFromEntry(entries[i]);
|
||||
if (msg) messagesToKeep.push(msg);
|
||||
}
|
||||
|
||||
return { cutPoint, firstKeptEntryId, messagesToSummarize, messagesToKeep, tokensBefore, boundaryStart };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main compaction function
|
||||
// ============================================================================
|
||||
|
||||
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
|
||||
|
||||
Summarize the prefix to provide context for the retained suffix:
|
||||
|
||||
## Original Request
|
||||
[What did the user ask for in this turn?]
|
||||
|
||||
## Early Progress
|
||||
- [Key decisions and work done in the prefix]
|
||||
|
||||
## Context for Suffix
|
||||
- [Information needed to understand the retained recent work]
|
||||
|
||||
Be concise. Focus on what's needed to understand the kept suffix.`;
|
||||
|
||||
/**
|
||||
* Calculate compaction and generate summary.
|
||||
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
|
||||
*
|
||||
* @param entries - All session entries (must have uuid fields for v2)
|
||||
* @param model - Model to use for summarization
|
||||
* @param settings - Compaction settings
|
||||
* @param apiKey - API key for LLM
|
||||
* @param signal - Optional abort signal
|
||||
* @param customInstructions - Optional custom focus for the summary
|
||||
*/
|
||||
export async function compact(
|
||||
entries: SessionEntry[],
|
||||
model: Model<any>,
|
||||
settings: CompactionSettings,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
): Promise<CompactionResult> {
|
||||
// Don't compact if the last entry is already a compaction
|
||||
if (entries.length > 0 && entries[entries.length - 1].type === "compaction") {
|
||||
throw new Error("Already compacted");
|
||||
}
|
||||
|
||||
// Find previous compaction boundary
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
if (entries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = entries.length;
|
||||
|
||||
// Get token count before compaction
|
||||
const lastUsage = getLastAssistantUsage(entries);
|
||||
const tokensBefore = lastUsage ? calculateContextTokens(lastUsage) : 0;
|
||||
|
||||
// Find cut point (entry index) within the valid range
|
||||
const cutResult = findCutPoint(entries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
|
||||
|
||||
// Extract messages for history summary (before the turn that contains the cut point)
|
||||
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
|
||||
const historyMessages: AgentMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const msg = getMessageFromEntry(entries[i]);
|
||||
if (msg) historyMessages.push(msg);
|
||||
}
|
||||
|
||||
// Get previous summary for iterative update (if not from hook)
|
||||
let previousSummary: string | undefined;
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
previousSummary = prevCompaction.summary;
|
||||
}
|
||||
|
||||
// Extract file operations from messages and previous compaction
|
||||
const fileOps = extractFileOperations(historyMessages, entries, prevCompactionIndex);
|
||||
|
||||
// Extract messages for turn prefix summary (if splitting a turn)
|
||||
const turnPrefixMessages: AgentMessage[] = [];
|
||||
if (cutResult.isSplitTurn) {
|
||||
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
|
||||
const msg = getMessageFromEntry(entries[i]);
|
||||
if (msg) turnPrefixMessages.push(msg);
|
||||
}
|
||||
// Also extract file ops from turn prefix
|
||||
for (const msg of turnPrefixMessages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate summaries (can be parallel if both needed) and merge into one
|
||||
let summary: string;
|
||||
|
||||
if (cutResult.isSplitTurn && turnPrefixMessages.length > 0) {
|
||||
// Generate both summaries in parallel
|
||||
const [historyResult, turnPrefixResult] = await Promise.all([
|
||||
historyMessages.length > 0
|
||||
? generateSummary(
|
||||
historyMessages,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
|
||||
]);
|
||||
// Merge into single summary
|
||||
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
|
||||
} else {
|
||||
// Just generate history summary
|
||||
summary = await generateSummary(
|
||||
historyMessages,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
);
|
||||
}
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
// Get UUID of first kept entry
|
||||
const firstKeptEntry = entries[cutResult.firstKeptEntryIndex];
|
||||
const firstKeptEntryId = firstKeptEntry.id;
|
||||
if (!firstKeptEntryId) {
|
||||
throw new Error("First kept entry has no UUID - session may need migration");
|
||||
}
|
||||
|
||||
return {
|
||||
summary,
|
||||
firstKeptEntryId,
|
||||
tokensBefore,
|
||||
details: { readFiles, modifiedFiles } as CompactionDetails,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a summary for a turn prefix (when splitting a turn).
|
||||
*/
|
||||
async function generateTurnPrefixSummary(
|
||||
messages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
|
||||
|
||||
const transformedMessages = convertToLlm(messages);
|
||||
const summarizationMessages = [
|
||||
...transformedMessages,
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: TURN_PREFIX_SUMMARIZATION_PROMPT }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await complete(model, { messages: summarizationMessages }, { maxTokens, signal, apiKey });
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`);
|
||||
}
|
||||
|
||||
return response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
7
packages/coding-agent/src/core/compaction/index.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Compaction and summarization utilities.
|
||||
*/
|
||||
|
||||
export * from "./branch-summarization.js";
|
||||
export * from "./compaction.js";
|
||||
export * from "./utils.js";
|
||||
154
packages/coding-agent/src/core/compaction/utils.ts
Normal file
154
packages/coding-agent/src/core/compaction/utils.ts
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Shared utilities for compaction and branch summarization.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
export interface FileOperations {
|
||||
read: Set<string>;
|
||||
written: Set<string>;
|
||||
edited: Set<string>;
|
||||
}
|
||||
|
||||
export function createFileOps(): FileOperations {
|
||||
return {
|
||||
read: new Set(),
|
||||
written: new Set(),
|
||||
edited: new Set(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from tool calls in an assistant message.
|
||||
*/
|
||||
export function extractFileOpsFromMessage(message: AgentMessage, fileOps: FileOperations): void {
|
||||
if (message.role !== "assistant") return;
|
||||
if (!("content" in message) || !Array.isArray(message.content)) return;
|
||||
|
||||
for (const block of message.content) {
|
||||
if (typeof block !== "object" || block === null) continue;
|
||||
if (!("type" in block) || block.type !== "toolCall") continue;
|
||||
if (!("arguments" in block) || !("name" in block)) continue;
|
||||
|
||||
const args = block.arguments as Record<string, unknown> | undefined;
|
||||
if (!args) continue;
|
||||
|
||||
const path = typeof args.path === "string" ? args.path : undefined;
|
||||
if (!path) continue;
|
||||
|
||||
switch (block.name) {
|
||||
case "read":
|
||||
fileOps.read.add(path);
|
||||
break;
|
||||
case "write":
|
||||
fileOps.written.add(path);
|
||||
break;
|
||||
case "edit":
|
||||
fileOps.edited.add(path);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute final file lists from file operations.
|
||||
* Returns readFiles (files only read, not modified) and modifiedFiles.
|
||||
*/
|
||||
export function computeFileLists(fileOps: FileOperations): { readFiles: string[]; modifiedFiles: string[] } {
|
||||
const modified = new Set([...fileOps.edited, ...fileOps.written]);
|
||||
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
|
||||
const modifiedFiles = [...modified].sort();
|
||||
return { readFiles: readOnly, modifiedFiles };
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file operations as XML tags for summary.
|
||||
*/
|
||||
export function formatFileOperations(readFiles: string[], modifiedFiles: string[]): string {
|
||||
const sections: string[] = [];
|
||||
if (readFiles.length > 0) {
|
||||
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
|
||||
}
|
||||
if (modifiedFiles.length > 0) {
|
||||
sections.push(`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`);
|
||||
}
|
||||
if (sections.length === 0) return "";
|
||||
return `\n\n${sections.join("\n\n")}`;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Serialization
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Serialize LLM messages to text for summarization.
|
||||
* This prevents the model from treating it as a conversation to continue.
|
||||
* Call convertToLlm() first to handle custom message types.
|
||||
*/
|
||||
export function serializeConversation(messages: Message[]): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
const content =
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) parts.push(`[User]: ${content}`);
|
||||
} else if (msg.role === "assistant") {
|
||||
const textParts: string[] = [];
|
||||
const thinkingParts: string[] = [];
|
||||
const toolCalls: string[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking") {
|
||||
thinkingParts.push(block.thinking);
|
||||
} else if (block.type === "toolCall") {
|
||||
const args = block.arguments as Record<string, unknown>;
|
||||
const argsStr = Object.entries(args)
|
||||
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
|
||||
.join(", ");
|
||||
toolCalls.push(`${block.name}(${argsStr})`);
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingParts.length > 0) {
|
||||
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
parts.push(`[Assistant]: ${textParts.join("\n")}`);
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const content = msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) {
|
||||
parts.push(`[Tool result]: ${content}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join("\n\n");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization System Prompt
|
||||
// ============================================================================
|
||||
|
||||
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
|
||||
|
||||
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;
|
||||
|
|
@ -7,7 +7,6 @@
|
|||
* for custom tools that depend on pi packages.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
import * as fs from "node:fs";
|
||||
import { createRequire } from "node:module";
|
||||
import * as os from "node:os";
|
||||
|
|
@ -15,15 +14,10 @@ import * as path from "node:path";
|
|||
import { fileURLToPath } from "node:url";
|
||||
import { createJiti } from "jiti";
|
||||
import { getAgentDir, isBunBinary } from "../../config.js";
|
||||
import type { ExecOptions } from "../exec.js";
|
||||
import { execCommand } from "../exec.js";
|
||||
import type { HookUIContext } from "../hooks/types.js";
|
||||
import type {
|
||||
CustomToolFactory,
|
||||
CustomToolsLoadResult,
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
LoadedCustomTool,
|
||||
ToolAPI,
|
||||
} from "./types.js";
|
||||
import type { CustomToolFactory, CustomToolsLoadResult, LoadedCustomTool, ToolAPI } from "./types.js";
|
||||
|
||||
// Create require function to resolve module paths at runtime
|
||||
const require = createRequire(import.meta.url);
|
||||
|
|
@ -87,97 +81,16 @@ function resolveToolPath(toolPath: string, cwd: string): string {
|
|||
return path.resolve(cwd, expanded);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a command and return stdout/stderr/code.
|
||||
* Supports cancellation via AbortSignal and timeout.
|
||||
*/
|
||||
async function execCommand(command: string, args: string[], cwd: string, options?: ExecOptions): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, {
|
||||
cwd,
|
||||
shell: false,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({
|
||||
stdout,
|
||||
stderr,
|
||||
code: code ?? 0,
|
||||
killed,
|
||||
});
|
||||
});
|
||||
|
||||
proc.on("error", (err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({
|
||||
stdout,
|
||||
stderr: stderr || err.message,
|
||||
code: 1,
|
||||
killed,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a no-op UI context for headless modes.
|
||||
*/
|
||||
function createNoOpUIContext(): HookUIContext {
|
||||
return {
|
||||
select: async () => null,
|
||||
select: async () => undefined,
|
||||
confirm: async () => false,
|
||||
input: async () => null,
|
||||
input: async () => undefined,
|
||||
notify: () => {},
|
||||
custom: () => ({ close: () => {}, requestRender: () => {} }),
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -298,7 +211,8 @@ export async function loadCustomTools(
|
|||
// Shared API object - all tools get the same instance
|
||||
const sharedApi: ToolAPI = {
|
||||
cwd,
|
||||
exec: (command: string, args: string[], options?: ExecOptions) => execCommand(command, args, cwd, options),
|
||||
exec: (command: string, args: string[], options?: ExecOptions) =>
|
||||
execCommand(command, args, options?.cwd ?? cwd, options),
|
||||
ui: createNoOpUIContext(),
|
||||
hasUI: false,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
* They can provide custom rendering for tool calls and results in the TUI.
|
||||
*/
|
||||
|
||||
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
|
||||
import type { Component } from "@mariozechner/pi-tui";
|
||||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
import type { Theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ExecOptions, ExecResult } from "../exec.js";
|
||||
import type { HookUIContext } from "../hooks/types.js";
|
||||
import type { SessionEntry } from "../session-manager.js";
|
||||
|
||||
|
|
@ -18,20 +19,8 @@ export type ToolUIContext = HookUIContext;
|
|||
/** Re-export for custom tools to use in execute signature */
|
||||
export type { AgentToolUpdateCallback };
|
||||
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
/** True if the process was killed due to signal or timeout */
|
||||
killed?: boolean;
|
||||
}
|
||||
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the process */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
}
|
||||
// Re-export for backward compatibility
|
||||
export type { ExecOptions, ExecResult } from "../exec.js";
|
||||
|
||||
/** API passed to custom tool factory (stable across session changes) */
|
||||
export interface ToolAPI {
|
||||
|
|
@ -49,12 +38,12 @@ export interface ToolAPI {
|
|||
export interface SessionEvent {
|
||||
/** All session entries (including pre-compaction history) */
|
||||
entries: SessionEntry[];
|
||||
/** Current session file path, or null in --no-session mode */
|
||||
sessionFile: string | null;
|
||||
/** Previous session file path, or null for "start" and "new" */
|
||||
previousSessionFile: string | null;
|
||||
/** Current session file path, or undefined in --no-session mode */
|
||||
sessionFile: string | undefined;
|
||||
/** Previous session file path, or undefined for "start" and "new" */
|
||||
previousSessionFile: string | undefined;
|
||||
/** Reason for the session event */
|
||||
reason: "start" | "switch" | "branch" | "new";
|
||||
reason: "start" | "switch" | "branch" | "new" | "tree";
|
||||
}
|
||||
|
||||
/** Rendering options passed to renderResult */
|
||||
|
|
|
|||
104
packages/coding-agent/src/core/exec.ts
Normal file
104
packages/coding-agent/src/core/exec.ts
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Shared command execution utilities for hooks and custom tools.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
|
||||
/**
|
||||
* Options for executing shell commands.
|
||||
*/
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the command */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
/** Working directory */
|
||||
cwd?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of executing a shell command.
|
||||
*/
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
killed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a shell command and return stdout/stderr/code.
|
||||
* Supports timeout and abort signal.
|
||||
*/
|
||||
export async function execCommand(
|
||||
command: string,
|
||||
args: string[],
|
||||
cwd: string,
|
||||
options?: ExecOptions,
|
||||
): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, {
|
||||
cwd,
|
||||
shell: false,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout?.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: code ?? 0, killed });
|
||||
});
|
||||
|
||||
proc.on("error", (_err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: 1, killed });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AgentState } from "@mariozechner/pi-agent-core";
|
||||
import type { AgentMessage, AgentState } from "@mariozechner/pi-agent-core";
|
||||
import type { AssistantMessage, ImageContent, Message, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import hljs from "highlight.js";
|
||||
|
|
@ -7,7 +7,6 @@ import { homedir } from "os";
|
|||
import * as path from "path";
|
||||
import { basename } from "path";
|
||||
import { APP_NAME, getCustomThemesDir, getThemesDir, VERSION } from "../config.js";
|
||||
import { type BashExecutionMessage, isBashExecutionMessage } from "./messages.js";
|
||||
import type { SessionManager } from "./session-manager.js";
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -122,7 +121,7 @@ function resolveColorValue(
|
|||
}
|
||||
|
||||
/** Load theme JSON from built-in or custom themes directory. */
|
||||
function loadThemeJson(name: string): ThemeJson | null {
|
||||
function loadThemeJson(name: string): ThemeJson | undefined {
|
||||
// Try built-in themes first
|
||||
const themesDir = getThemesDir();
|
||||
const builtinPath = path.join(themesDir, `${name}.json`);
|
||||
|
|
@ -130,7 +129,7 @@ function loadThemeJson(name: string): ThemeJson | null {
|
|||
try {
|
||||
return JSON.parse(readFileSync(builtinPath, "utf-8")) as ThemeJson;
|
||||
} catch {
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -141,11 +140,11 @@ function loadThemeJson(name: string): ThemeJson | null {
|
|||
try {
|
||||
return JSON.parse(readFileSync(customPath, "utf-8")) as ThemeJson;
|
||||
} catch {
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Build complete theme colors object, resolving theme JSON values against defaults. */
|
||||
|
|
@ -821,110 +820,138 @@ function formatToolExecution(
|
|||
return { html, bgColor };
|
||||
}
|
||||
|
||||
function formatMessage(message: Message, toolResultsMap: Map<string, ToolResultMessage>, colors: ThemeColors): string {
|
||||
function formatMessage(
|
||||
message: AgentMessage,
|
||||
toolResultsMap: Map<string, ToolResultMessage>,
|
||||
colors: ThemeColors,
|
||||
): string {
|
||||
let html = "";
|
||||
const timestamp = (message as { timestamp?: number }).timestamp;
|
||||
const timestampHtml = timestamp ? `<div class="message-timestamp">${formatTimestamp(timestamp)}</div>` : "";
|
||||
|
||||
// Handle bash execution messages (user-executed via ! command)
|
||||
if (isBashExecutionMessage(message)) {
|
||||
const bashMsg = message as unknown as BashExecutionMessage;
|
||||
const isError = bashMsg.cancelled || (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null);
|
||||
switch (message.role) {
|
||||
case "bashExecution": {
|
||||
const isError =
|
||||
message.cancelled ||
|
||||
(message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined);
|
||||
|
||||
html += `<div class="tool-execution user-bash${isError ? " user-bash-error" : ""}">`;
|
||||
html += timestampHtml;
|
||||
html += `<div class="tool-command">$ ${escapeHtml(bashMsg.command)}</div>`;
|
||||
html += `<div class="tool-execution user-bash${isError ? " user-bash-error" : ""}">`;
|
||||
html += timestampHtml;
|
||||
html += `<div class="tool-command">$ ${escapeHtml(message.command)}</div>`;
|
||||
|
||||
if (bashMsg.output) {
|
||||
const lines = bashMsg.output.split("\n");
|
||||
html += formatExpandableOutput(lines, 10);
|
||||
if (message.output) {
|
||||
const lines = message.output.split("\n");
|
||||
html += formatExpandableOutput(lines, 10);
|
||||
}
|
||||
|
||||
if (message.cancelled) {
|
||||
html += `<div class="bash-status warning">(cancelled)</div>`;
|
||||
} else if (message.exitCode !== 0 && message.exitCode !== null && message.exitCode !== undefined) {
|
||||
html += `<div class="bash-status error">(exit ${message.exitCode})</div>`;
|
||||
}
|
||||
|
||||
if (message.truncated && message.fullOutputPath) {
|
||||
html += `<div class="bash-truncation warning">Output truncated. Full output: ${escapeHtml(message.fullOutputPath)}</div>`;
|
||||
}
|
||||
|
||||
html += `</div>`;
|
||||
break;
|
||||
}
|
||||
case "user": {
|
||||
const userMsg = message as UserMessage;
|
||||
let textContent = "";
|
||||
const images: ImageContent[] = [];
|
||||
|
||||
if (bashMsg.cancelled) {
|
||||
html += `<div class="bash-status warning">(cancelled)</div>`;
|
||||
} else if (bashMsg.exitCode !== 0 && bashMsg.exitCode !== null) {
|
||||
html += `<div class="bash-status error">(exit ${bashMsg.exitCode})</div>`;
|
||||
}
|
||||
|
||||
if (bashMsg.truncated && bashMsg.fullOutputPath) {
|
||||
html += `<div class="bash-truncation warning">Output truncated. Full output: ${escapeHtml(bashMsg.fullOutputPath)}</div>`;
|
||||
}
|
||||
|
||||
html += `</div>`;
|
||||
return html;
|
||||
}
|
||||
|
||||
if (message.role === "user") {
|
||||
const userMsg = message as UserMessage;
|
||||
let textContent = "";
|
||||
const images: ImageContent[] = [];
|
||||
|
||||
if (typeof userMsg.content === "string") {
|
||||
textContent = userMsg.content;
|
||||
} else {
|
||||
for (const block of userMsg.content) {
|
||||
if (block.type === "text") {
|
||||
textContent += block.text;
|
||||
} else if (block.type === "image") {
|
||||
images.push(block as ImageContent);
|
||||
if (typeof userMsg.content === "string") {
|
||||
textContent = userMsg.content;
|
||||
} else {
|
||||
for (const block of userMsg.content) {
|
||||
if (block.type === "text") {
|
||||
textContent += block.text;
|
||||
} else if (block.type === "image") {
|
||||
images.push(block as ImageContent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
html += `<div class="user-message">${timestampHtml}`;
|
||||
html += `<div class="user-message">${timestampHtml}`;
|
||||
|
||||
// Render images first
|
||||
if (images.length > 0) {
|
||||
html += `<div class="message-images">`;
|
||||
for (const img of images) {
|
||||
html += `<img src="data:${img.mimeType};base64,${img.data}" alt="User uploaded image" class="message-image" />`;
|
||||
// Render images first
|
||||
if (images.length > 0) {
|
||||
html += `<div class="message-images">`;
|
||||
for (const img of images) {
|
||||
html += `<img src="data:${img.mimeType};base64,${img.data}" alt="User uploaded image" class="message-image" />`;
|
||||
}
|
||||
html += `</div>`;
|
||||
}
|
||||
|
||||
// Render text as markdown (server-side)
|
||||
if (textContent.trim()) {
|
||||
html += `<div class="markdown-content">${renderMarkdown(textContent)}</div>`;
|
||||
}
|
||||
|
||||
html += `</div>`;
|
||||
break;
|
||||
}
|
||||
case "assistant": {
|
||||
html += timestampHtml ? `<div class="assistant-message">${timestampHtml}` : "";
|
||||
|
||||
// Render text as markdown (server-side)
|
||||
if (textContent.trim()) {
|
||||
html += `<div class="markdown-content">${renderMarkdown(textContent)}</div>`;
|
||||
}
|
||||
|
||||
html += `</div>`;
|
||||
} else if (message.role === "assistant") {
|
||||
const assistantMsg = message as AssistantMessage;
|
||||
html += timestampHtml ? `<div class="assistant-message">${timestampHtml}` : "";
|
||||
|
||||
for (const content of assistantMsg.content) {
|
||||
if (content.type === "text" && content.text.trim()) {
|
||||
// Render markdown server-side
|
||||
html += `<div class="assistant-text markdown-content">${renderMarkdown(content.text)}</div>`;
|
||||
} else if (content.type === "thinking" && content.thinking.trim()) {
|
||||
html += `<div class="thinking-text">${escapeHtml(content.thinking.trim()).replace(/\n/g, "<br>")}</div>`;
|
||||
for (const content of message.content) {
|
||||
if (content.type === "text" && content.text.trim()) {
|
||||
// Render markdown server-side
|
||||
html += `<div class="assistant-text markdown-content">${renderMarkdown(content.text)}</div>`;
|
||||
} else if (content.type === "thinking" && content.thinking.trim()) {
|
||||
html += `<div class="thinking-text">${escapeHtml(content.thinking.trim()).replace(/\n/g, "<br>")}</div>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const content of assistantMsg.content) {
|
||||
if (content.type === "toolCall") {
|
||||
const toolResult = toolResultsMap.get(content.id);
|
||||
const { html: toolHtml, bgColor } = formatToolExecution(
|
||||
content.name,
|
||||
content.arguments as Record<string, unknown>,
|
||||
toolResult,
|
||||
colors,
|
||||
);
|
||||
html += `<div class="tool-execution" style="background-color: ${bgColor}">${toolHtml}</div>`;
|
||||
for (const content of message.content) {
|
||||
if (content.type === "toolCall") {
|
||||
const toolResult = toolResultsMap.get(content.id);
|
||||
const { html: toolHtml, bgColor } = formatToolExecution(
|
||||
content.name,
|
||||
content.arguments as Record<string, unknown>,
|
||||
toolResult,
|
||||
colors,
|
||||
);
|
||||
html += `<div class="tool-execution" style="background-color: ${bgColor}">${toolHtml}</div>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasToolCalls = assistantMsg.content.some((c) => c.type === "toolCall");
|
||||
if (!hasToolCalls) {
|
||||
if (assistantMsg.stopReason === "aborted") {
|
||||
html += '<div class="error-text">Aborted</div>';
|
||||
} else if (assistantMsg.stopReason === "error") {
|
||||
html += `<div class="error-text">Error: ${escapeHtml(assistantMsg.errorMessage || "Unknown error")}</div>`;
|
||||
const hasToolCalls = message.content.some((c) => c.type === "toolCall");
|
||||
if (!hasToolCalls) {
|
||||
if (message.stopReason === "aborted") {
|
||||
html += '<div class="error-text">Aborted</div>';
|
||||
} else if (message.stopReason === "error") {
|
||||
html += `<div class="error-text">Error: ${escapeHtml(message.errorMessage || "Unknown error")}</div>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (timestampHtml) {
|
||||
html += "</div>";
|
||||
if (timestampHtml) {
|
||||
html += "</div>";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "toolResult":
|
||||
// Tool results are rendered inline with tool calls
|
||||
break;
|
||||
case "hookMessage":
|
||||
// Hook messages with display:true shown as info boxes
|
||||
if (message.display) {
|
||||
const content = typeof message.content === "string" ? message.content : JSON.stringify(message.content);
|
||||
html += `<div class="hook-message">${timestampHtml}<div class="hook-type">[${escapeHtml(message.customType)}]</div><div class="markdown-content">${renderMarkdown(content)}</div></div>`;
|
||||
}
|
||||
break;
|
||||
case "compactionSummary":
|
||||
// Rendered separately via formatCompaction
|
||||
break;
|
||||
case "branchSummary":
|
||||
// Rendered as compaction-like summary
|
||||
html += `<div class="compaction-container expanded"><div class="compaction-content"><div class="compaction-summary"><div class="compaction-summary-header">Branch Summary</div><div class="compaction-summary-content">${escapeHtml(message.summary).replace(/\n/g, "<br>")}</div></div></div></div>`;
|
||||
break;
|
||||
default: {
|
||||
// Exhaustive check
|
||||
const _exhaustive: never = message;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -995,7 +1022,7 @@ function generateHtml(data: ParsedSessionData, filename: string, colors: ThemeCo
|
|||
const lastModelInfo = lastProvider ? `${lastProvider}/${lastModel}` : lastModel;
|
||||
|
||||
const contextWindow = data.contextWindow || 0;
|
||||
const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : null;
|
||||
const contextPercent = contextWindow > 0 ? ((contextTokens / contextWindow) * 100).toFixed(1) : undefined;
|
||||
|
||||
let messagesHtml = "";
|
||||
for (const event of data.sessionEvents) {
|
||||
|
|
@ -1343,6 +1370,9 @@ export function exportSessionToHtml(
|
|||
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
const sessionFile = sessionManager.getSessionFile();
|
||||
if (!sessionFile) {
|
||||
throw new Error("Cannot export in-memory session to HTML");
|
||||
}
|
||||
const content = readFileSync(sessionFile, "utf8");
|
||||
const data = parseSessionFile(content);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,39 +1,13 @@
|
|||
export { discoverAndLoadHooks, type LoadedHook, type LoadHooksResult, loadHooks, type SendHandler } from "./loader.js";
|
||||
export { type HookErrorListener, HookRunner } from "./runner.js";
|
||||
export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js";
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
BashToolResultEvent,
|
||||
CustomToolResultEvent,
|
||||
EditToolResultEvent,
|
||||
ExecResult,
|
||||
FindToolResultEvent,
|
||||
GrepToolResultEvent,
|
||||
HookAPI,
|
||||
HookError,
|
||||
HookEvent,
|
||||
HookEventContext,
|
||||
HookFactory,
|
||||
HookUIContext,
|
||||
LsToolResultEvent,
|
||||
ReadToolResultEvent,
|
||||
SessionEvent,
|
||||
SessionEventResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
WriteToolResultEvent,
|
||||
} from "./types.js";
|
||||
// biome-ignore assist/source/organizeImports: biome is not smart
|
||||
export {
|
||||
isBashToolResult,
|
||||
isEditToolResult,
|
||||
isFindToolResult,
|
||||
isGrepToolResult,
|
||||
isLsToolResult,
|
||||
isReadToolResult,
|
||||
isWriteToolResult,
|
||||
} from "./types.js";
|
||||
discoverAndLoadHooks,
|
||||
loadHooks,
|
||||
type AppendEntryHandler,
|
||||
type LoadedHook,
|
||||
type LoadHooksResult,
|
||||
type SendMessageHandler,
|
||||
} from "./loader.js";
|
||||
export { execCommand, HookRunner, type HookErrorListener } from "./runner.js";
|
||||
export { wrapToolsWithHooks, wrapToolWithHooks } from "./tool-wrapper.js";
|
||||
export type * from "./types.js";
|
||||
export type { ReadonlySessionManager } from "../session-manager.js";
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ import { createRequire } from "node:module";
|
|||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import type { Attachment } from "@mariozechner/pi-agent-core";
|
||||
import { createJiti } from "jiti";
|
||||
import { getAgentDir } from "../../config.js";
|
||||
import type { HookAPI, HookFactory } from "./types.js";
|
||||
import type { HookMessage } from "../messages.js";
|
||||
import { execCommand } from "./runner.js";
|
||||
import type { ExecOptions, HookAPI, HookFactory, HookMessageRenderer, RegisteredCommand } from "./types.js";
|
||||
|
||||
// Create require function to resolve module paths at runtime
|
||||
const require = createRequire(import.meta.url);
|
||||
|
|
@ -47,9 +48,17 @@ function getAliases(): Record<string, string> {
|
|||
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
|
||||
|
||||
/**
|
||||
* Send handler type for pi.send().
|
||||
* Send message handler type for pi.sendMessage().
|
||||
*/
|
||||
export type SendHandler = (text: string, attachments?: Attachment[]) => void;
|
||||
export type SendMessageHandler = <T = unknown>(
|
||||
message: Pick<HookMessage<T>, "customType" | "content" | "display" | "details">,
|
||||
triggerTurn?: boolean,
|
||||
) => void;
|
||||
|
||||
/**
|
||||
* Append entry handler type for pi.appendEntry().
|
||||
*/
|
||||
export type AppendEntryHandler = <T = unknown>(customType: string, data?: T) => void;
|
||||
|
||||
/**
|
||||
* Registered handlers for a loaded hook.
|
||||
|
|
@ -61,8 +70,14 @@ export interface LoadedHook {
|
|||
resolvedPath: string;
|
||||
/** Map of event type to handler functions */
|
||||
handlers: Map<string, HandlerFn[]>;
|
||||
/** Set the send handler for this hook's pi.send() */
|
||||
setSendHandler: (handler: SendHandler) => void;
|
||||
/** Map of customType to hook message renderer */
|
||||
messageRenderers: Map<string, HookMessageRenderer>;
|
||||
/** Map of command name to registered command */
|
||||
commands: Map<string, RegisteredCommand>;
|
||||
/** Set the send message handler for this hook's pi.sendMessage() */
|
||||
setSendMessageHandler: (handler: SendMessageHandler) => void;
|
||||
/** Set the append entry handler for this hook's pi.appendEntry() */
|
||||
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -110,32 +125,62 @@ function resolveHookPath(hookPath: string, cwd: string): string {
|
|||
}
|
||||
|
||||
/**
|
||||
* Create a HookAPI instance that collects handlers.
|
||||
* Returns the API and a function to set the send handler later.
|
||||
* Create a HookAPI instance that collects handlers, renderers, and commands.
|
||||
* Returns the API, maps, and a function to set the send message handler later.
|
||||
*/
|
||||
function createHookAPI(handlers: Map<string, HandlerFn[]>): {
|
||||
function createHookAPI(
|
||||
handlers: Map<string, HandlerFn[]>,
|
||||
cwd: string,
|
||||
): {
|
||||
api: HookAPI;
|
||||
setSendHandler: (handler: SendHandler) => void;
|
||||
messageRenderers: Map<string, HookMessageRenderer>;
|
||||
commands: Map<string, RegisteredCommand>;
|
||||
setSendMessageHandler: (handler: SendMessageHandler) => void;
|
||||
setAppendEntryHandler: (handler: AppendEntryHandler) => void;
|
||||
} {
|
||||
let sendHandler: SendHandler = () => {
|
||||
let sendMessageHandler: SendMessageHandler = () => {
|
||||
// Default no-op until mode sets the handler
|
||||
};
|
||||
let appendEntryHandler: AppendEntryHandler = () => {
|
||||
// Default no-op until mode sets the handler
|
||||
};
|
||||
const messageRenderers = new Map<string, HookMessageRenderer>();
|
||||
const commands = new Map<string, RegisteredCommand>();
|
||||
|
||||
const api: HookAPI = {
|
||||
// Cast to HookAPI - the implementation is more general (string event names)
|
||||
// but the interface has specific overloads for type safety in hooks
|
||||
const api = {
|
||||
on(event: string, handler: HandlerFn): void {
|
||||
const list = handlers.get(event) ?? [];
|
||||
list.push(handler);
|
||||
handlers.set(event, list);
|
||||
},
|
||||
send(text: string, attachments?: Attachment[]): void {
|
||||
sendHandler(text, attachments);
|
||||
sendMessage<T = unknown>(message: HookMessage<T>, triggerTurn?: boolean): void {
|
||||
sendMessageHandler(message, triggerTurn);
|
||||
},
|
||||
appendEntry<T = unknown>(customType: string, data?: T): void {
|
||||
appendEntryHandler(customType, data);
|
||||
},
|
||||
registerMessageRenderer<T = unknown>(customType: string, renderer: HookMessageRenderer<T>): void {
|
||||
messageRenderers.set(customType, renderer as HookMessageRenderer);
|
||||
},
|
||||
registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void {
|
||||
commands.set(name, { name, ...options });
|
||||
},
|
||||
exec(command: string, args: string[], options?: ExecOptions) {
|
||||
return execCommand(command, args, options?.cwd ?? cwd, options);
|
||||
},
|
||||
} as HookAPI;
|
||||
|
||||
return {
|
||||
api,
|
||||
setSendHandler: (handler: SendHandler) => {
|
||||
sendHandler = handler;
|
||||
messageRenderers,
|
||||
commands,
|
||||
setSendMessageHandler: (handler: SendMessageHandler) => {
|
||||
sendMessageHandler = handler;
|
||||
},
|
||||
setAppendEntryHandler: (handler: AppendEntryHandler) => {
|
||||
appendEntryHandler = handler;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
@ -164,13 +209,24 @@ async function loadHook(hookPath: string, cwd: string): Promise<{ hook: LoadedHo
|
|||
|
||||
// Create handlers map and API
|
||||
const handlers = new Map<string, HandlerFn[]>();
|
||||
const { api, setSendHandler } = createHookAPI(handlers);
|
||||
const { api, messageRenderers, commands, setSendMessageHandler, setAppendEntryHandler } = createHookAPI(
|
||||
handlers,
|
||||
cwd,
|
||||
);
|
||||
|
||||
// Call factory to register handlers
|
||||
factory(api);
|
||||
|
||||
return {
|
||||
hook: { path: hookPath, resolvedPath, handlers, setSendHandler },
|
||||
hook: {
|
||||
path: hookPath,
|
||||
resolvedPath,
|
||||
handlers,
|
||||
messageRenderers,
|
||||
commands,
|
||||
setSendMessageHandler,
|
||||
setAppendEntryHandler,
|
||||
},
|
||||
error: null,
|
||||
};
|
||||
} catch (err) {
|
||||
|
|
|
|||
|
|
@ -2,17 +2,23 @@
|
|||
* Hook runner - executes hooks and manages their lifecycle.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
import type { LoadedHook, SendHandler } from "./loader.js";
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ModelRegistry } from "../model-registry.js";
|
||||
import type { SessionManager } from "../session-manager.js";
|
||||
import type { AppendEntryHandler, LoadedHook, SendMessageHandler } from "./loader.js";
|
||||
import type {
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
ContextEvent,
|
||||
ContextEventResult,
|
||||
HookError,
|
||||
HookEvent,
|
||||
HookEventContext,
|
||||
HookMessageRenderer,
|
||||
HookUIContext,
|
||||
SessionEvent,
|
||||
SessionEventResult,
|
||||
RegisteredCommand,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeTreeResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEventResult,
|
||||
|
|
@ -28,73 +34,8 @@ const DEFAULT_TIMEOUT = 30000;
|
|||
*/
|
||||
export type HookErrorListener = (error: HookError) => void;
|
||||
|
||||
/**
|
||||
* Execute a command and return stdout/stderr/code.
|
||||
* Supports cancellation via AbortSignal and timeout.
|
||||
*/
|
||||
async function exec(command: string, args: string[], cwd: string, options?: ExecOptions): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, { cwd, shell: false });
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout?.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: code ?? 0, killed });
|
||||
});
|
||||
|
||||
proc.on("error", (_err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: 1, killed });
|
||||
});
|
||||
});
|
||||
}
|
||||
// Re-export execCommand for backward compatibility
|
||||
export { execCommand } from "../exec.js";
|
||||
|
||||
/**
|
||||
* Create a promise that rejects after a timeout.
|
||||
|
|
@ -112,10 +53,11 @@ function createTimeout(ms: number): { promise: Promise<never>; clear: () => void
|
|||
|
||||
/** No-op UI context used when no UI is available */
|
||||
const noOpUIContext: HookUIContext = {
|
||||
select: async () => null,
|
||||
select: async () => undefined,
|
||||
confirm: async () => false,
|
||||
input: async () => null,
|
||||
input: async () => undefined,
|
||||
notify: () => {},
|
||||
custom: () => ({ close: () => {}, requestRender: () => {} }),
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -126,16 +68,24 @@ export class HookRunner {
|
|||
private uiContext: HookUIContext;
|
||||
private hasUI: boolean;
|
||||
private cwd: string;
|
||||
private sessionFile: string | null;
|
||||
private sessionManager: SessionManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private timeout: number;
|
||||
private errorListeners: Set<HookErrorListener> = new Set();
|
||||
|
||||
constructor(hooks: LoadedHook[], cwd: string, timeout: number = DEFAULT_TIMEOUT) {
|
||||
constructor(
|
||||
hooks: LoadedHook[],
|
||||
cwd: string,
|
||||
sessionManager: SessionManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
timeout: number = DEFAULT_TIMEOUT,
|
||||
) {
|
||||
this.hooks = hooks;
|
||||
this.uiContext = noOpUIContext;
|
||||
this.hasUI = false;
|
||||
this.cwd = cwd;
|
||||
this.sessionFile = null;
|
||||
this.sessionManager = sessionManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.timeout = timeout;
|
||||
}
|
||||
|
||||
|
|
@ -148,6 +98,20 @@ export class HookRunner {
|
|||
this.hasUI = hasUI;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the UI context (set by mode).
|
||||
*/
|
||||
getUIContext(): HookUIContext | null {
|
||||
return this.uiContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get whether UI is available.
|
||||
*/
|
||||
getHasUI(): boolean {
|
||||
return this.hasUI;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the paths of all loaded hooks.
|
||||
*/
|
||||
|
|
@ -156,19 +120,22 @@ export class HookRunner {
|
|||
}
|
||||
|
||||
/**
|
||||
* Set the session file path.
|
||||
* Set the send message handler for all hooks' pi.sendMessage().
|
||||
* Call this when the mode initializes.
|
||||
*/
|
||||
setSessionFile(sessionFile: string | null): void {
|
||||
this.sessionFile = sessionFile;
|
||||
setSendMessageHandler(handler: SendMessageHandler): void {
|
||||
for (const hook of this.hooks) {
|
||||
hook.setSendMessageHandler(handler);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the send handler for all hooks' pi.send().
|
||||
* Set the append entry handler for all hooks' pi.appendEntry().
|
||||
* Call this when the mode initializes.
|
||||
*/
|
||||
setSendHandler(handler: SendHandler): void {
|
||||
setAppendEntryHandler(handler: AppendEntryHandler): void {
|
||||
for (const hook of this.hooks) {
|
||||
hook.setSendHandler(handler);
|
||||
hook.setAppendEntryHandler(handler);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -184,7 +151,10 @@ export class HookRunner {
|
|||
/**
|
||||
* Emit an error to all listeners.
|
||||
*/
|
||||
private emitError(error: HookError): void {
|
||||
/**
|
||||
* Emit an error to all error listeners.
|
||||
*/
|
||||
emitError(error: HookError): void {
|
||||
for (const listener of this.errorListeners) {
|
||||
listener(error);
|
||||
}
|
||||
|
|
@ -203,26 +173,89 @@ export class HookRunner {
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a message renderer for the given customType.
|
||||
* Returns the first renderer found across all hooks, or undefined if none.
|
||||
*/
|
||||
getMessageRenderer(customType: string): HookMessageRenderer | undefined {
|
||||
for (const hook of this.hooks) {
|
||||
const renderer = hook.messageRenderers.get(customType);
|
||||
if (renderer) {
|
||||
return renderer;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered commands from all hooks.
|
||||
*/
|
||||
getRegisteredCommands(): RegisteredCommand[] {
|
||||
const commands: RegisteredCommand[] = [];
|
||||
for (const hook of this.hooks) {
|
||||
for (const command of hook.commands.values()) {
|
||||
commands.push(command);
|
||||
}
|
||||
}
|
||||
return commands;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a registered command by name.
|
||||
* Returns the first command found across all hooks, or undefined if none.
|
||||
*/
|
||||
getCommand(name: string): RegisteredCommand | undefined {
|
||||
for (const hook of this.hooks) {
|
||||
const command = hook.commands.get(name);
|
||||
if (command) {
|
||||
return command;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the event context for handlers.
|
||||
*/
|
||||
private createContext(): HookEventContext {
|
||||
return {
|
||||
exec: (command: string, args: string[], options?: ExecOptions) => exec(command, args, this.cwd, options),
|
||||
ui: this.uiContext,
|
||||
hasUI: this.hasUI,
|
||||
cwd: this.cwd,
|
||||
sessionFile: this.sessionFile,
|
||||
sessionManager: this.sessionManager,
|
||||
modelRegistry: this.modelRegistry,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit an event to all hooks.
|
||||
* Returns the result from session/tool_result events (if any handler returns one).
|
||||
* Check if event type is a session "before_*" event that can be cancelled.
|
||||
*/
|
||||
async emit(event: HookEvent): Promise<SessionEventResult | ToolResultEventResult | undefined> {
|
||||
private isSessionBeforeEvent(
|
||||
type: string,
|
||||
): type is
|
||||
| "session_before_switch"
|
||||
| "session_before_new"
|
||||
| "session_before_branch"
|
||||
| "session_before_compact"
|
||||
| "session_before_tree" {
|
||||
return (
|
||||
type === "session_before_switch" ||
|
||||
type === "session_before_new" ||
|
||||
type === "session_before_branch" ||
|
||||
type === "session_before_compact" ||
|
||||
type === "session_before_tree"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit an event to all hooks.
|
||||
* Returns the result from session before_* / tool_result events (if any handler returns one).
|
||||
*/
|
||||
async emit(
|
||||
event: HookEvent,
|
||||
): Promise<SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
let result: SessionEventResult | ToolResultEventResult | undefined;
|
||||
let result: SessionBeforeCompactResult | SessionBeforeTreeResult | ToolResultEventResult | undefined;
|
||||
|
||||
for (const hook of this.hooks) {
|
||||
const handlers = hook.handlers.get(event.type);
|
||||
|
|
@ -230,11 +263,10 @@ export class HookRunner {
|
|||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
// No timeout for before_compact events (like tool_call, they may take a while)
|
||||
const isBeforeCompact = event.type === "session" && (event as SessionEvent).reason === "before_compact";
|
||||
// No timeout for session_before_compact events (like tool_call, they may take a while)
|
||||
let handlerResult: unknown;
|
||||
|
||||
if (isBeforeCompact) {
|
||||
if (event.type === "session_before_compact") {
|
||||
handlerResult = await handler(event, ctx);
|
||||
} else {
|
||||
const timeout = createTimeout(this.timeout);
|
||||
|
|
@ -242,9 +274,9 @@ export class HookRunner {
|
|||
timeout.clear();
|
||||
}
|
||||
|
||||
// For session events, capture the result (for before_* cancellation)
|
||||
if (event.type === "session" && handlerResult) {
|
||||
result = handlerResult as SessionEventResult;
|
||||
// For session before_* events, capture the result (for cancellation)
|
||||
if (this.isSessionBeforeEvent(event.type) && handlerResult) {
|
||||
result = handlerResult as SessionBeforeCompactResult | SessionBeforeTreeResult;
|
||||
// If cancelled, stop processing further hooks
|
||||
if (result.cancel) {
|
||||
return result;
|
||||
|
|
@ -298,4 +330,83 @@ export class HookRunner {
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit a context event to all hooks.
|
||||
* Handlers are chained - each gets the previous handler's output (if any).
|
||||
* Returns the final modified messages, or the original if no modifications.
|
||||
*
|
||||
* Note: Messages are already deep-copied by the caller (pi-ai preprocessor).
|
||||
*/
|
||||
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
|
||||
const ctx = this.createContext();
|
||||
let currentMessages = messages;
|
||||
|
||||
for (const hook of this.hooks) {
|
||||
const handlers = hook.handlers.get("context");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ContextEvent = { type: "context", messages: currentMessages };
|
||||
const timeout = createTimeout(this.timeout);
|
||||
const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]);
|
||||
timeout.clear();
|
||||
|
||||
if (handlerResult && (handlerResult as ContextEventResult).messages) {
|
||||
currentMessages = (handlerResult as ContextEventResult).messages!;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
this.emitError({
|
||||
hookPath: hook.path,
|
||||
event: "context",
|
||||
error: message,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit before_agent_start event to all hooks.
|
||||
* Returns the first message to inject (if any handler returns one).
|
||||
*/
|
||||
async emitBeforeAgentStart(
|
||||
prompt: string,
|
||||
images?: import("@mariozechner/pi-ai").ImageContent[],
|
||||
): Promise<BeforeAgentStartEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
let result: BeforeAgentStartEventResult | undefined;
|
||||
|
||||
for (const hook of this.hooks) {
|
||||
const handlers = hook.handlers.get("before_agent_start");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images };
|
||||
const timeout = createTimeout(this.timeout);
|
||||
const handlerResult = await Promise.race([handler(event, ctx), timeout.promise]);
|
||||
timeout.clear();
|
||||
|
||||
// Take the first message returned
|
||||
if (handlerResult && (handlerResult as BeforeAgentStartEventResult).message && !result) {
|
||||
result = handlerResult as BeforeAgentStartEventResult;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
this.emitError({
|
||||
hookPath: hook.path,
|
||||
event: "before_agent_start",
|
||||
error: message,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
* Tool wrapper - wraps tools with hook callbacks for interception.
|
||||
*/
|
||||
|
||||
import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool, AgentToolUpdateCallback } from "@mariozechner/pi-agent-core";
|
||||
import type { HookRunner } from "./runner.js";
|
||||
import type { ToolCallEventResult, ToolResultEventResult } from "./types.js";
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,37 @@
|
|||
* and interact with the user via UI primitives.
|
||||
*/
|
||||
|
||||
import type { AppMessage, Attachment } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
|
||||
import type { CutPointResult } from "../compaction.js";
|
||||
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Message, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
|
||||
import type { Component } from "@mariozechner/pi-tui";
|
||||
import type { Theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { CompactionPreparation, CompactionResult } from "../compaction/index.js";
|
||||
import type { ExecOptions, ExecResult } from "../exec.js";
|
||||
import type { HookMessage } from "../messages.js";
|
||||
import type { ModelRegistry } from "../model-registry.js";
|
||||
import type { BranchSummaryEntry, CompactionEntry, SessionEntry, SessionManager } from "../session-manager.js";
|
||||
|
||||
/**
|
||||
* Read-only view of SessionManager for hooks.
|
||||
* Hooks should use pi.sendMessage() and pi.appendEntry() for writes.
|
||||
*/
|
||||
export type ReadonlySessionManager = Pick<
|
||||
SessionManager,
|
||||
| "getCwd"
|
||||
| "getSessionDir"
|
||||
| "getSessionId"
|
||||
| "getSessionFile"
|
||||
| "getLeafId"
|
||||
| "getLeafEntry"
|
||||
| "getEntry"
|
||||
| "getLabel"
|
||||
| "getPath"
|
||||
| "getHeader"
|
||||
| "getEntries"
|
||||
| "getTree"
|
||||
>;
|
||||
|
||||
import type { EditToolDetails } from "../tools/edit.js";
|
||||
import type {
|
||||
BashToolDetails,
|
||||
FindToolDetails,
|
||||
|
|
@ -17,27 +44,8 @@ import type {
|
|||
ReadToolDetails,
|
||||
} from "../tools/index.js";
|
||||
|
||||
// ============================================================================
|
||||
// Execution Context
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Result of executing a command via ctx.exec()
|
||||
*/
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
/** True if the process was killed due to signal or timeout */
|
||||
killed?: boolean;
|
||||
}
|
||||
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the process */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
}
|
||||
// Re-export for backward compatibility
|
||||
export type { ExecOptions, ExecResult } from "../exec.js";
|
||||
|
||||
/**
|
||||
* UI context for hooks to request interactive UI from the harness.
|
||||
|
|
@ -50,7 +58,7 @@ export interface HookUIContext {
|
|||
* @param options - Array of string options
|
||||
* @returns Selected option string, or null if cancelled
|
||||
*/
|
||||
select(title: string, options: string[]): Promise<string | null>;
|
||||
select(title: string, options: string[]): Promise<string | undefined>;
|
||||
|
||||
/**
|
||||
* Show a confirmation dialog.
|
||||
|
|
@ -60,97 +68,191 @@ export interface HookUIContext {
|
|||
|
||||
/**
|
||||
* Show a text input dialog.
|
||||
* @returns User input, or null if cancelled
|
||||
* @returns User input, or undefined if cancelled
|
||||
*/
|
||||
input(title: string, placeholder?: string): Promise<string | null>;
|
||||
input(title: string, placeholder?: string): Promise<string | undefined>;
|
||||
|
||||
/**
|
||||
* Show a notification to the user.
|
||||
*/
|
||||
notify(message: string, type?: "info" | "warning" | "error"): void;
|
||||
|
||||
/**
|
||||
* Show a custom component with keyboard focus.
|
||||
* The component receives keyboard input via handleInput() if implemented.
|
||||
*
|
||||
* @param component - Component to display (implement handleInput for keyboard, dispose for cleanup)
|
||||
* @returns Object with close() to restore normal UI and requestRender() to trigger redraw
|
||||
*/
|
||||
custom(component: Component & { dispose?(): void }): { close: () => void; requestRender: () => void };
|
||||
}
|
||||
|
||||
/**
|
||||
* Context passed to hook event handlers.
|
||||
*/
|
||||
export interface HookEventContext {
|
||||
/** Execute a command and return stdout/stderr/code */
|
||||
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
|
||||
/** UI methods for user interaction */
|
||||
ui: HookUIContext;
|
||||
/** Whether UI is available (false in print mode) */
|
||||
hasUI: boolean;
|
||||
/** Current working directory */
|
||||
cwd: string;
|
||||
/** Path to session file, or null if --no-session */
|
||||
sessionFile: string | null;
|
||||
/** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */
|
||||
sessionManager: ReadonlySessionManager;
|
||||
/** Model registry - use for API key resolution and model retrieval */
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Events
|
||||
// Session Events
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Base fields shared by all session events.
|
||||
*/
|
||||
interface SessionEventBase {
|
||||
type: "session";
|
||||
/** All session entries (including pre-compaction history) */
|
||||
entries: SessionEntry[];
|
||||
/** Current session file path, or null in --no-session mode */
|
||||
sessionFile: string | null;
|
||||
/** Previous session file path, or null for "start" and "new" */
|
||||
previousSessionFile: string | null;
|
||||
/** Fired on initial session load */
|
||||
export interface SessionStartEvent {
|
||||
type: "session_start";
|
||||
}
|
||||
|
||||
/**
|
||||
* Event data for session events.
|
||||
* Discriminated union based on reason.
|
||||
*
|
||||
* Lifecycle:
|
||||
* - start: Initial session load
|
||||
* - before_switch / switch: Session switch (e.g., /resume command)
|
||||
* - before_new / new: New session (e.g., /new command)
|
||||
* - before_branch / branch: Session branch (e.g., /branch command)
|
||||
* - before_compact / compact: Before/after context compaction
|
||||
* - shutdown: Process exit (SIGINT/SIGTERM)
|
||||
*
|
||||
* "before_*" events fire before the action and can be cancelled via SessionEventResult.
|
||||
* Other events fire after the action completes.
|
||||
*/
|
||||
/** Fired before switching to another session (can be cancelled) */
|
||||
export interface SessionBeforeSwitchEvent {
|
||||
type: "session_before_switch";
|
||||
/** Session file we're switching to */
|
||||
targetSessionFile: string;
|
||||
}
|
||||
|
||||
/** Fired after switching to another session */
|
||||
export interface SessionSwitchEvent {
|
||||
type: "session_switch";
|
||||
/** Session file we came from */
|
||||
previousSessionFile: string | undefined;
|
||||
}
|
||||
|
||||
/** Fired before creating a new session (can be cancelled) */
|
||||
export interface SessionBeforeNewEvent {
|
||||
type: "session_before_new";
|
||||
}
|
||||
|
||||
/** Fired after creating a new session */
|
||||
export interface SessionNewEvent {
|
||||
type: "session_new";
|
||||
}
|
||||
|
||||
/** Fired before branching a session (can be cancelled) */
|
||||
export interface SessionBeforeBranchEvent {
|
||||
type: "session_before_branch";
|
||||
/** Index of the entry in the session (SessionManager.getEntries()) to branch from */
|
||||
entryIndex: number;
|
||||
}
|
||||
|
||||
/** Fired after branching a session */
|
||||
export interface SessionBranchEvent {
|
||||
type: "session_branch";
|
||||
previousSessionFile: string | undefined;
|
||||
}
|
||||
|
||||
/** Fired before context compaction (can be cancelled or customized) */
|
||||
export interface SessionBeforeCompactEvent {
|
||||
type: "session_before_compact";
|
||||
/** Compaction preparation with cut point, messages to summarize/keep, etc. */
|
||||
preparation: CompactionPreparation;
|
||||
/** Previous compaction entries, newest first. Use for iterative summarization. */
|
||||
previousCompactions: CompactionEntry[];
|
||||
/** Optional user-provided instructions for the summary */
|
||||
customInstructions?: string;
|
||||
/** Current model */
|
||||
model: Model<any>;
|
||||
/** Abort signal - hooks should pass this to LLM calls and check it periodically */
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
/** Fired after context compaction */
|
||||
export interface SessionCompactEvent {
|
||||
type: "session_compact";
|
||||
compactionEntry: CompactionEntry;
|
||||
/** Whether the compaction entry was provided by a hook */
|
||||
fromHook: boolean;
|
||||
}
|
||||
|
||||
/** Fired on process exit (SIGINT/SIGTERM) */
|
||||
export interface SessionShutdownEvent {
|
||||
type: "session_shutdown";
|
||||
}
|
||||
|
||||
/** Preparation data for tree navigation (used by session_before_tree event) */
|
||||
export interface TreePreparation {
|
||||
/** Node being switched to */
|
||||
targetId: string;
|
||||
/** Current active leaf (being abandoned), null if no current position */
|
||||
oldLeafId: string | null;
|
||||
/** Common ancestor of target and old leaf, null if no common ancestor */
|
||||
commonAncestorId: string | null;
|
||||
/** Entries to summarize (old leaf back to common ancestor or compaction) */
|
||||
entriesToSummarize: SessionEntry[];
|
||||
/** Whether user chose to summarize */
|
||||
userWantsSummary: boolean;
|
||||
}
|
||||
|
||||
/** Fired before navigating to a different node in the session tree (can be cancelled) */
|
||||
export interface SessionBeforeTreeEvent {
|
||||
type: "session_before_tree";
|
||||
/** Preparation data for the navigation */
|
||||
preparation: TreePreparation;
|
||||
/** Model to use for summarization (conversation model) */
|
||||
model: Model<any>;
|
||||
/** Abort signal - honors Escape during summarization */
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
/** Fired after navigating to a different node in the session tree */
|
||||
export interface SessionTreeEvent {
|
||||
type: "session_tree";
|
||||
/** The new active leaf, null if navigated to before first entry */
|
||||
newLeafId: string | null;
|
||||
/** Previous active leaf, null if there was no position */
|
||||
oldLeafId: string | null;
|
||||
/** Branch summary entry if one was created */
|
||||
summaryEntry?: BranchSummaryEntry;
|
||||
/** Whether summary came from hook */
|
||||
fromHook?: boolean;
|
||||
}
|
||||
|
||||
/** Union of all session event types */
|
||||
export type SessionEvent =
|
||||
| (SessionEventBase & {
|
||||
reason: "start" | "switch" | "new" | "before_switch" | "before_new" | "shutdown";
|
||||
})
|
||||
| (SessionEventBase & {
|
||||
reason: "branch" | "before_branch";
|
||||
/** Index of the turn to branch from */
|
||||
targetTurnIndex: number;
|
||||
})
|
||||
| (SessionEventBase & {
|
||||
reason: "before_compact";
|
||||
cutPoint: CutPointResult;
|
||||
/** Summary from previous compaction, if any. Include this in your summary to preserve context. */
|
||||
previousSummary?: string;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AppMessage[];
|
||||
/** Messages that will be kept after the summary (recent turns) */
|
||||
messagesToKeep: AppMessage[];
|
||||
tokensBefore: number;
|
||||
customInstructions?: string;
|
||||
model: Model<any>;
|
||||
/** Resolve API key for any model (checks settings, OAuth, env vars) */
|
||||
resolveApiKey: (model: Model<any>) => Promise<string | undefined>;
|
||||
/** Abort signal - hooks should pass this to LLM calls and check it periodically */
|
||||
signal: AbortSignal;
|
||||
})
|
||||
| (SessionEventBase & {
|
||||
reason: "compact";
|
||||
compactionEntry: CompactionEntry;
|
||||
tokensBefore: number;
|
||||
/** Whether the compaction entry was provided by a hook */
|
||||
fromHook: boolean;
|
||||
});
|
||||
| SessionStartEvent
|
||||
| SessionBeforeSwitchEvent
|
||||
| SessionSwitchEvent
|
||||
| SessionBeforeNewEvent
|
||||
| SessionNewEvent
|
||||
| SessionBeforeBranchEvent
|
||||
| SessionBranchEvent
|
||||
| SessionBeforeCompactEvent
|
||||
| SessionCompactEvent
|
||||
| SessionShutdownEvent
|
||||
| SessionBeforeTreeEvent
|
||||
| SessionTreeEvent;
|
||||
|
||||
/**
|
||||
* Event data for context event.
|
||||
* Fired before each LLM call, allowing hooks to modify context non-destructively.
|
||||
* Original session messages are NOT modified - only the messages sent to the LLM are affected.
|
||||
*/
|
||||
export interface ContextEvent {
|
||||
type: "context";
|
||||
/** Messages about to be sent to the LLM (deep copy, safe to modify) */
|
||||
messages: AgentMessage[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Event data for before_agent_start event.
|
||||
* Fired after user submits a prompt but before the agent loop starts.
|
||||
* Allows hooks to inject context that will be persisted and visible in TUI.
|
||||
*/
|
||||
export interface BeforeAgentStartEvent {
|
||||
type: "before_agent_start";
|
||||
/** The user's prompt text */
|
||||
prompt: string;
|
||||
/** Any images attached to the prompt */
|
||||
images?: ImageContent[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Event data for agent_start event.
|
||||
|
|
@ -165,7 +267,7 @@ export interface AgentStartEvent {
|
|||
*/
|
||||
export interface AgentEndEvent {
|
||||
type: "agent_end";
|
||||
messages: AppMessage[];
|
||||
messages: AgentMessage[];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -183,7 +285,7 @@ export interface TurnStartEvent {
|
|||
export interface TurnEndEvent {
|
||||
type: "turn_end";
|
||||
turnIndex: number;
|
||||
message: AppMessage;
|
||||
message: AgentMessage;
|
||||
toolResults: ToolResultMessage[];
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +333,7 @@ export interface ReadToolResultEvent extends ToolResultEventBase {
|
|||
/** Tool result event for edit tool */
|
||||
export interface EditToolResultEvent extends ToolResultEventBase {
|
||||
toolName: "edit";
|
||||
details: undefined;
|
||||
details: EditToolDetails | undefined;
|
||||
}
|
||||
|
||||
/** Tool result event for write tool */
|
||||
|
|
@ -307,6 +409,8 @@ export function isLsToolResult(e: ToolResultEvent): e is LsToolResultEvent {
|
|||
*/
|
||||
export type HookEvent =
|
||||
| SessionEvent
|
||||
| ContextEvent
|
||||
| BeforeAgentStartEvent
|
||||
| AgentStartEvent
|
||||
| AgentEndEvent
|
||||
| TurnStartEvent
|
||||
|
|
@ -318,6 +422,15 @@ export type HookEvent =
|
|||
// Event Results
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Return type for context event handlers.
|
||||
* Allows hooks to modify messages before they're sent to the LLM.
|
||||
*/
|
||||
export interface ContextEventResult {
|
||||
/** Modified messages to send instead of the original */
|
||||
messages?: Message[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Return type for tool_call event handlers.
|
||||
* Allows hooks to block tool execution.
|
||||
|
|
@ -343,16 +456,68 @@ export interface ToolResultEventResult {
|
|||
}
|
||||
|
||||
/**
|
||||
* Return type for session event handlers.
|
||||
* Allows hooks to cancel "before_*" actions.
|
||||
* Return type for before_agent_start event handlers.
|
||||
* Allows hooks to inject context before the agent runs.
|
||||
*/
|
||||
export interface SessionEventResult {
|
||||
/** If true, cancel the pending action (switch, clear, or branch) */
|
||||
export interface BeforeAgentStartEventResult {
|
||||
/** Message to inject into context (persisted to session, visible in TUI) */
|
||||
message?: Pick<HookMessage, "customType" | "content" | "display" | "details">;
|
||||
}
|
||||
|
||||
/** Return type for session_before_switch handlers */
|
||||
export interface SessionBeforeSwitchResult {
|
||||
/** If true, cancel the switch */
|
||||
cancel?: boolean;
|
||||
/** If true (for before_branch only), skip restoring conversation to branch point while still creating the branched session file */
|
||||
}
|
||||
|
||||
/** Return type for session_before_new handlers */
|
||||
export interface SessionBeforeNewResult {
|
||||
/** If true, cancel the new session */
|
||||
cancel?: boolean;
|
||||
}
|
||||
|
||||
/** Return type for session_before_branch handlers */
|
||||
export interface SessionBeforeBranchResult {
|
||||
/**
|
||||
* If true, abort the branch entirely. No new session file is created,
|
||||
* conversation stays unchanged.
|
||||
*/
|
||||
cancel?: boolean;
|
||||
/**
|
||||
* If true, the branch proceeds (new session file created, session state updated)
|
||||
* but the in-memory conversation is NOT rewound to the branch point.
|
||||
*
|
||||
* Use case: git-checkpoint hook that restores code state separately.
|
||||
* The hook handles state restoration itself, so it doesn't want the
|
||||
* agent's conversation to be rewound (which would lose recent context).
|
||||
*
|
||||
* - `cancel: true` → nothing happens, user stays in current session
|
||||
* - `skipConversationRestore: true` → branch happens, but messages stay as-is
|
||||
* - neither → branch happens AND messages rewind to branch point (default)
|
||||
*/
|
||||
skipConversationRestore?: boolean;
|
||||
/** Custom compaction entry (for before_compact event) */
|
||||
compactionEntry?: CompactionEntry;
|
||||
}
|
||||
|
||||
/** Return type for session_before_compact handlers */
|
||||
export interface SessionBeforeCompactResult {
|
||||
/** If true, cancel the compaction */
|
||||
cancel?: boolean;
|
||||
/** Custom compaction result - SessionManager adds id/parentId */
|
||||
compaction?: CompactionResult;
|
||||
}
|
||||
|
||||
/** Return type for session_before_tree handlers */
|
||||
export interface SessionBeforeTreeResult {
|
||||
/** If true, cancel the navigation entirely */
|
||||
cancel?: boolean;
|
||||
/**
|
||||
* Custom summary (skips default summarizer).
|
||||
* Only used if preparation.userWantsSummary is true.
|
||||
*/
|
||||
summary?: {
|
||||
summary: string;
|
||||
details?: unknown;
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -361,29 +526,152 @@ export interface SessionEventResult {
|
|||
|
||||
/**
|
||||
* Handler function type for each event.
|
||||
* Handlers can return R, undefined, or void (bare return statements).
|
||||
*/
|
||||
export type HookHandler<E, R = void> = (event: E, ctx: HookEventContext) => Promise<R>;
|
||||
// biome-ignore lint/suspicious/noConfusingVoidType: void allows bare return statements in handlers
|
||||
export type HookHandler<E, R = undefined> = (event: E, ctx: HookEventContext) => Promise<R | void> | R | void;
|
||||
|
||||
export interface HookMessageRenderOptions {
|
||||
/** Whether the view is expanded */
|
||||
expanded: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Renderer for hook messages.
|
||||
* Hooks register these to provide custom TUI rendering for their message types.
|
||||
*/
|
||||
export type HookMessageRenderer<T = unknown> = (
|
||||
message: HookMessage<T>,
|
||||
options: HookMessageRenderOptions,
|
||||
theme: Theme,
|
||||
) => Component | undefined;
|
||||
|
||||
/**
|
||||
* Context passed to hook command handlers.
|
||||
*/
|
||||
export interface HookCommandContext {
|
||||
/** Arguments after the command name */
|
||||
args: string;
|
||||
/** UI methods for user interaction */
|
||||
ui: HookUIContext;
|
||||
/** Whether UI is available (false in print mode) */
|
||||
hasUI: boolean;
|
||||
/** Current working directory */
|
||||
cwd: string;
|
||||
/** Session manager (read-only) - use pi.sendMessage()/pi.appendEntry() for writes */
|
||||
sessionManager: ReadonlySessionManager;
|
||||
/** Model registry for API keys */
|
||||
modelRegistry: ModelRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Command registration options.
|
||||
*/
|
||||
export interface RegisteredCommand {
|
||||
name: string;
|
||||
description?: string;
|
||||
|
||||
handler: (ctx: HookCommandContext) => Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* HookAPI passed to hook factory functions.
|
||||
* Hooks use pi.on() to subscribe to events and pi.send() to inject messages.
|
||||
* Hooks use pi.on() to subscribe to events and pi.sendMessage() to inject messages.
|
||||
*/
|
||||
export interface HookAPI {
|
||||
// biome-ignore lint/suspicious/noConfusingVoidType: void allows handlers to not return anything
|
||||
on(event: "session", handler: HookHandler<SessionEvent, SessionEventResult | void>): void;
|
||||
// Session events
|
||||
on(event: "session_start", handler: HookHandler<SessionStartEvent>): void;
|
||||
on(event: "session_before_switch", handler: HookHandler<SessionBeforeSwitchEvent, SessionBeforeSwitchResult>): void;
|
||||
on(event: "session_switch", handler: HookHandler<SessionSwitchEvent>): void;
|
||||
on(event: "session_before_new", handler: HookHandler<SessionBeforeNewEvent, SessionBeforeNewResult>): void;
|
||||
on(event: "session_new", handler: HookHandler<SessionNewEvent>): void;
|
||||
on(event: "session_before_branch", handler: HookHandler<SessionBeforeBranchEvent, SessionBeforeBranchResult>): void;
|
||||
on(event: "session_branch", handler: HookHandler<SessionBranchEvent>): void;
|
||||
on(
|
||||
event: "session_before_compact",
|
||||
handler: HookHandler<SessionBeforeCompactEvent, SessionBeforeCompactResult>,
|
||||
): void;
|
||||
on(event: "session_compact", handler: HookHandler<SessionCompactEvent>): void;
|
||||
on(event: "session_shutdown", handler: HookHandler<SessionShutdownEvent>): void;
|
||||
on(event: "session_before_tree", handler: HookHandler<SessionBeforeTreeEvent, SessionBeforeTreeResult>): void;
|
||||
on(event: "session_tree", handler: HookHandler<SessionTreeEvent>): void;
|
||||
|
||||
// Context and agent events
|
||||
on(event: "context", handler: HookHandler<ContextEvent, ContextEventResult>): void;
|
||||
on(event: "before_agent_start", handler: HookHandler<BeforeAgentStartEvent, BeforeAgentStartEventResult>): void;
|
||||
on(event: "agent_start", handler: HookHandler<AgentStartEvent>): void;
|
||||
on(event: "agent_end", handler: HookHandler<AgentEndEvent>): void;
|
||||
on(event: "turn_start", handler: HookHandler<TurnStartEvent>): void;
|
||||
on(event: "turn_end", handler: HookHandler<TurnEndEvent>): void;
|
||||
on(event: "tool_call", handler: HookHandler<ToolCallEvent, ToolCallEventResult | undefined>): void;
|
||||
on(event: "tool_result", handler: HookHandler<ToolResultEvent, ToolResultEventResult | undefined>): void;
|
||||
on(event: "tool_call", handler: HookHandler<ToolCallEvent, ToolCallEventResult>): void;
|
||||
on(event: "tool_result", handler: HookHandler<ToolResultEvent, ToolResultEventResult>): void;
|
||||
|
||||
/**
|
||||
* Send a message to the agent.
|
||||
* If the agent is streaming, the message is queued.
|
||||
* If the agent is idle, a new agent loop is started.
|
||||
* Send a custom message to the session. Creates a CustomMessageEntry that
|
||||
* participates in LLM context and can be displayed in the TUI.
|
||||
*
|
||||
* Use this when you want the LLM to see the message content.
|
||||
* For hook state that should NOT be sent to the LLM, use appendEntry() instead.
|
||||
*
|
||||
* @param message - The message to send
|
||||
* @param message.customType - Identifier for your hook (used for filtering on reload)
|
||||
* @param message.content - Message content (string or TextContent/ImageContent array)
|
||||
* @param message.display - Whether to show in TUI (true = styled display, false = hidden)
|
||||
* @param message.details - Optional hook-specific metadata (not sent to LLM)
|
||||
* @param triggerTurn - If true and agent is idle, triggers a new LLM turn. Default: false.
|
||||
* If agent is streaming, message is queued and triggerTurn is ignored.
|
||||
*/
|
||||
send(text: string, attachments?: Attachment[]): void;
|
||||
sendMessage<T = unknown>(
|
||||
message: Pick<HookMessage<T>, "customType" | "content" | "display" | "details">,
|
||||
triggerTurn?: boolean,
|
||||
): void;
|
||||
|
||||
/**
|
||||
* Append a custom entry to the session for hook state persistence.
|
||||
* Creates a CustomEntry that does NOT participate in LLM context.
|
||||
*
|
||||
* Use this to store hook-specific data that should persist across session reloads
|
||||
* but should NOT be sent to the LLM. On reload, scan session entries for your
|
||||
* customType to reconstruct hook state.
|
||||
*
|
||||
* For messages that SHOULD be sent to the LLM, use sendMessage() instead.
|
||||
*
|
||||
* @param customType - Identifier for your hook (used for filtering on reload)
|
||||
* @param data - Hook-specific data to persist (must be JSON-serializable)
|
||||
*
|
||||
* @example
|
||||
* // Store permission state
|
||||
* pi.appendEntry("permissions", { level: "full", grantedAt: Date.now() });
|
||||
*
|
||||
* // On reload, reconstruct state from entries
|
||||
* pi.on("session", async (event, ctx) => {
|
||||
* if (event.reason === "start") {
|
||||
* const entries = event.sessionManager.getEntries();
|
||||
* const myEntries = entries.filter(e => e.type === "custom" && e.customType === "permissions");
|
||||
* // Reconstruct state from myEntries...
|
||||
* }
|
||||
* });
|
||||
*/
|
||||
appendEntry<T = unknown>(customType: string, data?: T): void;
|
||||
|
||||
/**
|
||||
* Register a custom renderer for CustomMessageEntry with a specific customType.
|
||||
* The renderer is called when rendering the entry in the TUI.
|
||||
* Return nothing to use the default renderer.
|
||||
*/
|
||||
registerMessageRenderer<T = unknown>(customType: string, renderer: HookMessageRenderer<T>): void;
|
||||
|
||||
/**
|
||||
* Register a custom slash command.
|
||||
* Handler receives HookCommandContext.
|
||||
*/
|
||||
registerCommand(name: string, options: { description?: string; handler: RegisteredCommand["handler"] }): void;
|
||||
|
||||
/**
|
||||
* Execute a shell command and return stdout/stderr/code.
|
||||
* Supports timeout and abort signal.
|
||||
*/
|
||||
exec(command: string, args: string[], options?: ExecOptions): Promise<ExecResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ export {
|
|||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type CompactionResult,
|
||||
type ModelCycleResult,
|
||||
type PromptOptions,
|
||||
type SessionStats,
|
||||
} from "./agent-session.js";
|
||||
export { type BashExecutorOptions, type BashResult, executeBash } from "./bash-executor.js";
|
||||
export type { CompactionResult } from "./compaction/index.js";
|
||||
export {
|
||||
type CustomAgentTool,
|
||||
type CustomToolFactory,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,27 @@
|
|||
/**
|
||||
* Custom message types and transformers for the coding agent.
|
||||
*
|
||||
* Extends the base AppMessage type with coding-agent specific message types,
|
||||
* Extends the base AgentMessage type with coding-agent specific message types,
|
||||
* and provides a transformer to convert them to LLM-compatible messages.
|
||||
*/
|
||||
|
||||
import type { AppMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { Message } from "@mariozechner/pi-ai";
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, Message, TextContent } from "@mariozechner/pi-ai";
|
||||
|
||||
// ============================================================================
|
||||
// Custom Message Types
|
||||
// ============================================================================
|
||||
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const COMPACTION_SUMMARY_SUFFIX = `
|
||||
</summary>`;
|
||||
|
||||
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
|
||||
|
||||
/**
|
||||
* Message type for bash executions via the ! command.
|
||||
|
|
@ -19,35 +30,50 @@ export interface BashExecutionMessage {
|
|||
role: "bashExecution";
|
||||
command: string;
|
||||
output: string;
|
||||
exitCode: number | null;
|
||||
exitCode: number | undefined;
|
||||
cancelled: boolean;
|
||||
truncated: boolean;
|
||||
fullOutputPath?: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
// Extend CustomMessages via declaration merging
|
||||
/**
|
||||
* Message type for hook-injected messages via sendMessage().
|
||||
* These are custom messages that hooks can inject into the conversation.
|
||||
*/
|
||||
export interface HookMessage<T = unknown> {
|
||||
role: "hookMessage";
|
||||
customType: string;
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
display: boolean;
|
||||
details?: T;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface BranchSummaryMessage {
|
||||
role: "branchSummary";
|
||||
summary: string;
|
||||
fromId: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface CompactionSummaryMessage {
|
||||
role: "compactionSummary";
|
||||
summary: string;
|
||||
tokensBefore: number;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
// Extend CustomAgentMessages via declaration merging
|
||||
declare module "@mariozechner/pi-agent-core" {
|
||||
interface CustomMessages {
|
||||
interface CustomAgentMessages {
|
||||
bashExecution: BashExecutionMessage;
|
||||
hookMessage: HookMessage;
|
||||
branchSummary: BranchSummaryMessage;
|
||||
compactionSummary: CompactionSummaryMessage;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Type Guards
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Type guard for BashExecutionMessage.
|
||||
*/
|
||||
export function isBashExecutionMessage(msg: AppMessage | Message): msg is BashExecutionMessage {
|
||||
return (msg as BashExecutionMessage).role === "bashExecution";
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Formatting
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Convert a BashExecutionMessage to user message text for LLM context.
|
||||
*/
|
||||
|
|
@ -60,7 +86,7 @@ export function bashExecutionToText(msg: BashExecutionMessage): string {
|
|||
}
|
||||
if (msg.cancelled) {
|
||||
text += "\n\n(command cancelled)";
|
||||
} else if (msg.exitCode !== null && msg.exitCode !== 0) {
|
||||
} else if (msg.exitCode !== null && msg.exitCode !== undefined && msg.exitCode !== 0) {
|
||||
text += `\n\nCommand exited with code ${msg.exitCode}`;
|
||||
}
|
||||
if (msg.truncated && msg.fullOutputPath) {
|
||||
|
|
@ -69,34 +95,95 @@ export function bashExecutionToText(msg: BashExecutionMessage): string {
|
|||
return text;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Transformer
|
||||
// ============================================================================
|
||||
export function createBranchSummaryMessage(summary: string, fromId: string, timestamp: string): BranchSummaryMessage {
|
||||
return {
|
||||
role: "branchSummary",
|
||||
summary,
|
||||
fromId,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
export function createCompactionSummaryMessage(
|
||||
summary: string,
|
||||
tokensBefore: number,
|
||||
timestamp: string,
|
||||
): CompactionSummaryMessage {
|
||||
return {
|
||||
role: "compactionSummary",
|
||||
summary: summary,
|
||||
tokensBefore,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert CustomMessageEntry to AgentMessage format */
|
||||
export function createHookMessage(
|
||||
customType: string,
|
||||
content: string | (TextContent | ImageContent)[],
|
||||
display: boolean,
|
||||
details: unknown | undefined,
|
||||
timestamp: string,
|
||||
): HookMessage {
|
||||
return {
|
||||
role: "hookMessage",
|
||||
customType,
|
||||
content,
|
||||
display,
|
||||
details,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform AppMessages (including custom types) to LLM-compatible Messages.
|
||||
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
|
||||
*
|
||||
* This is used by:
|
||||
* - Agent's messageTransformer option (for prompt calls)
|
||||
* - Agent's transormToLlm option (for prompt calls and queued messages)
|
||||
* - Compaction's generateSummary (for summarization)
|
||||
* - Custom hooks and tools
|
||||
*/
|
||||
export function messageTransformer(messages: AppMessage[]): Message[] {
|
||||
export function convertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages
|
||||
.map((m): Message | null => {
|
||||
if (isBashExecutionMessage(m)) {
|
||||
// Convert bash execution to user message
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: bashExecutionToText(m) }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
.map((m): Message | undefined => {
|
||||
switch (m.role) {
|
||||
case "bashExecution":
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: bashExecutionToText(m) }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "hookMessage": {
|
||||
const content = typeof m.content === "string" ? [{ type: "text" as const, text: m.content }] : m.content;
|
||||
return {
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
case "branchSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "compactionSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text" as const, text: COMPACTION_SUMMARY_PREFIX + m.summary + COMPACTION_SUMMARY_SUFFIX },
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "user":
|
||||
case "assistant":
|
||||
case "toolResult":
|
||||
return m;
|
||||
default:
|
||||
// biome-ignore lint/correctness/noSwitchDeclarations: fine
|
||||
const _exhaustiveCheck: never = m;
|
||||
return undefined;
|
||||
}
|
||||
// Pass through standard LLM roles
|
||||
if (m.role === "user" || m.role === "assistant" || m.role === "toolResult") {
|
||||
return m as Message;
|
||||
}
|
||||
// Filter out unknown message types
|
||||
return null;
|
||||
})
|
||||
.filter((m): m is Message => m !== null);
|
||||
.filter((m) => m !== undefined);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,11 +90,11 @@ function resolveApiKeyConfig(keyConfig: string): string | undefined {
|
|||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private loadError: string | null = null;
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | null = null,
|
||||
private modelsJsonPath: string | undefined = undefined,
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
|
|
@ -114,14 +114,14 @@ export class ModelRegistry {
|
|||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = null;
|
||||
this.loadError = undefined;
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (null if no error).
|
||||
* Get any error from loading models.json (undefined if no error).
|
||||
*/
|
||||
getError(): string | null {
|
||||
getError(): string | undefined {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
|
|
@ -160,9 +160,9 @@ export class ModelRegistry {
|
|||
}
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | null } {
|
||||
private loadCustomModels(modelsJsonPath: string): { models: Model<Api>[]; error: string | undefined } {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return { models: [], error: null };
|
||||
return { models: [], error: undefined };
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -186,7 +186,7 @@ export class ModelRegistry {
|
|||
this.validateConfig(config);
|
||||
|
||||
// Parse models
|
||||
return { models: this.parseModels(config), error: null };
|
||||
return { models: this.parseModels(config), error: undefined };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return {
|
||||
|
|
@ -294,14 +294,14 @@ export class ModelRegistry {
|
|||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | null {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? null;
|
||||
find(provider: string, modelId: string): Model<Api> | undefined {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId) ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | null> {
|
||||
async getApiKey(model: Model<Api>): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ function isAlias(id: string): boolean {
|
|||
|
||||
/**
|
||||
* Try to match a pattern to a model from the available models list.
|
||||
* Returns the matched model or null if no match found.
|
||||
* Returns the matched model or undefined if no match found.
|
||||
*/
|
||||
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | null {
|
||||
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | undefined {
|
||||
// Check for provider/modelId format (provider is everything before the first /)
|
||||
const slashIndex = modelPattern.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
|
|
@ -75,7 +75,7 @@ function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Mod
|
|||
);
|
||||
|
||||
if (matches.length === 0) {
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Separate into aliases and dated versions
|
||||
|
|
@ -94,9 +94,9 @@ function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Mod
|
|||
}
|
||||
|
||||
export interface ParsedModelResult {
|
||||
model: Model<Api> | null;
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
warning: string | null;
|
||||
warning: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -116,14 +116,14 @@ export function parseModelPattern(pattern: string, availableModels: Model<Api>[]
|
|||
// Try exact match first
|
||||
const exactMatch = tryMatchModel(pattern, availableModels);
|
||||
if (exactMatch) {
|
||||
return { model: exactMatch, thinkingLevel: "off", warning: null };
|
||||
return { model: exactMatch, thinkingLevel: "off", warning: undefined };
|
||||
}
|
||||
|
||||
// No match - try splitting on last colon if present
|
||||
const lastColonIndex = pattern.lastIndexOf(":");
|
||||
if (lastColonIndex === -1) {
|
||||
// No colons, pattern simply doesn't match any model
|
||||
return { model: null, thinkingLevel: "off", warning: null };
|
||||
return { model: undefined, thinkingLevel: "off", warning: undefined };
|
||||
}
|
||||
|
||||
const prefix = pattern.substring(0, lastColonIndex);
|
||||
|
|
@ -193,9 +193,9 @@ export async function resolveModelScope(patterns: string[], modelRegistry: Model
|
|||
}
|
||||
|
||||
export interface InitialModelResult {
|
||||
model: Model<Api> | null;
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
fallbackMessage: string | null;
|
||||
fallbackMessage: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -227,7 +227,7 @@ export async function findInitialModel(options: {
|
|||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | null = null;
|
||||
let model: Model<Api> | undefined;
|
||||
let thinkingLevel: ThinkingLevel = "off";
|
||||
|
||||
// 1. CLI args take priority
|
||||
|
|
@ -237,7 +237,7 @@ export async function findInitialModel(options: {
|
|||
console.error(chalk.red(`Model ${cliProvider}/${cliModel} not found`));
|
||||
process.exit(1);
|
||||
}
|
||||
return { model: found, thinkingLevel: "off", fallbackMessage: null };
|
||||
return { model: found, thinkingLevel: "off", fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// 2. Use first model from scoped models (skip if continuing/resuming)
|
||||
|
|
@ -245,7 +245,7 @@ export async function findInitialModel(options: {
|
|||
return {
|
||||
model: scopedModels[0].model,
|
||||
thinkingLevel: scopedModels[0].thinkingLevel,
|
||||
fallbackMessage: null,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -257,7 +257,7 @@ export async function findInitialModel(options: {
|
|||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: null };
|
||||
return { model, thinkingLevel, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -270,16 +270,16 @@ export async function findInitialModel(options: {
|
|||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
||||
if (match) {
|
||||
return { model: match, thinkingLevel: "off", fallbackMessage: null };
|
||||
return { model: match, thinkingLevel: "off", fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: null };
|
||||
return { model: availableModels[0], thinkingLevel: "off", fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// 5. No model found
|
||||
return { model: null, thinkingLevel: "off", fallbackMessage: null };
|
||||
return { model: undefined, thinkingLevel: "off", fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -288,10 +288,10 @@ export async function findInitialModel(options: {
|
|||
export async function restoreModelFromSession(
|
||||
savedProvider: string,
|
||||
savedModelId: string,
|
||||
currentModel: Model<Api> | null,
|
||||
currentModel: Model<Api> | undefined,
|
||||
shouldPrintMessages: boolean,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{ model: Model<Api> | null; fallbackMessage: string | null }> {
|
||||
): Promise<{ model: Model<Api> | undefined; fallbackMessage: string | undefined }> {
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
|
|
@ -301,7 +301,7 @@ export async function restoreModelFromSession(
|
|||
if (shouldPrintMessages) {
|
||||
console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`));
|
||||
}
|
||||
return { model: restoredModel, fallbackMessage: null };
|
||||
return { model: restoredModel, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// Model not found or no API key - fall back
|
||||
|
|
@ -327,7 +327,7 @@ export async function restoreModelFromSession(
|
|||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
let fallbackModel: Model<Api> | null = null;
|
||||
let fallbackModel: Model<Api> | undefined;
|
||||
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
||||
|
|
@ -353,5 +353,5 @@ export async function restoreModelFromSession(
|
|||
}
|
||||
|
||||
// No models available
|
||||
return { model: null, fallbackMessage: null };
|
||||
return { model: undefined, fallbackMessage: undefined };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@
|
|||
* ```
|
||||
*/
|
||||
|
||||
import { Agent, ProviderTransport, type ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import { Agent, type ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
|
|
@ -39,7 +39,7 @@ import { discoverAndLoadCustomTools, type LoadedCustomTool } from "./custom-tool
|
|||
import type { CustomAgentTool } from "./custom-tools/types.js";
|
||||
import { discoverAndLoadHooks, HookRunner, type LoadedHook, wrapToolsWithHooks } from "./hooks/index.js";
|
||||
import type { HookFactory } from "./hooks/types.js";
|
||||
import { messageTransformer } from "./messages.js";
|
||||
import { convertToLlm } from "./messages.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { type Settings, SettingsManager, type SkillsSettings } from "./settings-manager.js";
|
||||
|
|
@ -340,7 +340,10 @@ function createFactoryFromLoadedHook(loaded: LoadedHook): HookFactory {
|
|||
function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; factory: HookFactory }>): LoadedHook[] {
|
||||
return definitions.map((def) => {
|
||||
const handlers = new Map<string, Array<(...args: unknown[]) => Promise<unknown>>>();
|
||||
let sendHandler: (text: string, attachments?: any[]) => void = () => {};
|
||||
const messageRenderers = new Map<string, any>();
|
||||
const commands = new Map<string, any>();
|
||||
let sendMessageHandler: (message: any, triggerTurn?: boolean) => void = () => {};
|
||||
let appendEntryHandler: (customType: string, data?: any) => void = () => {};
|
||||
|
||||
const api = {
|
||||
on: (event: string, handler: (...args: unknown[]) => Promise<unknown>) => {
|
||||
|
|
@ -348,8 +351,17 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
|
|||
list.push(handler);
|
||||
handlers.set(event, list);
|
||||
},
|
||||
send: (text: string, attachments?: any[]) => {
|
||||
sendHandler(text, attachments);
|
||||
sendMessage: (message: any, triggerTurn?: boolean) => {
|
||||
sendMessageHandler(message, triggerTurn);
|
||||
},
|
||||
appendEntry: (customType: string, data?: any) => {
|
||||
appendEntryHandler(customType, data);
|
||||
},
|
||||
registerMessageRenderer: (customType: string, renderer: any) => {
|
||||
messageRenderers.set(customType, renderer);
|
||||
},
|
||||
registerCommand: (name: string, options: any) => {
|
||||
commands.set(name, { name, ...options });
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -359,8 +371,13 @@ function createLoadedHooksFromDefinitions(definitions: Array<{ path?: string; fa
|
|||
path: def.path ?? "<inline>",
|
||||
resolvedPath: def.path ?? "<inline>",
|
||||
handlers,
|
||||
setSendHandler: (handler: (text: string, attachments?: any[]) => void) => {
|
||||
sendHandler = handler;
|
||||
messageRenderers,
|
||||
commands,
|
||||
setSendMessageHandler: (handler: (message: any, triggerTurn?: boolean) => void) => {
|
||||
sendMessageHandler = handler;
|
||||
},
|
||||
setAppendEntryHandler: (handler: (customType: string, data?: any) => void) => {
|
||||
appendEntryHandler = handler;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
@ -513,11 +530,11 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
customToolsResult = result;
|
||||
}
|
||||
|
||||
let hookRunner: HookRunner | null = null;
|
||||
let hookRunner: HookRunner | undefined;
|
||||
if (options.hooks !== undefined) {
|
||||
if (options.hooks.length > 0) {
|
||||
const loadedHooks = createLoadedHooksFromDefinitions(options.hooks);
|
||||
hookRunner = new HookRunner(loadedHooks, cwd, settingsManager.getHookTimeout());
|
||||
hookRunner = new HookRunner(loadedHooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout());
|
||||
}
|
||||
} else {
|
||||
// Discover hooks, merging with additional paths
|
||||
|
|
@ -528,7 +545,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
console.error(`Failed to load hook "${path}": ${error}`);
|
||||
}
|
||||
if (hooks.length > 0) {
|
||||
hookRunner = new HookRunner(hooks, cwd, settingsManager.getHookTimeout());
|
||||
hookRunner = new HookRunner(hooks, cwd, sessionManager, modelRegistry, settingsManager.getHookTimeout());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -571,21 +588,24 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
|||
thinkingLevel,
|
||||
tools: allToolsArray,
|
||||
},
|
||||
messageTransformer,
|
||||
convertToLlm,
|
||||
transformContext: hookRunner
|
||||
? async (messages) => {
|
||||
return hookRunner.emitContext(messages);
|
||||
}
|
||||
: undefined,
|
||||
queueMode: settingsManager.getQueueMode(),
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async () => {
|
||||
const currentModel = agent.state.model;
|
||||
if (!currentModel) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await modelRegistry.getApiKey(currentModel);
|
||||
if (!key) {
|
||||
throw new Error(`No API key found for provider "${currentModel.provider}"`);
|
||||
}
|
||||
return key;
|
||||
},
|
||||
}),
|
||||
getApiKey: async () => {
|
||||
const currentModel = agent.state.model;
|
||||
if (!currentModel) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await modelRegistry.getApiKey(currentModel);
|
||||
if (!key) {
|
||||
throw new Error(`No API key found for provider "${currentModel.provider}"`);
|
||||
}
|
||||
return key;
|
||||
},
|
||||
});
|
||||
time("createAgent");
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -8,6 +8,10 @@ export interface CompactionSettings {
|
|||
keepRecentTokens?: number; // default: 20000
|
||||
}
|
||||
|
||||
export interface BranchSummarySettings {
|
||||
reserveTokens?: number; // default: 16384 (tokens reserved for prompt + LLM response)
|
||||
}
|
||||
|
||||
export interface RetrySettings {
|
||||
enabled?: boolean; // default: true
|
||||
maxRetries?: number; // default: 3
|
||||
|
|
@ -38,6 +42,7 @@ export interface Settings {
|
|||
queueMode?: "all" | "one-at-a-time";
|
||||
theme?: string;
|
||||
compaction?: CompactionSettings;
|
||||
branchSummary?: BranchSummarySettings;
|
||||
retry?: RetrySettings;
|
||||
hideThinkingBlock?: boolean;
|
||||
shellPath?: string; // Custom shell path (e.g., for Cygwin users on Windows)
|
||||
|
|
@ -255,6 +260,12 @@ export class SettingsManager {
|
|||
};
|
||||
}
|
||||
|
||||
getBranchSummarySettings(): { reserveTokens: number } {
|
||||
return {
|
||||
reserveTokens: this.settings.branchSummary?.reserveTokens ?? 16384,
|
||||
};
|
||||
}
|
||||
|
||||
getRetryEnabled(): boolean {
|
||||
return this.settings.retry?.enabled ?? true;
|
||||
}
|
||||
|
|
@ -303,7 +314,7 @@ export class SettingsManager {
|
|||
}
|
||||
|
||||
getHookPaths(): string[] {
|
||||
return this.settings.hooks ?? [];
|
||||
return [...(this.settings.hooks ?? [])];
|
||||
}
|
||||
|
||||
setHookPaths(paths: string[]): void {
|
||||
|
|
@ -321,7 +332,7 @@ export class SettingsManager {
|
|||
}
|
||||
|
||||
getCustomToolPaths(): string[] {
|
||||
return this.settings.customTools ?? [];
|
||||
return [...(this.settings.customTools ?? [])];
|
||||
}
|
||||
|
||||
setCustomToolPaths(paths: string[]): void {
|
||||
|
|
@ -349,9 +360,9 @@ export class SettingsManager {
|
|||
enableClaudeProject: this.settings.skills?.enableClaudeProject ?? true,
|
||||
enablePiUser: this.settings.skills?.enablePiUser ?? true,
|
||||
enablePiProject: this.settings.skills?.enablePiProject ?? true,
|
||||
customDirectories: this.settings.skills?.customDirectories ?? [],
|
||||
ignoredSkills: this.settings.skills?.ignoredSkills ?? [],
|
||||
includeSkills: this.settings.skills?.includeSkills ?? [],
|
||||
customDirectories: [...(this.settings.skills?.customDirectories ?? [])],
|
||||
ignoredSkills: [...(this.settings.skills?.ignoredSkills ?? [])],
|
||||
includeSkills: [...(this.settings.skills?.includeSkills ?? [])],
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { randomBytes } from "node:crypto";
|
|||
import { createWriteStream } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import { getShellConfig, killProcessTree } from "../../utils/shell.js";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import * as Diff from "diff";
|
||||
import { constants } from "fs";
|
||||
|
|
@ -23,8 +23,13 @@ function restoreLineEndings(text: string, ending: "\r\n" | "\n"): string {
|
|||
|
||||
/**
|
||||
* Generate a unified diff string with line numbers and context
|
||||
* Returns both the diff string and the first changed line number (in the new file)
|
||||
*/
|
||||
function generateDiffString(oldContent: string, newContent: string, contextLines = 4): string {
|
||||
function generateDiffString(
|
||||
oldContent: string,
|
||||
newContent: string,
|
||||
contextLines = 4,
|
||||
): { diff: string; firstChangedLine: number | undefined } {
|
||||
const parts = Diff.diffLines(oldContent, newContent);
|
||||
const output: string[] = [];
|
||||
|
||||
|
|
@ -36,6 +41,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
|
|||
let oldLineNum = 1;
|
||||
let newLineNum = 1;
|
||||
let lastWasChange = false;
|
||||
let firstChangedLine: number | undefined;
|
||||
|
||||
for (let i = 0; i < parts.length; i++) {
|
||||
const part = parts[i];
|
||||
|
|
@ -45,6 +51,11 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
|
|||
}
|
||||
|
||||
if (part.added || part.removed) {
|
||||
// Capture the first changed line (in the new file)
|
||||
if (firstChangedLine === undefined) {
|
||||
firstChangedLine = newLineNum;
|
||||
}
|
||||
|
||||
// Show the change
|
||||
for (const line of raw) {
|
||||
if (part.added) {
|
||||
|
|
@ -113,7 +124,7 @@ function generateDiffString(oldContent: string, newContent: string, contextLines
|
|||
}
|
||||
}
|
||||
|
||||
return output.join("\n");
|
||||
return { diff: output.join("\n"), firstChangedLine };
|
||||
}
|
||||
|
||||
const editSchema = Type.Object({
|
||||
|
|
@ -122,6 +133,13 @@ const editSchema = Type.Object({
|
|||
newText: Type.String({ description: "New text to replace the old text with" }),
|
||||
});
|
||||
|
||||
export interface EditToolDetails {
|
||||
/** Unified diff of the changes made */
|
||||
diff: string;
|
||||
/** Line number of the first change in the new file (for editor navigation) */
|
||||
firstChangedLine?: number;
|
||||
}
|
||||
|
||||
export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
|
||||
return {
|
||||
name: "edit",
|
||||
|
|
@ -138,7 +156,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
|
|||
|
||||
return new Promise<{
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: { diff: string } | undefined;
|
||||
details: EditToolDetails | undefined;
|
||||
}>((resolve, reject) => {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
|
|
@ -257,6 +275,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
|
|||
signal.removeEventListener("abort", onAbort);
|
||||
}
|
||||
|
||||
const diffResult = generateDiffString(normalizedContent, normalizedNewContent);
|
||||
resolve({
|
||||
content: [
|
||||
{
|
||||
|
|
@ -264,7 +283,7 @@ export function createEditTool(cwd: string): AgentTool<typeof editSchema> {
|
|||
text: `Successfully replaced text in ${path}.`,
|
||||
},
|
||||
],
|
||||
details: { diff: generateDiffString(normalizedContent, normalizedNewContent) },
|
||||
details: { diff: diffResult.diff, firstChangedLine: diffResult.firstChangedLine },
|
||||
});
|
||||
} catch (error: any) {
|
||||
// Clean up abort handler
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { spawnSync } from "child_process";
|
||||
import { existsSync } from "fs";
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { createInterface } from "node:readline";
|
||||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { spawn } from "child_process";
|
||||
import { readFileSync, type Stats, statSync } from "fs";
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
|
||||
export { type BashToolDetails, bashTool, createBashTool } from "./bash.js";
|
||||
export { createEditTool, editTool } from "./edit.js";
|
||||
export { createFindTool, type FindToolDetails, findTool } from "./find.js";
|
||||
|
|
@ -9,6 +7,7 @@ export { createReadTool, type ReadToolDetails, readTool } from "./read.js";
|
|||
export type { TruncationResult } from "./truncate.js";
|
||||
export { createWriteTool, writeTool } from "./write.js";
|
||||
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { bashTool, createBashTool } from "./bash.js";
|
||||
import { createEditTool, editTool } from "./edit.js";
|
||||
import { createFindTool, findTool } from "./find.js";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { existsSync, readdirSync, statSync } from "fs";
|
||||
import nodePath from "path";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import type { AgentTool, ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { constants } from "fs";
|
||||
import { access, readFile } from "fs/promises";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import type { AgentTool } from "@mariozechner/pi-ai";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { mkdir, writeFile } from "fs/promises";
|
||||
import { dirname } from "path";
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ export {
|
|||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type CompactionResult,
|
||||
type ModelCycleResult,
|
||||
type PromptOptions,
|
||||
type SessionStats,
|
||||
|
|
@ -13,17 +12,26 @@ export {
|
|||
export { type ApiKeyCredential, type AuthCredential, AuthStorage, type OAuthCredential } from "./core/auth-storage.js";
|
||||
// Compaction
|
||||
export {
|
||||
type BranchPreparation,
|
||||
type BranchSummaryResult,
|
||||
type CollectEntriesResult,
|
||||
type CompactionResult,
|
||||
type CutPointResult,
|
||||
calculateContextTokens,
|
||||
collectEntriesForBranchSummary,
|
||||
compact,
|
||||
DEFAULT_COMPACTION_SETTINGS,
|
||||
estimateTokens,
|
||||
type FileOperations,
|
||||
findCutPoint,
|
||||
findTurnStartIndex,
|
||||
type GenerateBranchSummaryOptions,
|
||||
generateBranchSummary,
|
||||
generateSummary,
|
||||
getLastAssistantUsage,
|
||||
prepareBranchEntries,
|
||||
shouldCompact,
|
||||
} from "./core/compaction.js";
|
||||
} from "./core/compaction/index.js";
|
||||
// Custom tools
|
||||
export type {
|
||||
AgentToolUpdateCallback,
|
||||
|
|
@ -38,31 +46,7 @@ export type {
|
|||
ToolUIContext,
|
||||
} from "./core/custom-tools/index.js";
|
||||
export { discoverAndLoadCustomTools, loadCustomTools } from "./core/custom-tools/index.js";
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
BashToolResultEvent,
|
||||
CustomToolResultEvent,
|
||||
EditToolResultEvent,
|
||||
FindToolResultEvent,
|
||||
GrepToolResultEvent,
|
||||
HookAPI,
|
||||
HookEvent,
|
||||
HookEventContext,
|
||||
HookFactory,
|
||||
HookUIContext,
|
||||
LsToolResultEvent,
|
||||
ReadToolResultEvent,
|
||||
SessionEvent,
|
||||
SessionEventResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
WriteToolResultEvent,
|
||||
} from "./core/hooks/index.js";
|
||||
export type * from "./core/hooks/index.js";
|
||||
// Hook system types and type guards
|
||||
export {
|
||||
isBashToolResult,
|
||||
|
|
@ -73,7 +57,7 @@ export {
|
|||
isReadToolResult,
|
||||
isWriteToolResult,
|
||||
} from "./core/hooks/index.js";
|
||||
export { messageTransformer } from "./core/messages.js";
|
||||
export { convertToLlm } from "./core/messages.js";
|
||||
export { ModelRegistry } from "./core/model-registry.js";
|
||||
// SDK for programmatic usage
|
||||
export {
|
||||
|
|
@ -107,20 +91,24 @@ export {
|
|||
readOnlyTools,
|
||||
} from "./core/sdk.js";
|
||||
export {
|
||||
type BranchSummaryEntry,
|
||||
buildSessionContext,
|
||||
type CompactionEntry,
|
||||
createSummaryMessage,
|
||||
CURRENT_SESSION_VERSION,
|
||||
type CustomEntry,
|
||||
type CustomMessageEntry,
|
||||
type FileEntry,
|
||||
getLatestCompactionEntry,
|
||||
type ModelChangeEntry,
|
||||
migrateSessionEntries,
|
||||
parseSessionEntries,
|
||||
type SessionContext as LoadedSession,
|
||||
type SessionContext,
|
||||
type SessionEntry,
|
||||
type SessionEntryBase,
|
||||
type SessionHeader,
|
||||
type SessionInfo,
|
||||
SessionManager,
|
||||
type SessionMessageEntry,
|
||||
SUMMARY_PREFIX,
|
||||
SUMMARY_SUFFIX,
|
||||
type ThinkingLevelChangeEntry,
|
||||
} from "./core/session-manager.js";
|
||||
export {
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@
|
|||
* createAgentSession() options. The SDK does the heavy lifting.
|
||||
*/
|
||||
|
||||
import type { Attachment } from "@mariozechner/pi-agent-core";
|
||||
import { supportsXhigh } from "@mariozechner/pi-ai";
|
||||
import { type ImageContent, supportsXhigh } from "@mariozechner/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { existsSync } from "fs";
|
||||
import { join } from "path";
|
||||
|
|
@ -34,10 +33,10 @@ import { initTheme, stopThemeWatcher } from "./modes/interactive/theme/theme.js"
|
|||
import { getChangelogPath, getNewEntries, parseChangelog } from "./utils/changelog.js";
|
||||
import { ensureTool } from "./utils/tools-manager.js";
|
||||
|
||||
async function checkForNewVersion(currentVersion: string): Promise<string | null> {
|
||||
async function checkForNewVersion(currentVersion: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch("https://registry.npmjs.org/@mariozechner/pi -coding-agent/latest");
|
||||
if (!response.ok) return null;
|
||||
if (!response.ok) return undefined;
|
||||
|
||||
const data = (await response.json()) as { version?: string };
|
||||
const latestVersion = data.version;
|
||||
|
|
@ -46,26 +45,26 @@ async function checkForNewVersion(currentVersion: string): Promise<string | null
|
|||
return latestVersion;
|
||||
}
|
||||
|
||||
return null;
|
||||
return undefined;
|
||||
} catch {
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async function runInteractiveMode(
|
||||
session: AgentSession,
|
||||
version: string,
|
||||
changelogMarkdown: string | null,
|
||||
changelogMarkdown: string | undefined,
|
||||
modelFallbackMessage: string | undefined,
|
||||
modelsJsonError: string | null,
|
||||
modelsJsonError: string | undefined,
|
||||
migratedProviders: string[],
|
||||
versionCheckPromise: Promise<string | null>,
|
||||
versionCheckPromise: Promise<string | undefined>,
|
||||
initialMessages: string[],
|
||||
customTools: LoadedCustomTool[],
|
||||
setToolUIContext: (uiContext: HookUIContext, hasUI: boolean) => void,
|
||||
initialMessage?: string,
|
||||
initialAttachments?: Attachment[],
|
||||
fdPath: string | null = null,
|
||||
initialImages?: ImageContent[],
|
||||
fdPath: string | undefined = undefined,
|
||||
): Promise<void> {
|
||||
const mode = new InteractiveMode(session, version, changelogMarkdown, customTools, setToolUIContext, fdPath);
|
||||
|
||||
|
|
@ -77,7 +76,7 @@ async function runInteractiveMode(
|
|||
}
|
||||
});
|
||||
|
||||
mode.renderInitialMessages(session.state);
|
||||
mode.renderInitialMessages();
|
||||
|
||||
if (migratedProviders.length > 0) {
|
||||
mode.showWarning(`Migrated credentials to auth.json: ${migratedProviders.join(", ")}`);
|
||||
|
|
@ -93,7 +92,7 @@ async function runInteractiveMode(
|
|||
|
||||
if (initialMessage) {
|
||||
try {
|
||||
await session.prompt(initialMessage, { attachments: initialAttachments });
|
||||
await session.prompt(initialMessage, { images: initialImages });
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error occurred";
|
||||
mode.showError(errorMessage);
|
||||
|
|
@ -122,31 +121,31 @@ async function runInteractiveMode(
|
|||
|
||||
async function prepareInitialMessage(parsed: Args): Promise<{
|
||||
initialMessage?: string;
|
||||
initialAttachments?: Attachment[];
|
||||
initialImages?: ImageContent[];
|
||||
}> {
|
||||
if (parsed.fileArgs.length === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const { textContent, imageAttachments } = await processFileArguments(parsed.fileArgs);
|
||||
const { text, images } = await processFileArguments(parsed.fileArgs);
|
||||
|
||||
let initialMessage: string;
|
||||
if (parsed.messages.length > 0) {
|
||||
initialMessage = textContent + parsed.messages[0];
|
||||
initialMessage = text + parsed.messages[0];
|
||||
parsed.messages.shift();
|
||||
} else {
|
||||
initialMessage = textContent;
|
||||
initialMessage = text;
|
||||
}
|
||||
|
||||
return {
|
||||
initialMessage,
|
||||
initialAttachments: imageAttachments.length > 0 ? imageAttachments : undefined,
|
||||
initialImages: images.length > 0 ? images : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | null {
|
||||
function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager): string | undefined {
|
||||
if (parsed.continue || parsed.resume) {
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const lastVersion = settingsManager.getLastChangelogVersion();
|
||||
|
|
@ -166,10 +165,10 @@ function getChangelogForDisplay(parsed: Args, settingsManager: SettingsManager):
|
|||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function createSessionManager(parsed: Args, cwd: string): SessionManager | null {
|
||||
function createSessionManager(parsed: Args, cwd: string): SessionManager | undefined {
|
||||
if (parsed.noSession) {
|
||||
return SessionManager.inMemory();
|
||||
}
|
||||
|
|
@ -184,8 +183,8 @@ function createSessionManager(parsed: Args, cwd: string): SessionManager | null
|
|||
if (parsed.sessionDir) {
|
||||
return SessionManager.create(cwd, parsed.sessionDir);
|
||||
}
|
||||
// Default case (new session) returns null, SDK will create one
|
||||
return null;
|
||||
// Default case (new session) returns undefined, SDK will create one
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Discover SYSTEM.md file if no CLI system prompt was provided */
|
||||
|
|
@ -208,7 +207,7 @@ function discoverSystemPromptFile(): string | undefined {
|
|||
function buildSessionOptions(
|
||||
parsed: Args,
|
||||
scopedModels: ScopedModel[],
|
||||
sessionManager: SessionManager | null,
|
||||
sessionManager: SessionManager | undefined,
|
||||
modelRegistry: ModelRegistry,
|
||||
): CreateAgentSessionOptions {
|
||||
const options: CreateAgentSessionOptions = {};
|
||||
|
|
@ -330,7 +329,7 @@ export async function main(args: string[]) {
|
|||
}
|
||||
|
||||
const cwd = process.cwd();
|
||||
const { initialMessage, initialAttachments } = await prepareInitialMessage(parsed);
|
||||
const { initialMessage, initialImages } = await prepareInitialMessage(parsed);
|
||||
time("prepareInitialMessage");
|
||||
const isInteractive = !parsed.print && parsed.mode === undefined;
|
||||
const mode = parsed.mode || "text";
|
||||
|
|
@ -409,7 +408,7 @@ export async function main(args: string[]) {
|
|||
if (mode === "rpc") {
|
||||
await runRpcMode(session);
|
||||
} else if (isInteractive) {
|
||||
const versionCheckPromise = checkForNewVersion(VERSION).catch(() => null);
|
||||
const versionCheckPromise = checkForNewVersion(VERSION).catch(() => undefined);
|
||||
const changelogMarkdown = getChangelogForDisplay(parsed, settingsManager);
|
||||
|
||||
if (scopedModels.length > 0) {
|
||||
|
|
@ -438,11 +437,11 @@ export async function main(args: string[]) {
|
|||
customToolsResult.tools,
|
||||
customToolsResult.setUIContext,
|
||||
initialMessage,
|
||||
initialAttachments,
|
||||
initialImages,
|
||||
fdPath,
|
||||
);
|
||||
} else {
|
||||
await runPrintMode(session, mode, parsed.messages, initialMessage, initialAttachments);
|
||||
await runPrintMode(session, mode, parsed.messages, initialMessage, initialImages);
|
||||
stopThemeWatcher();
|
||||
if (process.stdout.writableLength > 0) {
|
||||
await new Promise<void>((resolve) => process.stdout.once("drain", resolve));
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ export class BashExecutionComponent extends Container {
|
|||
private command: string;
|
||||
private outputLines: string[] = [];
|
||||
private status: "running" | "complete" | "cancelled" | "error" = "running";
|
||||
private exitCode: number | null = null;
|
||||
private exitCode: number | undefined = undefined;
|
||||
private loader: Loader;
|
||||
private truncationResult?: TruncationResult;
|
||||
private fullOutputPath?: string;
|
||||
|
|
@ -90,13 +90,17 @@ export class BashExecutionComponent extends Container {
|
|||
}
|
||||
|
||||
setComplete(
|
||||
exitCode: number | null,
|
||||
exitCode: number | undefined,
|
||||
cancelled: boolean,
|
||||
truncationResult?: TruncationResult,
|
||||
fullOutputPath?: string,
|
||||
): void {
|
||||
this.exitCode = exitCode;
|
||||
this.status = cancelled ? "cancelled" : exitCode !== 0 && exitCode !== null ? "error" : "complete";
|
||||
this.status = cancelled
|
||||
? "cancelled"
|
||||
: exitCode !== 0 && exitCode !== undefined && exitCode !== null
|
||||
? "error"
|
||||
: "complete";
|
||||
this.truncationResult = truncationResult;
|
||||
this.fullOutputPath = fullOutputPath;
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
import type { BranchSummaryMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a branch summary message with collapsed/expanded state.
|
||||
* Uses same background color as hook messages for visual consistency.
|
||||
*/
|
||||
export class BranchSummaryMessageComponent extends Box {
|
||||
private expanded = false;
|
||||
private message: BranchSummaryMessage;
|
||||
|
||||
constructor(message: BranchSummaryMessage) {
|
||||
super(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
this.message = message;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[branch]\x1b[22m`);
|
||||
this.addChild(new Text(label, 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (this.expanded) {
|
||||
const header = "**Branch Summary**\n\n";
|
||||
this.addChild(
|
||||
new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
this.addChild(new Text(theme.fg("customMessageText", "Branch summary (ctrl+o to expand)"), 0, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import { Box, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
import type { CompactionSummaryMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a compaction message with collapsed/expanded state.
|
||||
* Uses same background color as hook messages for visual consistency.
|
||||
*/
|
||||
export class CompactionSummaryMessageComponent extends Box {
|
||||
private expanded = false;
|
||||
private message: CompactionSummaryMessage;
|
||||
|
||||
constructor(message: CompactionSummaryMessage) {
|
||||
super(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
this.message = message;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
const tokenStr = this.message.tokensBefore.toLocaleString();
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[compaction]\x1b[22m`);
|
||||
this.addChild(new Text(label, 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
if (this.expanded) {
|
||||
const header = `**Compacted from ${tokenStr} tokens**\n\n`;
|
||||
this.addChild(
|
||||
new Markdown(header + this.message.summary, 0, 0, getMarkdownTheme(), {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
this.addChild(
|
||||
new Text(theme.fg("customMessageText", `Compacted from ${tokenStr} tokens (ctrl+o to expand)`), 0, 0),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
import { Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a compaction indicator with collapsed/expanded state.
|
||||
* Collapsed: shows "Context compacted from X tokens"
|
||||
* Expanded: shows the full summary rendered as markdown (like a user message)
|
||||
*/
|
||||
export class CompactionComponent extends Container {
|
||||
private expanded = false;
|
||||
private tokensBefore: number;
|
||||
private summary: string;
|
||||
|
||||
constructor(tokensBefore: number, summary: string) {
|
||||
super();
|
||||
this.tokensBefore = tokensBefore;
|
||||
this.summary = summary;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
this.expanded = expanded;
|
||||
this.updateDisplay();
|
||||
}
|
||||
|
||||
private updateDisplay(): void {
|
||||
this.clear();
|
||||
|
||||
if (this.expanded) {
|
||||
// Show header + summary as markdown (like user message)
|
||||
this.addChild(new Spacer(1));
|
||||
const header = `**Context compacted from ${this.tokensBefore.toLocaleString()} tokens**\n\n`;
|
||||
this.addChild(
|
||||
new Markdown(header + this.summary, 1, 1, getMarkdownTheme(), {
|
||||
bgColor: (text: string) => theme.bg("userMessageBg", text),
|
||||
color: (text: string) => theme.fg("userMessageText", text),
|
||||
}),
|
||||
);
|
||||
this.addChild(new Spacer(1));
|
||||
} else {
|
||||
// Collapsed: simple text in warning color with token count
|
||||
const tokenStr = this.tokensBefore.toLocaleString();
|
||||
this.addChild(
|
||||
new Text(
|
||||
theme.fg("warning", `Earlier messages compacted from ${tokenStr} tokens (ctrl+o to expand)`),
|
||||
1,
|
||||
1,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import type { TextContent } from "@mariozechner/pi-ai";
|
||||
import type { Component } from "@mariozechner/pi-tui";
|
||||
import { Box, Container, Markdown, Spacer, Text } from "@mariozechner/pi-tui";
|
||||
import type { HookMessageRenderer } from "../../../core/hooks/types.js";
|
||||
import type { HookMessage } from "../../../core/messages.js";
|
||||
import { getMarkdownTheme, theme } from "../theme/theme.js";
|
||||
|
||||
/**
|
||||
* Component that renders a custom message entry from hooks.
|
||||
* Uses distinct styling to differentiate from user messages.
|
||||
*/
|
||||
export class HookMessageComponent extends Container {
|
||||
private message: HookMessage<unknown>;
|
||||
private customRenderer?: HookMessageRenderer;
|
||||
private box: Box;
|
||||
private customComponent?: Component;
|
||||
private _expanded = false;
|
||||
|
||||
constructor(message: HookMessage<unknown>, customRenderer?: HookMessageRenderer) {
|
||||
super();
|
||||
this.message = message;
|
||||
this.customRenderer = customRenderer;
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Create box with purple background (used for default rendering)
|
||||
this.box = new Box(1, 1, (t) => theme.bg("customMessageBg", t));
|
||||
|
||||
this.rebuild();
|
||||
}
|
||||
|
||||
setExpanded(expanded: boolean): void {
|
||||
if (this._expanded !== expanded) {
|
||||
this._expanded = expanded;
|
||||
this.rebuild();
|
||||
}
|
||||
}
|
||||
|
||||
private rebuild(): void {
|
||||
// Remove previous content component
|
||||
if (this.customComponent) {
|
||||
this.removeChild(this.customComponent);
|
||||
this.customComponent = undefined;
|
||||
}
|
||||
this.removeChild(this.box);
|
||||
|
||||
// Try custom renderer first - it handles its own styling
|
||||
if (this.customRenderer) {
|
||||
try {
|
||||
const component = this.customRenderer(this.message, { expanded: this._expanded }, theme);
|
||||
if (component) {
|
||||
// Custom renderer provides its own styled component
|
||||
this.customComponent = component;
|
||||
this.addChild(component);
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// Fall through to default rendering
|
||||
}
|
||||
}
|
||||
|
||||
// Default rendering uses our box
|
||||
this.addChild(this.box);
|
||||
this.box.clear();
|
||||
|
||||
// Default rendering: label + content
|
||||
const label = theme.fg("customMessageLabel", `\x1b[1m[${this.message.customType}]\x1b[22m`);
|
||||
this.box.addChild(new Text(label, 0, 0));
|
||||
this.box.addChild(new Spacer(1));
|
||||
|
||||
// Extract text content
|
||||
let text: string;
|
||||
if (typeof this.message.content === "string") {
|
||||
text = this.message.content;
|
||||
} else {
|
||||
text = this.message.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
// Limit lines when collapsed
|
||||
if (!this._expanded) {
|
||||
const lines = text.split("\n");
|
||||
if (lines.length > 5) {
|
||||
text = `${lines.slice(0, 5).join("\n")}\n...`;
|
||||
}
|
||||
}
|
||||
|
||||
this.box.addChild(
|
||||
new Markdown(text, 0, 0, getMarkdownTheme(), {
|
||||
color: (text: string) => theme.fg("customMessageText", text),
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -36,18 +36,18 @@ export class ModelSelectorComponent extends Container {
|
|||
private allModels: ModelItem[] = [];
|
||||
private filteredModels: ModelItem[] = [];
|
||||
private selectedIndex: number = 0;
|
||||
private currentModel: Model<any> | null;
|
||||
private currentModel?: Model<any>;
|
||||
private settingsManager: SettingsManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private onSelectCallback: (model: Model<any>) => void;
|
||||
private onCancelCallback: () => void;
|
||||
private errorMessage: string | null = null;
|
||||
private errorMessage?: string;
|
||||
private tui: TUI;
|
||||
private scopedModels: ReadonlyArray<ScopedModelItem>;
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
currentModel: Model<any> | null,
|
||||
currentModel: Model<any> | undefined,
|
||||
settingsManager: SettingsManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
scopedModels: ReadonlyArray<ScopedModelItem>,
|
||||
|
|
|
|||
|
|
@ -415,10 +415,14 @@ export class ToolExecutionComponent extends Container {
|
|||
} else if (this.toolName === "edit") {
|
||||
const rawPath = this.args?.file_path || this.args?.path || "";
|
||||
const path = shortenPath(rawPath);
|
||||
text =
|
||||
theme.fg("toolTitle", theme.bold("edit")) +
|
||||
" " +
|
||||
(path ? theme.fg("accent", path) : theme.fg("toolOutput", "..."));
|
||||
|
||||
// Build path display, appending :line if we have a successful result with line info
|
||||
let pathDisplay = path ? theme.fg("accent", path) : theme.fg("toolOutput", "...");
|
||||
if (this.result && !this.result.isError && this.result.details?.firstChangedLine) {
|
||||
pathDisplay += theme.fg("warning", `:${this.result.details.firstChangedLine}`);
|
||||
}
|
||||
|
||||
text = `${theme.fg("toolTitle", theme.bold("edit"))} ${pathDisplay}`;
|
||||
|
||||
if (this.result) {
|
||||
if (this.result.isError) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,866 @@
|
|||
import {
|
||||
type Component,
|
||||
Container,
|
||||
Input,
|
||||
isArrowDown,
|
||||
isArrowLeft,
|
||||
isArrowRight,
|
||||
isArrowUp,
|
||||
isBackspace,
|
||||
isCtrlC,
|
||||
isCtrlO,
|
||||
isEnter,
|
||||
isEscape,
|
||||
isShiftCtrlO,
|
||||
Spacer,
|
||||
Text,
|
||||
TruncatedText,
|
||||
truncateToWidth,
|
||||
} from "@mariozechner/pi-tui";
|
||||
import type { SessionTreeNode } from "../../../core/session-manager.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { DynamicBorder } from "./dynamic-border.js";
|
||||
|
||||
/** Gutter info: position (displayIndent where connector was) and whether to show │ */
|
||||
interface GutterInfo {
|
||||
position: number; // displayIndent level where the connector was shown
|
||||
show: boolean; // true = show │, false = show spaces
|
||||
}
|
||||
|
||||
/** Flattened tree node for navigation */
|
||||
interface FlatNode {
|
||||
node: SessionTreeNode;
|
||||
/** Indentation level (each level = 3 chars) */
|
||||
indent: number;
|
||||
/** Whether to show connector (├─ or └─) - true if parent has multiple children */
|
||||
showConnector: boolean;
|
||||
/** If showConnector, true = last sibling (└─), false = not last (├─) */
|
||||
isLast: boolean;
|
||||
/** Gutter info for each ancestor branch point */
|
||||
gutters: GutterInfo[];
|
||||
/** True if this node is a root under a virtual branching root (multiple roots) */
|
||||
isVirtualRootChild: boolean;
|
||||
}
|
||||
|
||||
/** Filter mode for tree display */
|
||||
type FilterMode = "default" | "no-tools" | "user-only" | "labeled-only" | "all";
|
||||
|
||||
/**
|
||||
* Tree list component with selection and ASCII art visualization
|
||||
*/
|
||||
/** Tool call info for lookup */
|
||||
interface ToolCallInfo {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
}
|
||||
|
||||
class TreeList implements Component {
|
||||
private flatNodes: FlatNode[] = [];
|
||||
private filteredNodes: FlatNode[] = [];
|
||||
private selectedIndex = 0;
|
||||
private currentLeafId: string | null;
|
||||
private maxVisibleLines: number;
|
||||
private filterMode: FilterMode = "default";
|
||||
private searchQuery = "";
|
||||
private toolCallMap: Map<string, ToolCallInfo> = new Map();
|
||||
private multipleRoots = false;
|
||||
private activePathIds: Set<string> = new Set();
|
||||
|
||||
public onSelect?: (entryId: string) => void;
|
||||
public onCancel?: () => void;
|
||||
public onLabelEdit?: (entryId: string, currentLabel: string | undefined) => void;
|
||||
|
||||
constructor(tree: SessionTreeNode[], currentLeafId: string | null, maxVisibleLines: number) {
|
||||
this.currentLeafId = currentLeafId;
|
||||
this.maxVisibleLines = maxVisibleLines;
|
||||
this.multipleRoots = tree.length > 1;
|
||||
this.flatNodes = this.flattenTree(tree);
|
||||
this.buildActivePath();
|
||||
this.applyFilter();
|
||||
|
||||
// Start with current leaf selected
|
||||
const leafIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === currentLeafId);
|
||||
if (leafIndex !== -1) {
|
||||
this.selectedIndex = leafIndex;
|
||||
} else {
|
||||
this.selectedIndex = Math.max(0, this.filteredNodes.length - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/** Build the set of entry IDs on the path from root to current leaf */
|
||||
private buildActivePath(): void {
|
||||
this.activePathIds.clear();
|
||||
if (!this.currentLeafId) return;
|
||||
|
||||
// Build a map of id -> entry for parent lookup
|
||||
const entryMap = new Map<string, FlatNode>();
|
||||
for (const flatNode of this.flatNodes) {
|
||||
entryMap.set(flatNode.node.entry.id, flatNode);
|
||||
}
|
||||
|
||||
// Walk from leaf to root
|
||||
let currentId: string | null = this.currentLeafId;
|
||||
while (currentId) {
|
||||
this.activePathIds.add(currentId);
|
||||
const node = entryMap.get(currentId);
|
||||
if (!node) break;
|
||||
currentId = node.node.entry.parentId ?? null;
|
||||
}
|
||||
}
|
||||
|
||||
private flattenTree(roots: SessionTreeNode[]): FlatNode[] {
|
||||
const result: FlatNode[] = [];
|
||||
this.toolCallMap.clear();
|
||||
|
||||
// Indentation rules:
|
||||
// - At indent 0: stay at 0 unless parent has >1 children (then +1)
|
||||
// - At indent 1: children always go to indent 2 (visual grouping of subtree)
|
||||
// - At indent 2+: stay flat for single-child chains, +1 only if parent branches
|
||||
|
||||
// Stack items: [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild]
|
||||
type StackItem = [SessionTreeNode, number, boolean, boolean, boolean, GutterInfo[], boolean];
|
||||
const stack: StackItem[] = [];
|
||||
|
||||
// Determine which subtrees contain the active leaf (to sort current branch first)
|
||||
// Use iterative post-order traversal to avoid stack overflow
|
||||
const containsActive = new Map<SessionTreeNode, boolean>();
|
||||
const leafId = this.currentLeafId;
|
||||
{
|
||||
// Build list in pre-order, then process in reverse for post-order effect
|
||||
const allNodes: SessionTreeNode[] = [];
|
||||
const preOrderStack: SessionTreeNode[] = [...roots];
|
||||
while (preOrderStack.length > 0) {
|
||||
const node = preOrderStack.pop()!;
|
||||
allNodes.push(node);
|
||||
// Push children in reverse so they're processed left-to-right
|
||||
for (let i = node.children.length - 1; i >= 0; i--) {
|
||||
preOrderStack.push(node.children[i]);
|
||||
}
|
||||
}
|
||||
// Process in reverse (post-order): children before parents
|
||||
for (let i = allNodes.length - 1; i >= 0; i--) {
|
||||
const node = allNodes[i];
|
||||
let has = leafId !== null && node.entry.id === leafId;
|
||||
for (const child of node.children) {
|
||||
if (containsActive.get(child)) {
|
||||
has = true;
|
||||
}
|
||||
}
|
||||
containsActive.set(node, has);
|
||||
}
|
||||
}
|
||||
|
||||
// Add roots in reverse order, prioritizing the one containing the active leaf
|
||||
// If multiple roots, treat them as children of a virtual root that branches
|
||||
const multipleRoots = roots.length > 1;
|
||||
const orderedRoots = [...roots].sort((a, b) => Number(containsActive.get(b)) - Number(containsActive.get(a)));
|
||||
for (let i = orderedRoots.length - 1; i >= 0; i--) {
|
||||
const isLast = i === orderedRoots.length - 1;
|
||||
stack.push([orderedRoots[i], multipleRoots ? 1 : 0, multipleRoots, multipleRoots, isLast, [], multipleRoots]);
|
||||
}
|
||||
|
||||
while (stack.length > 0) {
|
||||
const [node, indent, justBranched, showConnector, isLast, gutters, isVirtualRootChild] = stack.pop()!;
|
||||
|
||||
// Extract tool calls from assistant messages for later lookup
|
||||
const entry = node.entry;
|
||||
if (entry.type === "message" && entry.message.role === "assistant") {
|
||||
const content = (entry.message as { content?: unknown }).content;
|
||||
if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (typeof block === "object" && block !== null && "type" in block && block.type === "toolCall") {
|
||||
const tc = block as { id: string; name: string; arguments: Record<string, unknown> };
|
||||
this.toolCallMap.set(tc.id, { name: tc.name, arguments: tc.arguments });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.push({ node, indent, showConnector, isLast, gutters, isVirtualRootChild });
|
||||
|
||||
const children = node.children;
|
||||
const multipleChildren = children.length > 1;
|
||||
|
||||
// Order children so the branch containing the active leaf comes first
|
||||
const orderedChildren = (() => {
|
||||
const prioritized: SessionTreeNode[] = [];
|
||||
const rest: SessionTreeNode[] = [];
|
||||
for (const child of children) {
|
||||
if (containsActive.get(child)) {
|
||||
prioritized.push(child);
|
||||
} else {
|
||||
rest.push(child);
|
||||
}
|
||||
}
|
||||
return [...prioritized, ...rest];
|
||||
})();
|
||||
|
||||
// Calculate child indent
|
||||
let childIndent: number;
|
||||
if (multipleChildren) {
|
||||
// Parent branches: children get +1
|
||||
childIndent = indent + 1;
|
||||
} else if (justBranched && indent > 0) {
|
||||
// First generation after a branch: +1 for visual grouping
|
||||
childIndent = indent + 1;
|
||||
} else {
|
||||
// Single-child chain: stay flat
|
||||
childIndent = indent;
|
||||
}
|
||||
|
||||
// Build gutters for children
|
||||
// If this node showed a connector, add a gutter entry for descendants
|
||||
// Only add gutter if connector is actually displayed (not suppressed for virtual root children)
|
||||
const connectorDisplayed = showConnector && !isVirtualRootChild;
|
||||
// When connector is displayed, add a gutter entry at the connector's position
|
||||
// Connector is at position (displayIndent - 1), so gutter should be there too
|
||||
const currentDisplayIndent = this.multipleRoots ? Math.max(0, indent - 1) : indent;
|
||||
const connectorPosition = Math.max(0, currentDisplayIndent - 1);
|
||||
const childGutters: GutterInfo[] = connectorDisplayed
|
||||
? [...gutters, { position: connectorPosition, show: !isLast }]
|
||||
: gutters;
|
||||
|
||||
// Add children in reverse order
|
||||
for (let i = orderedChildren.length - 1; i >= 0; i--) {
|
||||
const childIsLast = i === orderedChildren.length - 1;
|
||||
stack.push([
|
||||
orderedChildren[i],
|
||||
childIndent,
|
||||
multipleChildren,
|
||||
multipleChildren,
|
||||
childIsLast,
|
||||
childGutters,
|
||||
false,
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private applyFilter(): void {
|
||||
// Remember currently selected node to preserve cursor position
|
||||
const previouslySelectedId = this.filteredNodes[this.selectedIndex]?.node.entry.id;
|
||||
|
||||
const searchTokens = this.searchQuery.toLowerCase().split(/\s+/).filter(Boolean);
|
||||
|
||||
this.filteredNodes = this.flatNodes.filter((flatNode) => {
|
||||
const entry = flatNode.node.entry;
|
||||
const isCurrentLeaf = entry.id === this.currentLeafId;
|
||||
|
||||
// Skip assistant messages with only tool calls (no text) unless error/aborted
|
||||
// Always show current leaf so active position is visible
|
||||
if (entry.type === "message" && entry.message.role === "assistant" && !isCurrentLeaf) {
|
||||
const msg = entry.message as { stopReason?: string; content?: unknown };
|
||||
const hasText = this.hasTextContent(msg.content);
|
||||
const isErrorOrAborted = msg.stopReason && msg.stopReason !== "stop" && msg.stopReason !== "toolUse";
|
||||
// Only hide if no text AND not an error/aborted message
|
||||
if (!hasText && !isErrorOrAborted) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filter mode
|
||||
let passesFilter = true;
|
||||
// Entry types hidden in default view (settings/bookkeeping)
|
||||
const isSettingsEntry =
|
||||
entry.type === "label" ||
|
||||
entry.type === "custom" ||
|
||||
entry.type === "model_change" ||
|
||||
entry.type === "thinking_level_change";
|
||||
|
||||
switch (this.filterMode) {
|
||||
case "user-only":
|
||||
// Just user messages
|
||||
passesFilter = entry.type === "message" && entry.message.role === "user";
|
||||
break;
|
||||
case "no-tools":
|
||||
// Default minus tool results
|
||||
passesFilter = !isSettingsEntry && !(entry.type === "message" && entry.message.role === "toolResult");
|
||||
break;
|
||||
case "labeled-only":
|
||||
// Just labeled entries
|
||||
passesFilter = flatNode.node.label !== undefined;
|
||||
break;
|
||||
case "all":
|
||||
// Show everything
|
||||
passesFilter = true;
|
||||
break;
|
||||
default:
|
||||
// Default mode: hide settings/bookkeeping entries
|
||||
passesFilter = !isSettingsEntry;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!passesFilter) return false;
|
||||
|
||||
// Apply search filter
|
||||
if (searchTokens.length > 0) {
|
||||
const nodeText = this.getSearchableText(flatNode.node).toLowerCase();
|
||||
return searchTokens.every((token) => nodeText.includes(token));
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
// Try to preserve cursor on the same node after filtering
|
||||
if (previouslySelectedId) {
|
||||
const newIndex = this.filteredNodes.findIndex((n) => n.node.entry.id === previouslySelectedId);
|
||||
if (newIndex !== -1) {
|
||||
this.selectedIndex = newIndex;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back: clamp index if out of bounds
|
||||
if (this.selectedIndex >= this.filteredNodes.length) {
|
||||
this.selectedIndex = Math.max(0, this.filteredNodes.length - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/** Get searchable text content from a node */
|
||||
private getSearchableText(node: SessionTreeNode): string {
|
||||
const entry = node.entry;
|
||||
const parts: string[] = [];
|
||||
|
||||
if (node.label) {
|
||||
parts.push(node.label);
|
||||
}
|
||||
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const msg = entry.message;
|
||||
parts.push(msg.role);
|
||||
if ("content" in msg && msg.content) {
|
||||
parts.push(this.extractContent(msg.content));
|
||||
}
|
||||
if (msg.role === "bashExecution") {
|
||||
const bashMsg = msg as { command?: string };
|
||||
if (bashMsg.command) parts.push(bashMsg.command);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "custom_message": {
|
||||
parts.push(entry.customType);
|
||||
if (typeof entry.content === "string") {
|
||||
parts.push(entry.content);
|
||||
} else {
|
||||
parts.push(this.extractContent(entry.content));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "compaction":
|
||||
parts.push("compaction");
|
||||
break;
|
||||
case "branch_summary":
|
||||
parts.push("branch summary", entry.summary);
|
||||
break;
|
||||
case "model_change":
|
||||
parts.push("model", entry.modelId);
|
||||
break;
|
||||
case "thinking_level_change":
|
||||
parts.push("thinking", entry.thinkingLevel);
|
||||
break;
|
||||
case "custom":
|
||||
parts.push("custom", entry.customType);
|
||||
break;
|
||||
case "label":
|
||||
parts.push("label", entry.label ?? "");
|
||||
break;
|
||||
}
|
||||
|
||||
return parts.join(" ");
|
||||
}
|
||||
|
||||
invalidate(): void {}
|
||||
|
||||
getSearchQuery(): string {
|
||||
return this.searchQuery;
|
||||
}
|
||||
|
||||
getSelectedNode(): SessionTreeNode | undefined {
|
||||
return this.filteredNodes[this.selectedIndex]?.node;
|
||||
}
|
||||
|
||||
updateNodeLabel(entryId: string, label: string | undefined): void {
|
||||
for (const flatNode of this.flatNodes) {
|
||||
if (flatNode.node.entry.id === entryId) {
|
||||
flatNode.node.label = label;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private getFilterLabel(): string {
|
||||
switch (this.filterMode) {
|
||||
case "no-tools":
|
||||
return " [no-tools]";
|
||||
case "user-only":
|
||||
return " [user]";
|
||||
case "labeled-only":
|
||||
return " [labeled]";
|
||||
case "all":
|
||||
return " [all]";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
const lines: string[] = [];
|
||||
|
||||
if (this.filteredNodes.length === 0) {
|
||||
lines.push(truncateToWidth(theme.fg("muted", " No entries found"), width));
|
||||
lines.push(truncateToWidth(theme.fg("muted", ` (0/0)${this.getFilterLabel()}`), width));
|
||||
return lines;
|
||||
}
|
||||
|
||||
const startIndex = Math.max(
|
||||
0,
|
||||
Math.min(
|
||||
this.selectedIndex - Math.floor(this.maxVisibleLines / 2),
|
||||
this.filteredNodes.length - this.maxVisibleLines,
|
||||
),
|
||||
);
|
||||
const endIndex = Math.min(startIndex + this.maxVisibleLines, this.filteredNodes.length);
|
||||
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const flatNode = this.filteredNodes[i];
|
||||
const entry = flatNode.node.entry;
|
||||
const isSelected = i === this.selectedIndex;
|
||||
|
||||
// Build line: cursor + prefix + path marker + label + content
|
||||
const cursor = isSelected ? theme.fg("accent", "› ") : " ";
|
||||
|
||||
// If multiple roots, shift display (roots at 0, not 1)
|
||||
const displayIndent = this.multipleRoots ? Math.max(0, flatNode.indent - 1) : flatNode.indent;
|
||||
|
||||
// Build prefix with gutters at their correct positions
|
||||
// Each gutter has a position (displayIndent where its connector was shown)
|
||||
const connector =
|
||||
flatNode.showConnector && !flatNode.isVirtualRootChild ? (flatNode.isLast ? "└─ " : "├─ ") : "";
|
||||
const connectorPosition = connector ? displayIndent - 1 : -1;
|
||||
|
||||
// Build prefix char by char, placing gutters and connector at their positions
|
||||
const totalChars = displayIndent * 3;
|
||||
const prefixChars: string[] = [];
|
||||
for (let i = 0; i < totalChars; i++) {
|
||||
const level = Math.floor(i / 3);
|
||||
const posInLevel = i % 3;
|
||||
|
||||
// Check if there's a gutter at this level
|
||||
const gutter = flatNode.gutters.find((g) => g.position === level);
|
||||
if (gutter) {
|
||||
if (posInLevel === 0) {
|
||||
prefixChars.push(gutter.show ? "│" : " ");
|
||||
} else {
|
||||
prefixChars.push(" ");
|
||||
}
|
||||
} else if (connector && level === connectorPosition) {
|
||||
// Connector at this level
|
||||
if (posInLevel === 0) {
|
||||
prefixChars.push(flatNode.isLast ? "└" : "├");
|
||||
} else if (posInLevel === 1) {
|
||||
prefixChars.push("─");
|
||||
} else {
|
||||
prefixChars.push(" ");
|
||||
}
|
||||
} else {
|
||||
prefixChars.push(" ");
|
||||
}
|
||||
}
|
||||
const prefix = prefixChars.join("");
|
||||
|
||||
// Active path marker - shown right before the entry text
|
||||
const isOnActivePath = this.activePathIds.has(entry.id);
|
||||
const pathMarker = isOnActivePath ? theme.fg("accent", "• ") : "";
|
||||
|
||||
const label = flatNode.node.label ? theme.fg("warning", `[${flatNode.node.label}] `) : "";
|
||||
const content = this.getEntryDisplayText(flatNode.node, isSelected);
|
||||
|
||||
let line = cursor + theme.fg("dim", prefix) + pathMarker + label + content;
|
||||
if (isSelected) {
|
||||
line = theme.bg("selectedBg", line);
|
||||
}
|
||||
lines.push(truncateToWidth(line, width));
|
||||
}
|
||||
|
||||
lines.push(
|
||||
truncateToWidth(
|
||||
theme.fg("muted", ` (${this.selectedIndex + 1}/${this.filteredNodes.length})${this.getFilterLabel()}`),
|
||||
width,
|
||||
),
|
||||
);
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
private getEntryDisplayText(node: SessionTreeNode, isSelected: boolean): string {
|
||||
const entry = node.entry;
|
||||
let result: string;
|
||||
|
||||
const normalize = (s: string) => s.replace(/[\n\t]/g, " ").trim();
|
||||
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const msg = entry.message;
|
||||
const role = msg.role;
|
||||
if (role === "user") {
|
||||
const msgWithContent = msg as { content?: unknown };
|
||||
const content = normalize(this.extractContent(msgWithContent.content));
|
||||
result = theme.fg("accent", "user: ") + content;
|
||||
} else if (role === "assistant") {
|
||||
const msgWithContent = msg as { content?: unknown; stopReason?: string; errorMessage?: string };
|
||||
const textContent = normalize(this.extractContent(msgWithContent.content));
|
||||
if (textContent) {
|
||||
result = theme.fg("success", "assistant: ") + textContent;
|
||||
} else if (msgWithContent.stopReason === "aborted") {
|
||||
result = theme.fg("success", "assistant: ") + theme.fg("muted", "(aborted)");
|
||||
} else if (msgWithContent.errorMessage) {
|
||||
const errMsg = normalize(msgWithContent.errorMessage).slice(0, 80);
|
||||
result = theme.fg("success", "assistant: ") + theme.fg("error", errMsg);
|
||||
} else {
|
||||
result = theme.fg("success", "assistant: ") + theme.fg("muted", "(no content)");
|
||||
}
|
||||
} else if (role === "toolResult") {
|
||||
const toolMsg = msg as { toolCallId?: string; toolName?: string };
|
||||
const toolCall = toolMsg.toolCallId ? this.toolCallMap.get(toolMsg.toolCallId) : undefined;
|
||||
if (toolCall) {
|
||||
result = theme.fg("muted", this.formatToolCall(toolCall.name, toolCall.arguments));
|
||||
} else {
|
||||
result = theme.fg("muted", `[${toolMsg.toolName ?? "tool"}]`);
|
||||
}
|
||||
} else if (role === "bashExecution") {
|
||||
const bashMsg = msg as { command?: string };
|
||||
result = theme.fg("dim", `[bash]: ${normalize(bashMsg.command ?? "")}`);
|
||||
} else {
|
||||
result = theme.fg("dim", `[${role}]`);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "custom_message": {
|
||||
const content =
|
||||
typeof entry.content === "string"
|
||||
? entry.content
|
||||
: entry.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
result = theme.fg("customMessageLabel", `[${entry.customType}]: `) + normalize(content);
|
||||
break;
|
||||
}
|
||||
case "compaction": {
|
||||
const tokens = Math.round(entry.tokensBefore / 1000);
|
||||
result = theme.fg("borderAccent", `[compaction: ${tokens}k tokens]`);
|
||||
break;
|
||||
}
|
||||
case "branch_summary":
|
||||
result = theme.fg("warning", `[branch summary]: `) + normalize(entry.summary);
|
||||
break;
|
||||
case "model_change":
|
||||
result = theme.fg("dim", `[model: ${entry.modelId}]`);
|
||||
break;
|
||||
case "thinking_level_change":
|
||||
result = theme.fg("dim", `[thinking: ${entry.thinkingLevel}]`);
|
||||
break;
|
||||
case "custom":
|
||||
result = theme.fg("dim", `[custom: ${entry.customType}]`);
|
||||
break;
|
||||
case "label":
|
||||
result = theme.fg("dim", `[label: ${entry.label ?? "(cleared)"}]`);
|
||||
break;
|
||||
default:
|
||||
result = "";
|
||||
}
|
||||
|
||||
return isSelected ? theme.bold(result) : result;
|
||||
}
|
||||
|
||||
private extractContent(content: unknown): string {
|
||||
const maxLen = 200;
|
||||
if (typeof content === "string") return content.slice(0, maxLen);
|
||||
if (Array.isArray(content)) {
|
||||
let result = "";
|
||||
for (const c of content) {
|
||||
if (typeof c === "object" && c !== null && "type" in c && c.type === "text") {
|
||||
result += (c as { text: string }).text;
|
||||
if (result.length >= maxLen) return result.slice(0, maxLen);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
private hasTextContent(content: unknown): boolean {
|
||||
if (typeof content === "string") return content.trim().length > 0;
|
||||
if (Array.isArray(content)) {
|
||||
for (const c of content) {
|
||||
if (typeof c === "object" && c !== null && "type" in c && c.type === "text") {
|
||||
const text = (c as { text?: string }).text;
|
||||
if (text && text.trim().length > 0) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private formatToolCall(name: string, args: Record<string, unknown>): string {
|
||||
const shortenPath = (p: string): string => {
|
||||
const home = process.env.HOME || process.env.USERPROFILE || "";
|
||||
if (home && p.startsWith(home)) return `~${p.slice(home.length)}`;
|
||||
return p;
|
||||
};
|
||||
|
||||
switch (name) {
|
||||
case "read": {
|
||||
const path = shortenPath(String(args.path || args.file_path || ""));
|
||||
const offset = args.offset as number | undefined;
|
||||
const limit = args.limit as number | undefined;
|
||||
let display = path;
|
||||
if (offset !== undefined || limit !== undefined) {
|
||||
const start = offset ?? 1;
|
||||
const end = limit !== undefined ? start + limit - 1 : "";
|
||||
display += `:${start}${end ? `-${end}` : ""}`;
|
||||
}
|
||||
return `[read: ${display}]`;
|
||||
}
|
||||
case "write": {
|
||||
const path = shortenPath(String(args.path || args.file_path || ""));
|
||||
return `[write: ${path}]`;
|
||||
}
|
||||
case "edit": {
|
||||
const path = shortenPath(String(args.path || args.file_path || ""));
|
||||
return `[edit: ${path}]`;
|
||||
}
|
||||
case "bash": {
|
||||
const rawCmd = String(args.command || "");
|
||||
const cmd = rawCmd
|
||||
.replace(/[\n\t]/g, " ")
|
||||
.trim()
|
||||
.slice(0, 50);
|
||||
return `[bash: ${cmd}${rawCmd.length > 50 ? "..." : ""}]`;
|
||||
}
|
||||
case "grep": {
|
||||
const pattern = String(args.pattern || "");
|
||||
const path = shortenPath(String(args.path || "."));
|
||||
return `[grep: /${pattern}/ in ${path}]`;
|
||||
}
|
||||
case "find": {
|
||||
const pattern = String(args.pattern || "");
|
||||
const path = shortenPath(String(args.path || "."));
|
||||
return `[find: ${pattern} in ${path}]`;
|
||||
}
|
||||
case "ls": {
|
||||
const path = shortenPath(String(args.path || "."));
|
||||
return `[ls: ${path}]`;
|
||||
}
|
||||
default: {
|
||||
// Custom tool - show name and truncated JSON args
|
||||
const argsStr = JSON.stringify(args).slice(0, 40);
|
||||
return `[${name}: ${argsStr}${JSON.stringify(args).length > 40 ? "..." : ""}]`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
if (isArrowUp(keyData)) {
|
||||
this.selectedIndex = this.selectedIndex === 0 ? this.filteredNodes.length - 1 : this.selectedIndex - 1;
|
||||
} else if (isArrowDown(keyData)) {
|
||||
this.selectedIndex = this.selectedIndex === this.filteredNodes.length - 1 ? 0 : this.selectedIndex + 1;
|
||||
} else if (isArrowLeft(keyData)) {
|
||||
// Page up
|
||||
this.selectedIndex = Math.max(0, this.selectedIndex - this.maxVisibleLines);
|
||||
} else if (isArrowRight(keyData)) {
|
||||
// Page down
|
||||
this.selectedIndex = Math.min(this.filteredNodes.length - 1, this.selectedIndex + this.maxVisibleLines);
|
||||
} else if (isEnter(keyData)) {
|
||||
const selected = this.filteredNodes[this.selectedIndex];
|
||||
if (selected && this.onSelect) {
|
||||
this.onSelect(selected.node.entry.id);
|
||||
}
|
||||
} else if (isEscape(keyData)) {
|
||||
if (this.searchQuery) {
|
||||
this.searchQuery = "";
|
||||
this.applyFilter();
|
||||
} else {
|
||||
this.onCancel?.();
|
||||
}
|
||||
} else if (isCtrlC(keyData)) {
|
||||
this.onCancel?.();
|
||||
} else if (isShiftCtrlO(keyData)) {
|
||||
// Cycle filter backwards
|
||||
const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"];
|
||||
const currentIndex = modes.indexOf(this.filterMode);
|
||||
this.filterMode = modes[(currentIndex - 1 + modes.length) % modes.length];
|
||||
this.applyFilter();
|
||||
} else if (isCtrlO(keyData)) {
|
||||
// Cycle filter forwards: default → no-tools → user-only → labeled-only → all → default
|
||||
const modes: FilterMode[] = ["default", "no-tools", "user-only", "labeled-only", "all"];
|
||||
const currentIndex = modes.indexOf(this.filterMode);
|
||||
this.filterMode = modes[(currentIndex + 1) % modes.length];
|
||||
this.applyFilter();
|
||||
} else if (isBackspace(keyData)) {
|
||||
if (this.searchQuery.length > 0) {
|
||||
this.searchQuery = this.searchQuery.slice(0, -1);
|
||||
this.applyFilter();
|
||||
}
|
||||
} else if (keyData === "l" && !this.searchQuery) {
|
||||
const selected = this.filteredNodes[this.selectedIndex];
|
||||
if (selected && this.onLabelEdit) {
|
||||
this.onLabelEdit(selected.node.entry.id, selected.node.label);
|
||||
}
|
||||
} else {
|
||||
const hasControlChars = [...keyData].some((ch) => {
|
||||
const code = ch.charCodeAt(0);
|
||||
return code < 32 || code === 0x7f || (code >= 0x80 && code <= 0x9f);
|
||||
});
|
||||
if (!hasControlChars && keyData.length > 0) {
|
||||
this.searchQuery += keyData;
|
||||
this.applyFilter();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Component that displays the current search query */
|
||||
class SearchLine implements Component {
|
||||
constructor(private treeList: TreeList) {}
|
||||
|
||||
invalidate(): void {}
|
||||
|
||||
render(width: number): string[] {
|
||||
const query = this.treeList.getSearchQuery();
|
||||
if (query) {
|
||||
return [truncateToWidth(` ${theme.fg("muted", "Search:")} ${theme.fg("accent", query)}`, width)];
|
||||
}
|
||||
return [truncateToWidth(` ${theme.fg("muted", "Search:")}`, width)];
|
||||
}
|
||||
|
||||
handleInput(_keyData: string): void {}
|
||||
}
|
||||
|
||||
/** Label input component shown when editing a label */
|
||||
class LabelInput implements Component {
|
||||
private input: Input;
|
||||
private entryId: string;
|
||||
public onSubmit?: (entryId: string, label: string | undefined) => void;
|
||||
public onCancel?: () => void;
|
||||
|
||||
constructor(entryId: string, currentLabel: string | undefined) {
|
||||
this.entryId = entryId;
|
||||
this.input = new Input();
|
||||
if (currentLabel) {
|
||||
this.input.setValue(currentLabel);
|
||||
}
|
||||
}
|
||||
|
||||
invalidate(): void {}
|
||||
|
||||
render(width: number): string[] {
|
||||
const lines: string[] = [];
|
||||
const indent = " ";
|
||||
const availableWidth = width - indent.length;
|
||||
lines.push(truncateToWidth(`${indent}${theme.fg("muted", "Label (empty to remove):")}`, width));
|
||||
lines.push(...this.input.render(availableWidth).map((line) => truncateToWidth(`${indent}${line}`, width)));
|
||||
lines.push(truncateToWidth(`${indent}${theme.fg("dim", "enter: save esc: cancel")}`, width));
|
||||
return lines;
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
if (isEnter(keyData)) {
|
||||
const value = this.input.getValue().trim();
|
||||
this.onSubmit?.(this.entryId, value || undefined);
|
||||
} else if (isEscape(keyData)) {
|
||||
this.onCancel?.();
|
||||
} else {
|
||||
this.input.handleInput(keyData);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Component that renders a session tree selector for navigation
|
||||
*/
|
||||
export class TreeSelectorComponent extends Container {
|
||||
private treeList: TreeList;
|
||||
private labelInput: LabelInput | null = null;
|
||||
private labelInputContainer: Container;
|
||||
private treeContainer: Container;
|
||||
private onLabelChangeCallback?: (entryId: string, label: string | undefined) => void;
|
||||
|
||||
constructor(
|
||||
tree: SessionTreeNode[],
|
||||
currentLeafId: string | null,
|
||||
terminalHeight: number,
|
||||
onSelect: (entryId: string) => void,
|
||||
onCancel: () => void,
|
||||
onLabelChange?: (entryId: string, label: string | undefined) => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.onLabelChangeCallback = onLabelChange;
|
||||
const maxVisibleLines = Math.max(5, Math.floor(terminalHeight / 2));
|
||||
|
||||
this.treeList = new TreeList(tree, currentLeafId, maxVisibleLines);
|
||||
this.treeList.onSelect = onSelect;
|
||||
this.treeList.onCancel = onCancel;
|
||||
this.treeList.onLabelEdit = (entryId, currentLabel) => this.showLabelInput(entryId, currentLabel);
|
||||
|
||||
this.treeContainer = new Container();
|
||||
this.treeContainer.addChild(this.treeList);
|
||||
|
||||
this.labelInputContainer = new Container();
|
||||
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Text(theme.bold(" Session Tree"), 1, 0));
|
||||
this.addChild(
|
||||
new TruncatedText(theme.fg("muted", " ↑/↓: move. ←/→: page. l: label. ^O/⇧^O: filter. Type to search"), 0, 0),
|
||||
);
|
||||
this.addChild(new SearchLine(this.treeList));
|
||||
this.addChild(new DynamicBorder());
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(this.treeContainer);
|
||||
this.addChild(this.labelInputContainer);
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(new DynamicBorder());
|
||||
|
||||
if (tree.length === 0) {
|
||||
setTimeout(() => onCancel(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
private showLabelInput(entryId: string, currentLabel: string | undefined): void {
|
||||
this.labelInput = new LabelInput(entryId, currentLabel);
|
||||
this.labelInput.onSubmit = (id, label) => {
|
||||
this.treeList.updateNodeLabel(id, label);
|
||||
this.onLabelChangeCallback?.(id, label);
|
||||
this.hideLabelInput();
|
||||
};
|
||||
this.labelInput.onCancel = () => this.hideLabelInput();
|
||||
|
||||
this.treeContainer.clear();
|
||||
this.labelInputContainer.clear();
|
||||
this.labelInputContainer.addChild(this.labelInput);
|
||||
}
|
||||
|
||||
private hideLabelInput(): void {
|
||||
this.labelInput = null;
|
||||
this.labelInputContainer.clear();
|
||||
this.treeContainer.clear();
|
||||
this.treeContainer.addChild(this.treeList);
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
if (this.labelInput) {
|
||||
this.labelInput.handleInput(keyData);
|
||||
} else {
|
||||
this.treeList.handleInput(keyData);
|
||||
}
|
||||
}
|
||||
|
||||
getTreeList(): TreeList {
|
||||
return this.treeList;
|
||||
}
|
||||
}
|
||||
|
|
@ -5,13 +5,9 @@ import { getMarkdownTheme, theme } from "../theme/theme.js";
|
|||
* Component that renders a user message
|
||||
*/
|
||||
export class UserMessageComponent extends Container {
|
||||
constructor(text: string, isFirst: boolean) {
|
||||
constructor(text: string) {
|
||||
super();
|
||||
|
||||
// Add spacer before user message (except first one)
|
||||
if (!isFirst) {
|
||||
this.addChild(new Spacer(1));
|
||||
}
|
||||
this.addChild(new Spacer(1));
|
||||
this.addChild(
|
||||
new Markdown(text, 1, 1, getMarkdownTheme(), {
|
||||
bgColor: (text: string) => theme.bg("userMessageBg", text),
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue