mirror of
https://github.com/harivansh-afk/clanker-agent.git
synced 2026-04-15 08:03:42 +00:00
move pi-mono into companion-cloud as apps/companion-os
- Copy all pi-mono source into apps/companion-os/ - Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases - Update deploy-staging.yml to build pi from source (bun compile) before Docker build - Add apps/companion-os/** to path triggers - No more cross-repo dispatch needed Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
0250f72976
579 changed files with 206942 additions and 0 deletions
262
packages/agent/CHANGELOG.md
Normal file
262
packages/agent/CHANGELOG.md
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
# Changelog
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.56.2] - 2026-03-05
|
||||
|
||||
## [0.56.1] - 2026-03-05
|
||||
|
||||
## [0.56.0] - 2026-03-04
|
||||
|
||||
## [0.55.4] - 2026-03-02
|
||||
|
||||
## [0.55.3] - 2026-02-27
|
||||
|
||||
## [0.55.2] - 2026-02-27
|
||||
|
||||
## [0.55.1] - 2026-02-26
|
||||
|
||||
## [0.55.0] - 2026-02-24
|
||||
|
||||
## [0.54.2] - 2026-02-23
|
||||
|
||||
## [0.54.1] - 2026-02-22
|
||||
|
||||
## [0.54.0] - 2026-02-19
|
||||
|
||||
## [0.53.1] - 2026-02-19
|
||||
|
||||
## [0.53.0] - 2026-02-17
|
||||
|
||||
## [0.52.12] - 2026-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added `transport` to `AgentOptions` and `AgentLoopConfig` forwarding, allowing stream transport preference (`"sse"`, `"websocket"`, `"auto"`) to flow into provider calls.
|
||||
|
||||
## [0.52.11] - 2026-02-13
|
||||
|
||||
## [0.52.10] - 2026-02-12
|
||||
|
||||
## [0.52.9] - 2026-02-08
|
||||
|
||||
## [0.52.8] - 2026-02-07
|
||||
|
||||
## [0.52.7] - 2026-02-06
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `continue()` to resume queued steering/follow-up messages when context currently ends in an assistant message, and preserved one-at-a-time steering ordering during assistant-tail resumes ([#1312](https://github.com/badlogic/pi-mono/pull/1312) by [@ferologics](https://github.com/ferologics))
|
||||
|
||||
## [0.52.6] - 2026-02-05
|
||||
|
||||
## [0.52.5] - 2026-02-05
|
||||
|
||||
## [0.52.4] - 2026-02-05
|
||||
|
||||
## [0.52.3] - 2026-02-05
|
||||
|
||||
## [0.52.2] - 2026-02-05
|
||||
|
||||
## [0.52.1] - 2026-02-05
|
||||
|
||||
## [0.52.0] - 2026-02-05
|
||||
|
||||
## [0.51.6] - 2026-02-04
|
||||
|
||||
## [0.51.5] - 2026-02-04
|
||||
|
||||
## [0.51.4] - 2026-02-03
|
||||
|
||||
## [0.51.3] - 2026-02-03
|
||||
|
||||
## [0.51.2] - 2026-02-03
|
||||
|
||||
## [0.51.1] - 2026-02-02
|
||||
|
||||
## [0.51.0] - 2026-02-01
|
||||
|
||||
## [0.50.9] - 2026-02-01
|
||||
|
||||
## [0.50.8] - 2026-02-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `maxRetryDelayMs` option to `AgentOptions` to cap server-requested retry delays. Passed through to the underlying stream function. ([#1123](https://github.com/badlogic/pi-mono/issues/1123))
|
||||
|
||||
## [0.50.7] - 2026-01-31
|
||||
|
||||
## [0.50.6] - 2026-01-30
|
||||
|
||||
## [0.50.5] - 2026-01-30
|
||||
|
||||
## [0.50.3] - 2026-01-29
|
||||
|
||||
## [0.50.2] - 2026-01-29
|
||||
|
||||
## [0.50.1] - 2026-01-26
|
||||
|
||||
## [0.50.0] - 2026-01-26
|
||||
|
||||
## [0.49.3] - 2026-01-22
|
||||
|
||||
## [0.49.2] - 2026-01-19
|
||||
|
||||
## [0.49.1] - 2026-01-18
|
||||
|
||||
## [0.49.0] - 2026-01-17
|
||||
|
||||
## [0.48.0] - 2026-01-16
|
||||
|
||||
## [0.47.0] - 2026-01-16
|
||||
|
||||
## [0.46.0] - 2026-01-15
|
||||
|
||||
## [0.45.7] - 2026-01-13
|
||||
|
||||
## [0.45.6] - 2026-01-13
|
||||
|
||||
## [0.45.5] - 2026-01-13
|
||||
|
||||
## [0.45.4] - 2026-01-13
|
||||
|
||||
## [0.45.3] - 2026-01-13
|
||||
|
||||
## [0.45.2] - 2026-01-13
|
||||
|
||||
## [0.45.1] - 2026-01-13
|
||||
|
||||
## [0.45.0] - 2026-01-13
|
||||
|
||||
## [0.44.0] - 2026-01-12
|
||||
|
||||
## [0.43.0] - 2026-01-11
|
||||
|
||||
## [0.42.5] - 2026-01-11
|
||||
|
||||
## [0.42.4] - 2026-01-10
|
||||
|
||||
## [0.42.3] - 2026-01-10
|
||||
|
||||
## [0.42.2] - 2026-01-10
|
||||
|
||||
## [0.42.1] - 2026-01-09
|
||||
|
||||
## [0.42.0] - 2026-01-09
|
||||
|
||||
## [0.41.0] - 2026-01-09
|
||||
|
||||
## [0.40.1] - 2026-01-09
|
||||
|
||||
## [0.40.0] - 2026-01-08
|
||||
|
||||
## [0.39.1] - 2026-01-08
|
||||
|
||||
## [0.39.0] - 2026-01-08
|
||||
|
||||
## [0.38.0] - 2026-01-08
|
||||
|
||||
### Added
|
||||
|
||||
- `thinkingBudgets` option on `Agent` and `AgentOptions` to customize token budgets per thinking level ([#529](https://github.com/badlogic/pi-mono/pull/529) by [@melihmucuk](https://github.com/melihmucuk))
|
||||
|
||||
## [0.37.8] - 2026-01-07
|
||||
|
||||
## [0.37.7] - 2026-01-07
|
||||
|
||||
## [0.37.6] - 2026-01-06
|
||||
|
||||
## [0.37.5] - 2026-01-06
|
||||
|
||||
## [0.37.4] - 2026-01-06
|
||||
|
||||
## [0.37.3] - 2026-01-06
|
||||
|
||||
### Added
|
||||
|
||||
- `sessionId` option on `Agent` to forward session identifiers to LLM providers for session-based caching.
|
||||
|
||||
## [0.37.2] - 2026-01-05
|
||||
|
||||
## [0.37.1] - 2026-01-05
|
||||
|
||||
## [0.37.0] - 2026-01-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- `minimal` thinking level now maps to `minimal` reasoning effort instead of being treated as `low`.
|
||||
|
||||
## [0.36.0] - 2026-01-05
|
||||
|
||||
## [0.35.0] - 2026-01-05
|
||||
|
||||
## [0.34.2] - 2026-01-04
|
||||
|
||||
## [0.34.1] - 2026-01-04
|
||||
|
||||
## [0.34.0] - 2026-01-04
|
||||
|
||||
## [0.33.0] - 2026-01-04
|
||||
|
||||
## [0.32.3] - 2026-01-03
|
||||
|
||||
## [0.32.2] - 2026-01-03
|
||||
|
||||
## [0.32.1] - 2026-01-03
|
||||
|
||||
## [0.32.0] - 2026-01-03
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Queue API replaced with steer/followUp**: The `queueMessage()` method has been split into two methods with different delivery semantics ([#403](https://github.com/badlogic/pi-mono/issues/403)):
|
||||
- `steer(msg)`: Interrupts the agent mid-run. Delivered after current tool execution, skips remaining tools.
|
||||
- `followUp(msg)`: Waits until the agent finishes. Delivered only when there are no more tool calls or steering messages.
|
||||
- **Queue mode renamed**: `queueMode` option renamed to `steeringMode`. Added new `followUpMode` option. Both control whether messages are delivered one-at-a-time or all at once.
|
||||
- **AgentLoopConfig callbacks renamed**: `getQueuedMessages` split into `getSteeringMessages` and `getFollowUpMessages`.
|
||||
- **Agent methods renamed**:
|
||||
- `queueMessage()` → `steer()` and `followUp()`
|
||||
- `clearMessageQueue()` → `clearSteeringQueue()`, `clearFollowUpQueue()`, `clearAllQueues()`
|
||||
- `setQueueMode()`/`getQueueMode()` → `setSteeringMode()`/`getSteeringMode()` and `setFollowUpMode()`/`getFollowUpMode()`
|
||||
|
||||
### Fixed
|
||||
|
||||
- `prompt()` and `continue()` now throw if called while the agent is already streaming, preventing race conditions and corrupted state. Use `steer()` or `followUp()` to queue messages during streaming, or `await` the previous call.
|
||||
|
||||
## [0.31.1] - 2026-01-02
|
||||
|
||||
## [0.31.0] - 2026-01-02
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. Use the `streamFn` option directly for custom streaming implementations.
|
||||
|
||||
- **Agent options renamed**:
|
||||
- `transport` → removed (use `streamFn` instead)
|
||||
- `messageTransformer` → `convertToLlm`
|
||||
- `preprocessor` → `transformContext`
|
||||
|
||||
- **`AppMessage` renamed to `AgentMessage`**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency.
|
||||
|
||||
- **`CustomMessages` renamed to `CustomAgentMessages`**: The declaration merging interface has been renamed.
|
||||
|
||||
- **`UserMessageWithAttachments` and `Attachment` types removed**: Attachment handling is now the responsibility of the `convertToLlm` function.
|
||||
|
||||
- **Agent loop moved from `@mariozechner/pi-ai`**: The `agentLoop`, `agentLoopContinue`, and related types have moved to this package. Import from `@mariozechner/pi-agent-core` instead.
|
||||
|
||||
### Added
|
||||
|
||||
- `streamFn` option on `Agent` for custom stream implementations. Default uses `streamSimple` from pi-ai.
|
||||
|
||||
- `streamProxy()` utility function for browser apps that need to proxy LLM calls through a backend server. Replaces the removed `AppTransport`.
|
||||
|
||||
- `getApiKey` option for dynamic API key resolution (useful for expiring OAuth tokens like GitHub Copilot).
|
||||
|
||||
- `agentLoop()` and `agentLoopContinue()` low-level functions for running the agent loop without the `Agent` class wrapper.
|
||||
|
||||
- New exported types: `AgentLoopConfig`, `AgentContext`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `StreamFn`.
|
||||
|
||||
### Changed
|
||||
|
||||
- `Agent` constructor now has all options optional (empty options use defaults).
|
||||
|
||||
- `queueMessage()` is now synchronous (no longer returns a Promise).
|
||||
426
packages/agent/README.md
Normal file
426
packages/agent/README.md
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
# @mariozechner/pi-agent-core
|
||||
|
||||
Stateful agent with tool execution and event streaming. Built on `@mariozechner/pi-ai`.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @mariozechner/pi-agent-core
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { Agent } from "@mariozechner/pi-agent-core";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model: getModel("anthropic", "claude-sonnet-4-20250514"),
|
||||
},
|
||||
});
|
||||
|
||||
agent.subscribe((event) => {
|
||||
if (
|
||||
event.type === "message_update" &&
|
||||
event.assistantMessageEvent.type === "text_delta"
|
||||
) {
|
||||
// Stream just the new text chunk
|
||||
process.stdout.write(event.assistantMessageEvent.delta);
|
||||
}
|
||||
});
|
||||
|
||||
await agent.prompt("Hello!");
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### AgentMessage vs LLM Message
|
||||
|
||||
The agent works with `AgentMessage`, a flexible type that can include:
|
||||
|
||||
- Standard LLM messages (`user`, `assistant`, `toolResult`)
|
||||
- Custom app-specific message types via declaration merging
|
||||
|
||||
LLMs only understand `user`, `assistant`, and `toolResult`. The `convertToLlm` function bridges this gap by filtering and transforming messages before each LLM call.
|
||||
|
||||
### Message Flow
|
||||
|
||||
```
|
||||
AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM
|
||||
(optional) (required)
|
||||
```
|
||||
|
||||
1. **transformContext**: Prune old messages, inject external context
|
||||
2. **convertToLlm**: Filter out UI-only messages, convert custom types to LLM format
|
||||
|
||||
## Event Flow
|
||||
|
||||
The agent emits events for UI updates. Understanding the event sequence helps build responsive interfaces.
|
||||
|
||||
### prompt() Event Sequence
|
||||
|
||||
When you call `prompt("Hello")`:
|
||||
|
||||
```
|
||||
prompt("Hello")
|
||||
├─ agent_start
|
||||
├─ turn_start
|
||||
├─ message_start { message: userMessage } // Your prompt
|
||||
├─ message_end { message: userMessage }
|
||||
├─ message_start { message: assistantMessage } // LLM starts responding
|
||||
├─ message_update { message: partial... } // Streaming chunks
|
||||
├─ message_update { message: partial... }
|
||||
├─ message_end { message: assistantMessage } // Complete response
|
||||
├─ turn_end { message, toolResults: [] }
|
||||
└─ agent_end { messages: [...] }
|
||||
```
|
||||
|
||||
### With Tool Calls
|
||||
|
||||
If the assistant calls tools, the loop continues:
|
||||
|
||||
```
|
||||
prompt("Read config.json")
|
||||
├─ agent_start
|
||||
├─ turn_start
|
||||
├─ message_start/end { userMessage }
|
||||
├─ message_start { assistantMessage with toolCall }
|
||||
├─ message_update...
|
||||
├─ message_end { assistantMessage }
|
||||
├─ tool_execution_start { toolCallId, toolName, args }
|
||||
├─ tool_execution_update { partialResult } // If tool streams
|
||||
├─ tool_execution_end { toolCallId, result }
|
||||
├─ message_start/end { toolResultMessage }
|
||||
├─ turn_end { message, toolResults: [toolResult] }
|
||||
│
|
||||
├─ turn_start // Next turn
|
||||
├─ message_start { assistantMessage } // LLM responds to tool result
|
||||
├─ message_update...
|
||||
├─ message_end
|
||||
├─ turn_end
|
||||
└─ agent_end
|
||||
```
|
||||
|
||||
### continue() Event Sequence
|
||||
|
||||
`continue()` resumes from existing context without adding a new message. Use it for retries after errors.
|
||||
|
||||
```typescript
|
||||
// After an error, retry from current state
|
||||
await agent.continue();
|
||||
```
|
||||
|
||||
The last message in context must be `user` or `toolResult` (not `assistant`).
|
||||
|
||||
### Event Types
|
||||
|
||||
| Event | Description |
|
||||
| ----------------------- | --------------------------------------------------------------- |
|
||||
| `agent_start` | Agent begins processing |
|
||||
| `agent_end` | Agent completes with all new messages |
|
||||
| `turn_start` | New turn begins (one LLM call + tool executions) |
|
||||
| `turn_end` | Turn completes with assistant message and tool results |
|
||||
| `message_start` | Any message begins (user, assistant, toolResult) |
|
||||
| `message_update` | **Assistant only.** Includes `assistantMessageEvent` with delta |
|
||||
| `message_end` | Message completes |
|
||||
| `tool_execution_start` | Tool begins |
|
||||
| `tool_execution_update` | Tool streams progress |
|
||||
| `tool_execution_end` | Tool completes |
|
||||
|
||||
## Agent Options
|
||||
|
||||
```typescript
|
||||
const agent = new Agent({
|
||||
// Initial state
|
||||
initialState: {
|
||||
systemPrompt: string,
|
||||
model: Model<any>,
|
||||
thinkingLevel: "off" | "minimal" | "low" | "medium" | "high" | "xhigh",
|
||||
tools: AgentTool<any>[],
|
||||
messages: AgentMessage[],
|
||||
},
|
||||
|
||||
// Convert AgentMessage[] to LLM Message[] (required for custom message types)
|
||||
convertToLlm: (messages) => messages.filter(...),
|
||||
|
||||
// Transform context before convertToLlm (for pruning, compaction)
|
||||
transformContext: async (messages, signal) => pruneOldMessages(messages),
|
||||
|
||||
// Steering mode: "one-at-a-time" (default) or "all"
|
||||
steeringMode: "one-at-a-time",
|
||||
|
||||
// Follow-up mode: "one-at-a-time" (default) or "all"
|
||||
followUpMode: "one-at-a-time",
|
||||
|
||||
// Custom stream function (for proxy backends)
|
||||
streamFn: streamProxy,
|
||||
|
||||
// Session ID for provider caching
|
||||
sessionId: "session-123",
|
||||
|
||||
// Dynamic API key resolution (for expiring OAuth tokens)
|
||||
getApiKey: async (provider) => refreshToken(),
|
||||
|
||||
// Custom thinking budgets for token-based providers
|
||||
thinkingBudgets: {
|
||||
minimal: 128,
|
||||
low: 512,
|
||||
medium: 1024,
|
||||
high: 2048,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Agent State
|
||||
|
||||
```typescript
|
||||
interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[];
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null; // Current partial during streaming
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
```
|
||||
|
||||
Access via `agent.state`. During streaming, `streamMessage` contains the partial assistant message.
|
||||
|
||||
## Methods
|
||||
|
||||
### Prompting
|
||||
|
||||
```typescript
|
||||
// Text prompt
|
||||
await agent.prompt("Hello");
|
||||
|
||||
// With images
|
||||
await agent.prompt("What's in this image?", [
|
||||
{ type: "image", data: base64Data, mimeType: "image/jpeg" },
|
||||
]);
|
||||
|
||||
// AgentMessage directly
|
||||
await agent.prompt({ role: "user", content: "Hello", timestamp: Date.now() });
|
||||
|
||||
// Continue from current context (last message must be user or toolResult)
|
||||
await agent.continue();
|
||||
```
|
||||
|
||||
### State Management
|
||||
|
||||
```typescript
|
||||
agent.setSystemPrompt("New prompt");
|
||||
agent.setModel(getModel("openai", "gpt-4o"));
|
||||
agent.setThinkingLevel("medium");
|
||||
agent.setTools([myTool]);
|
||||
agent.replaceMessages(newMessages);
|
||||
agent.appendMessage(message);
|
||||
agent.clearMessages();
|
||||
agent.reset(); // Clear everything
|
||||
```
|
||||
|
||||
### Session and Thinking Budgets
|
||||
|
||||
```typescript
|
||||
agent.sessionId = "session-123";
|
||||
|
||||
agent.thinkingBudgets = {
|
||||
minimal: 128,
|
||||
low: 512,
|
||||
medium: 1024,
|
||||
high: 2048,
|
||||
};
|
||||
```
|
||||
|
||||
### Control
|
||||
|
||||
```typescript
|
||||
agent.abort(); // Cancel current operation
|
||||
await agent.waitForIdle(); // Wait for completion
|
||||
```
|
||||
|
||||
### Events
|
||||
|
||||
```typescript
|
||||
const unsubscribe = agent.subscribe((event) => {
|
||||
console.log(event.type);
|
||||
});
|
||||
unsubscribe();
|
||||
```
|
||||
|
||||
## Steering and Follow-up
|
||||
|
||||
Steering messages let you interrupt the agent while tools are running. Follow-up messages let you queue work after the agent would otherwise stop.
|
||||
|
||||
```typescript
|
||||
agent.setSteeringMode("one-at-a-time");
|
||||
agent.setFollowUpMode("one-at-a-time");
|
||||
|
||||
// While agent is running tools
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: "Stop! Do this instead.",
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
// After the agent finishes its current work
|
||||
agent.followUp({
|
||||
role: "user",
|
||||
content: "Also summarize the result.",
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
const steeringMode = agent.getSteeringMode();
|
||||
const followUpMode = agent.getFollowUpMode();
|
||||
|
||||
agent.clearSteeringQueue();
|
||||
agent.clearFollowUpQueue();
|
||||
agent.clearAllQueues();
|
||||
```
|
||||
|
||||
Use clearSteeringQueue, clearFollowUpQueue, or clearAllQueues to drop queued messages.
|
||||
|
||||
When steering messages are detected after a tool completes:
|
||||
|
||||
1. Remaining tools are skipped with error results
|
||||
2. Steering messages are injected
|
||||
3. LLM responds to the interruption
|
||||
|
||||
Follow-up messages are checked only when there are no more tool calls and no steering messages. If any are queued, they are injected and another turn runs.
|
||||
|
||||
## Custom Message Types
|
||||
|
||||
Extend `AgentMessage` via declaration merging:
|
||||
|
||||
```typescript
|
||||
declare module "@mariozechner/pi-agent-core" {
|
||||
interface CustomAgentMessages {
|
||||
notification: { role: "notification"; text: string; timestamp: number };
|
||||
}
|
||||
}
|
||||
|
||||
// Now valid
|
||||
const msg: AgentMessage = {
|
||||
role: "notification",
|
||||
text: "Info",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
```
|
||||
|
||||
Handle custom types in `convertToLlm`:
|
||||
|
||||
```typescript
|
||||
const agent = new Agent({
|
||||
convertToLlm: (messages) =>
|
||||
messages.flatMap((m) => {
|
||||
if (m.role === "notification") return []; // Filter out
|
||||
return [m];
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
Define tools using `AgentTool`:
|
||||
|
||||
```typescript
|
||||
import { Type } from "@sinclair/typebox";
|
||||
|
||||
const readFileTool: AgentTool = {
|
||||
name: "read_file",
|
||||
label: "Read File", // For UI display
|
||||
description: "Read a file's contents",
|
||||
parameters: Type.Object({
|
||||
path: Type.String({ description: "File path" }),
|
||||
}),
|
||||
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||
const content = await fs.readFile(params.path, "utf-8");
|
||||
|
||||
// Optional: stream progress
|
||||
onUpdate?.({
|
||||
content: [{ type: "text", text: "Reading..." }],
|
||||
details: {},
|
||||
});
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: content }],
|
||||
details: { path: params.path, size: content.length },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
agent.setTools([readFileTool]);
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
**Throw an error** when a tool fails. Do not return error messages as content.
|
||||
|
||||
```typescript
|
||||
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||
if (!fs.existsSync(params.path)) {
|
||||
throw new Error(`File not found: ${params.path}`);
|
||||
}
|
||||
// Return content only on success
|
||||
return { content: [{ type: "text", text: "..." }] };
|
||||
};
|
||||
```
|
||||
|
||||
Thrown errors are caught by the agent and reported to the LLM as tool errors with `isError: true`.
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
For browser apps that proxy through a backend:
|
||||
|
||||
```typescript
|
||||
import { Agent, streamProxy } from "@mariozechner/pi-agent-core";
|
||||
|
||||
const agent = new Agent({
|
||||
streamFn: (model, context, options) =>
|
||||
streamProxy(model, context, {
|
||||
...options,
|
||||
authToken: "...",
|
||||
proxyUrl: "https://your-server.com",
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
## Low-Level API
|
||||
|
||||
For direct control without the Agent class:
|
||||
|
||||
```typescript
|
||||
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-agent-core";
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: getModel("openai", "gpt-4o"),
|
||||
convertToLlm: (msgs) =>
|
||||
msgs.filter((m) => ["user", "assistant", "toolResult"].includes(m.role)),
|
||||
};
|
||||
|
||||
const userMessage = { role: "user", content: "Hello", timestamp: Date.now() };
|
||||
|
||||
for await (const event of agentLoop([userMessage], context, config)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
|
||||
// Continue from existing context
|
||||
for await (const event of agentLoopContinue(context, config)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
44
packages/agent/package.json
Normal file
44
packages/agent/package.json
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"name": "@mariozechner/pi-agent-core",
|
||||
"version": "0.56.2",
|
||||
"description": "General-purpose agent with transport abstraction, state management, and attachment support",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"files": [
|
||||
"dist",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "shx rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@mariozechner/pi-ai": "^0.56.2"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"agent",
|
||||
"llm",
|
||||
"transport",
|
||||
"state-management"
|
||||
],
|
||||
"author": "Mario Zechner",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/getcompanion-ai/co-mono.git",
|
||||
"directory": "packages/agent"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^24.3.0",
|
||||
"typescript": "^5.7.3",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
}
|
||||
452
packages/agent/src/agent-loop.ts
Normal file
452
packages/agent/src/agent-loop.ts
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [...prompts];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, ...prompts],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
for (const prompt of prompts) {
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
}
|
||||
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = { ...context };
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let firstTurn = true;
|
||||
// Check for steering messages at start (user may have typed while waiting)
|
||||
let pendingMessages: AgentMessage[] =
|
||||
(await config.getSteeringMessages?.()) || [];
|
||||
|
||||
// Outer loop: continues when queued follow-up messages arrive after agent would stop
|
||||
while (true) {
|
||||
let hasMoreToolCalls = true;
|
||||
let steeringAfterTools: AgentMessage[] | null = null;
|
||||
|
||||
// Inner loop: process tool calls and steering messages
|
||||
while (hasMoreToolCalls || pendingMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
stream.push({ type: "turn_start" });
|
||||
} else {
|
||||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process pending messages (inject before next assistant response)
|
||||
if (pendingMessages.length > 0) {
|
||||
for (const message of pendingMessages) {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
pendingMessages = [];
|
||||
}
|
||||
|
||||
// Stream assistant response
|
||||
const message = await streamAssistantResponse(
|
||||
currentContext,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for tool calls
|
||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||
hasMoreToolCalls = toolCalls.length > 0;
|
||||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext.tools,
|
||||
message,
|
||||
signal,
|
||||
stream,
|
||||
config.getSteeringMessages,
|
||||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
steeringAfterTools = toolExecution.steeringMessages ?? null;
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
|
||||
// Get steering messages after turn completes
|
||||
if (steeringAfterTools && steeringAfterTools.length > 0) {
|
||||
pendingMessages = steeringAfterTools;
|
||||
steeringAfterTools = null;
|
||||
} else {
|
||||
pendingMessages = (await config.getSteeringMessages?.()) || [];
|
||||
}
|
||||
}
|
||||
|
||||
// Agent would stop here. Check for follow-up messages.
|
||||
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
|
||||
if (followUpMessages.length > 0) {
|
||||
// Set as pending so inner loop processes them
|
||||
pendingMessages = followUpMessages;
|
||||
continue;
|
||||
}
|
||||
|
||||
// No more messages, exit
|
||||
break;
|
||||
}
|
||||
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey
|
||||
? await config.getApiKey(config.model.provider)
|
||||
: undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
||||
for await (const event of response) {
|
||||
switch (event.type) {
|
||||
case "start":
|
||||
partialMessage = event.partial;
|
||||
context.messages.push(partialMessage);
|
||||
addedPartial = true;
|
||||
stream.push({ type: "message_start", message: { ...partialMessage } });
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
case "text_delta":
|
||||
case "text_end":
|
||||
case "thinking_start":
|
||||
case "thinking_delta":
|
||||
case "thinking_end":
|
||||
case "toolcall_start":
|
||||
case "toolcall_delta":
|
||||
case "toolcall_end":
|
||||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
stream.push({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "done":
|
||||
case "error": {
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
}
|
||||
if (!addedPartial) {
|
||||
stream.push({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
stream.push({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return await response.result();
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
tools: AgentTool<any>[] | undefined,
|
||||
assistantMessage: AssistantMessage,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
|
||||
): Promise<{
|
||||
toolResults: ToolResultMessage[];
|
||||
steeringMessages?: AgentMessage[];
|
||||
}> {
|
||||
const toolCalls = assistantMessage.content.filter(
|
||||
(c) => c.type === "toolCall",
|
||||
);
|
||||
const results: ToolResultMessage[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
const tool = tools?.find((t) => t.name === toolCall.name);
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
let result: AgentToolResult<any>;
|
||||
let isError = false;
|
||||
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
|
||||
result = await tool.execute(
|
||||
toolCall.id,
|
||||
validatedArgs,
|
||||
signal,
|
||||
(partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
},
|
||||
);
|
||||
} catch (e) {
|
||||
result = {
|
||||
content: [
|
||||
{ type: "text", text: e instanceof Error ? e.message : String(e) },
|
||||
],
|
||||
details: {},
|
||||
};
|
||||
isError = true;
|
||||
}
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
results.push(toolResultMessage);
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
// Check for steering messages - skip remaining tools if user interrupted
|
||||
if (getSteeringMessages) {
|
||||
const steering = await getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, steeringMessages };
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {},
|
||||
};
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError: true,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: {},
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
return toolResultMessage;
|
||||
}
|
||||
605
packages/agent/src/agent.ts
Normal file
605
packages/agent/src/agent.ts
Normal file
|
|
@ -0,0 +1,605 @@
|
|||
/**
|
||||
* Agent class that uses the agent-loop directly.
|
||||
* No transport abstraction - calls streamSimple via the loop.
|
||||
*/
|
||||
|
||||
import {
|
||||
getModel,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
streamSimple,
|
||||
type TextContent,
|
||||
type ThinkingBudgets,
|
||||
type Transport,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentState,
|
||||
AgentTool,
|
||||
StreamFn,
|
||||
ThinkingLevel,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
||||
*/
|
||||
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter(
|
||||
(m) =>
|
||||
m.role === "user" || m.role === "assistant" || m.role === "toolResult",
|
||||
);
|
||||
}
|
||||
|
||||
export interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
* Default filters to user/assistant/toolResult and converts attachments.
|
||||
*/
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to context before convertToLlm.
|
||||
* Use for context pruning, injecting external context, etc.
|
||||
*/
|
||||
transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
steeringMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
followUpMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||
*/
|
||||
streamFn?: StreamFn;
|
||||
|
||||
/**
|
||||
* Optional session identifier forwarded to LLM providers.
|
||||
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
||||
*/
|
||||
sessionId?: string;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
||||
*/
|
||||
getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Custom token budgets for thinking levels (token-based providers only).
|
||||
*/
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
*/
|
||||
transport?: Transport;
|
||||
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately,
|
||||
* allowing higher-level retry logic to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
}
|
||||
|
||||
export class Agent {
|
||||
private _state: AgentState = {
|
||||
systemPrompt: "",
|
||||
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
messages: [],
|
||||
isStreaming: false,
|
||||
streamMessage: null,
|
||||
pendingToolCalls: new Set<string>(),
|
||||
error: undefined,
|
||||
};
|
||||
|
||||
private listeners = new Set<(e: AgentEvent) => void>();
|
||||
private abortController?: AbortController;
|
||||
private convertToLlm: (
|
||||
messages: AgentMessage[],
|
||||
) => Message[] | Promise<Message[]>;
|
||||
private transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
private steeringQueue: AgentMessage[] = [];
|
||||
private followUpQueue: AgentMessage[] = [];
|
||||
private steeringMode: "all" | "one-at-a-time";
|
||||
private followUpMode: "all" | "one-at-a-time";
|
||||
public streamFn: StreamFn;
|
||||
private _sessionId?: string;
|
||||
public getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
private runningPrompt?: Promise<void>;
|
||||
private resolveRunningPrompt?: () => void;
|
||||
private _thinkingBudgets?: ThinkingBudgets;
|
||||
private _transport: Transport;
|
||||
private _maxRetryDelayMs?: number;
|
||||
|
||||
constructor(opts: AgentOptions = {}) {
|
||||
this._state = { ...this._state, ...opts.initialState };
|
||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||
this.transformContext = opts.transformContext;
|
||||
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
||||
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
||||
this.streamFn = opts.streamFn || streamSimple;
|
||||
this._sessionId = opts.sessionId;
|
||||
this.getApiKey = opts.getApiKey;
|
||||
this._thinkingBudgets = opts.thinkingBudgets;
|
||||
this._transport = opts.transport ?? "sse";
|
||||
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session ID used for provider caching.
|
||||
*/
|
||||
get sessionId(): string | undefined {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the session ID for provider caching.
|
||||
* Call this when switching sessions (new session, branch, resume).
|
||||
*/
|
||||
set sessionId(value: string | undefined) {
|
||||
this._sessionId = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thinking budgets.
|
||||
*/
|
||||
get thinkingBudgets(): ThinkingBudgets | undefined {
|
||||
return this._thinkingBudgets;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set custom thinking budgets for token-based providers.
|
||||
*/
|
||||
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
||||
this._thinkingBudgets = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current preferred transport.
|
||||
*/
|
||||
get transport(): Transport {
|
||||
return this._transport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the preferred transport.
|
||||
*/
|
||||
setTransport(value: Transport) {
|
||||
this._transport = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current max retry delay in milliseconds.
|
||||
*/
|
||||
get maxRetryDelayMs(): number | undefined {
|
||||
return this._maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the maximum delay to wait for server-requested retries.
|
||||
* Set to 0 to disable the cap.
|
||||
*/
|
||||
set maxRetryDelayMs(value: number | undefined) {
|
||||
this._maxRetryDelayMs = value;
|
||||
}
|
||||
|
||||
get state(): AgentState {
|
||||
return this._state;
|
||||
}
|
||||
|
||||
subscribe(fn: (e: AgentEvent) => void): () => void {
|
||||
this.listeners.add(fn);
|
||||
return () => this.listeners.delete(fn);
|
||||
}
|
||||
|
||||
// State mutators
|
||||
setSystemPrompt(v: string) {
|
||||
this._state.systemPrompt = v;
|
||||
}
|
||||
|
||||
setModel(m: Model<any>) {
|
||||
this._state.model = m;
|
||||
}
|
||||
|
||||
setThinkingLevel(l: ThinkingLevel) {
|
||||
this._state.thinkingLevel = l;
|
||||
}
|
||||
|
||||
setSteeringMode(mode: "all" | "one-at-a-time") {
|
||||
this.steeringMode = mode;
|
||||
}
|
||||
|
||||
getSteeringMode(): "all" | "one-at-a-time" {
|
||||
return this.steeringMode;
|
||||
}
|
||||
|
||||
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
||||
this.followUpMode = mode;
|
||||
}
|
||||
|
||||
getFollowUpMode(): "all" | "one-at-a-time" {
|
||||
return this.followUpMode;
|
||||
}
|
||||
|
||||
setTools(t: AgentTool<any>[]) {
|
||||
this._state.tools = t;
|
||||
}
|
||||
|
||||
replaceMessages(ms: AgentMessage[]) {
|
||||
this._state.messages = ms.slice();
|
||||
}
|
||||
|
||||
appendMessage(m: AgentMessage) {
|
||||
this._state.messages = [...this._state.messages, m];
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a steering message to interrupt the agent mid-run.
|
||||
* Delivered after current tool execution, skips remaining tools.
|
||||
*/
|
||||
steer(m: AgentMessage) {
|
||||
this.steeringQueue.push(m);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a follow-up message to be processed after the agent finishes.
|
||||
* Delivered only when agent has no more tool calls or steering messages.
|
||||
*/
|
||||
followUp(m: AgentMessage) {
|
||||
this.followUpQueue.push(m);
|
||||
}
|
||||
|
||||
clearSteeringQueue() {
|
||||
this.steeringQueue = [];
|
||||
}
|
||||
|
||||
clearFollowUpQueue() {
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
clearAllQueues() {
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
hasQueuedMessages(): boolean {
|
||||
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
||||
}
|
||||
|
||||
private dequeueSteeringMessages(): AgentMessage[] {
|
||||
if (this.steeringMode === "one-at-a-time") {
|
||||
if (this.steeringQueue.length > 0) {
|
||||
const first = this.steeringQueue[0];
|
||||
this.steeringQueue = this.steeringQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const steering = this.steeringQueue.slice();
|
||||
this.steeringQueue = [];
|
||||
return steering;
|
||||
}
|
||||
|
||||
private dequeueFollowUpMessages(): AgentMessage[] {
|
||||
if (this.followUpMode === "one-at-a-time") {
|
||||
if (this.followUpQueue.length > 0) {
|
||||
const first = this.followUpQueue[0];
|
||||
this.followUpQueue = this.followUpQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const followUp = this.followUpQueue.slice();
|
||||
this.followUpQueue = [];
|
||||
return followUp;
|
||||
}
|
||||
|
||||
clearMessages() {
|
||||
this._state.messages = [];
|
||||
}
|
||||
|
||||
abort() {
|
||||
this.abortController?.abort();
|
||||
}
|
||||
|
||||
waitForIdle(): Promise<void> {
|
||||
return this.runningPrompt ?? Promise.resolve();
|
||||
}
|
||||
|
||||
reset() {
|
||||
this._state.messages = [];
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this._state.error = undefined;
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
/** Send a prompt with an AgentMessage */
|
||||
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||
async prompt(
|
||||
input: string | AgentMessage | AgentMessage[],
|
||||
images?: ImageContent[],
|
||||
) {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||
);
|
||||
}
|
||||
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
let msgs: AgentMessage[];
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
msgs = input;
|
||||
} else if (typeof input === "string") {
|
||||
const content: Array<TextContent | ImageContent> = [
|
||||
{ type: "text", text: input },
|
||||
];
|
||||
if (images && images.length > 0) {
|
||||
content.push(...images);
|
||||
}
|
||||
msgs = [
|
||||
{
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
} else {
|
||||
msgs = [input];
|
||||
}
|
||||
|
||||
await this._runLoop(msgs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from current context (used for retries and resuming queued messages).
|
||||
*/
|
||||
async continue() {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing. Wait for completion before continuing.",
|
||||
);
|
||||
}
|
||||
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
if (messages[messages.length - 1].role === "assistant") {
|
||||
const queuedSteering = this.dequeueSteeringMessages();
|
||||
if (queuedSteering.length > 0) {
|
||||
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
||||
return;
|
||||
}
|
||||
|
||||
const queuedFollowUp = this.dequeueFollowUpMessages();
|
||||
if (queuedFollowUp.length > 0) {
|
||||
await this._runLoop(queuedFollowUp);
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
await this._runLoop(undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the agent loop.
|
||||
* If messages are provided, starts a new conversation turn with those messages.
|
||||
* Otherwise, continues from existing context.
|
||||
*/
|
||||
private async _runLoop(
|
||||
messages?: AgentMessage[],
|
||||
options?: { skipInitialSteeringPoll?: boolean },
|
||||
) {
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
this.abortController = new AbortController();
|
||||
this._state.isStreaming = true;
|
||||
this._state.streamMessage = null;
|
||||
this._state.error = undefined;
|
||||
|
||||
const reasoning =
|
||||
this._state.thinkingLevel === "off"
|
||||
? undefined
|
||||
: this._state.thinkingLevel;
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: this._state.systemPrompt,
|
||||
messages: this._state.messages.slice(),
|
||||
tools: this._state.tools,
|
||||
};
|
||||
|
||||
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
reasoning,
|
||||
sessionId: this._sessionId,
|
||||
transport: this._transport,
|
||||
thinkingBudgets: this._thinkingBudgets,
|
||||
maxRetryDelayMs: this._maxRetryDelayMs,
|
||||
convertToLlm: this.convertToLlm,
|
||||
transformContext: this.transformContext,
|
||||
getApiKey: this.getApiKey,
|
||||
getSteeringMessages: async () => {
|
||||
if (skipInitialSteeringPoll) {
|
||||
skipInitialSteeringPoll = false;
|
||||
return [];
|
||||
}
|
||||
return this.dequeueSteeringMessages();
|
||||
},
|
||||
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
||||
};
|
||||
|
||||
let partial: AgentMessage | null = null;
|
||||
|
||||
try {
|
||||
const stream = messages
|
||||
? agentLoop(
|
||||
messages,
|
||||
context,
|
||||
config,
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
)
|
||||
: agentLoopContinue(
|
||||
context,
|
||||
config,
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
// Update internal state based on events
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_update":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
partial = null;
|
||||
this._state.streamMessage = null;
|
||||
this.appendMessage(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_start": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.add(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_end": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.delete(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
if (
|
||||
event.message.role === "assistant" &&
|
||||
(event.message as any).errorMessage
|
||||
) {
|
||||
this._state.error = (event.message as any).errorMessage;
|
||||
}
|
||||
break;
|
||||
|
||||
case "agent_end":
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
break;
|
||||
}
|
||||
|
||||
// Emit to listeners
|
||||
this.emit(event);
|
||||
}
|
||||
|
||||
// Handle any remaining partial message
|
||||
if (
|
||||
partial &&
|
||||
partial.role === "assistant" &&
|
||||
partial.content.length > 0
|
||||
) {
|
||||
const onlyEmpty = !partial.content.some(
|
||||
(c) =>
|
||||
(c.type === "thinking" && c.thinking.trim().length > 0) ||
|
||||
(c.type === "text" && c.text.trim().length > 0) ||
|
||||
(c.type === "toolCall" && c.name.trim().length > 0),
|
||||
);
|
||||
if (!onlyEmpty) {
|
||||
this.appendMessage(partial);
|
||||
} else {
|
||||
if (this.abortController?.signal.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
const errorMsg: AgentMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
} as AgentMessage;
|
||||
|
||||
this.appendMessage(errorMsg);
|
||||
this._state.error = err?.message || String(err);
|
||||
this.emit({ type: "agent_end", messages: [errorMsg] });
|
||||
} finally {
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this.abortController = undefined;
|
||||
this.resolveRunningPrompt?.();
|
||||
this.runningPrompt = undefined;
|
||||
this.resolveRunningPrompt = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private emit(e: AgentEvent) {
|
||||
for (const listener of this.listeners) {
|
||||
listener(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
8
packages/agent/src/index.ts
Normal file
8
packages/agent/src/index.ts
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// Core Agent
|
||||
export * from "./agent.js";
|
||||
// Loop functions
|
||||
export * from "./agent-loop.js";
|
||||
// Proxy utilities
|
||||
export * from "./proxy.js";
|
||||
// Types
|
||||
export * from "./types.js";
|
||||
369
packages/agent/src/proxy.ts
Normal file
369
packages/agent/src/proxy.ts
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
/**
|
||||
* Proxy stream function for apps that route LLM calls through a server.
|
||||
* The server manages auth and proxies requests to LLM providers.
|
||||
*/
|
||||
|
||||
// Internal import for JSON parsing utility
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type Context,
|
||||
EventStream,
|
||||
type Model,
|
||||
parseStreamingJson,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type ToolCall,
|
||||
} from "@mariozechner/pi-ai";
|
||||
|
||||
// Create stream class matching ProxyMessageEventStream
|
||||
class ProxyMessageEventStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| {
|
||||
type: "toolcall_start";
|
||||
contentIndex: number;
|
||||
id: string;
|
||||
toolName: string;
|
||||
}
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
|
||||
usage: AssistantMessage["usage"];
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
errorMessage?: string;
|
||||
usage: AssistantMessage["usage"];
|
||||
};
|
||||
|
||||
export interface ProxyStreamOptions extends SimpleStreamOptions {
|
||||
/** Auth token for the proxy server */
|
||||
authToken: string;
|
||||
/** Proxy server URL (e.g., "https://genai.example.com") */
|
||||
proxyUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling LLM providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*
|
||||
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const agent = new Agent({
|
||||
* streamFn: (model, context, options) =>
|
||||
* streamProxy(model, context, {
|
||||
* ...options,
|
||||
* authToken: await getAuthToken(),
|
||||
* proxyUrl: "https://genai.example.com",
|
||||
* }),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export function streamProxy(
|
||||
model: Model<any>,
|
||||
context: Context,
|
||||
options: ProxyStreamOptions,
|
||||
): ProxyMessageEventStream {
|
||||
const stream = new ProxyMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${options.proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
const event = processProxyEvent(proxyEvent, partial);
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
const reason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.stopReason = reason;
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason,
|
||||
error: partial,
|
||||
});
|
||||
stream.end();
|
||||
} finally {
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a proxy event and update the partial message.
|
||||
*/
|
||||
function processProxyEvent(
|
||||
proxyEvent: ProxyAssistantMessageEvent,
|
||||
partial: AssistantMessage,
|
||||
): AssistantMessageEvent | undefined {
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
return { type: "start", partial };
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
|
||||
return {
|
||||
type: "text_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
};
|
||||
return {
|
||||
type: "thinking_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
return {
|
||||
type: "toolcall_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments =
|
||||
parseStreamingJson((content as any).partialJson) || {};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
|
||||
default: {
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
237
packages/agent/src/types.ts
Normal file
237
packages/agent/src/types.ts
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import type {
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
streamSimple,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
|
||||
/** Stream function - can return sync or Promise for async config lookup */
|
||||
export type StreamFn = (
|
||||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Configuration for the agent loop.
|
||||
*/
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
*
|
||||
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
|
||||
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
|
||||
* status messages) should be filtered out.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* convertToLlm: (messages) => messages.flatMap(m => {
|
||||
* if (m.role === "custom") {
|
||||
* // Convert custom message to user message
|
||||
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
|
||||
* }
|
||||
* if (m.role === "notification") {
|
||||
* // Filter out UI-only messages
|
||||
* return [];
|
||||
* }
|
||||
* // Pass through standard LLM messages
|
||||
* return [m];
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to the context before `convertToLlm`.
|
||||
*
|
||||
* Use this for operations that work at the AgentMessage level:
|
||||
* - Context window management (pruning old messages)
|
||||
* - Injecting context from external sources
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* transformContext: async (messages) => {
|
||||
* if (estimateTokens(messages) > MAX_TOKENS) {
|
||||
* return pruneOldMessages(messages);
|
||||
* }
|
||||
* return messages;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
*
|
||||
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
|
||||
* during long-running tool execution phases.
|
||||
*/
|
||||
getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Returns steering messages to inject into the conversation mid-run.
|
||||
*
|
||||
* Called after each tool execution to check for user interruptions.
|
||||
* If messages are returned, remaining tool calls are skipped and
|
||||
* these messages are added to the context before the next LLM call.
|
||||
*
|
||||
* Use this for "steering" the agent while it's working.
|
||||
*/
|
||||
getSteeringMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Returns follow-up messages to process after the agent would otherwise stop.
|
||||
*
|
||||
* Called when the agent has no more tool calls and no steering messages.
|
||||
* If messages are returned, they're added to the context and the agent
|
||||
* continues with another turn.
|
||||
*
|
||||
* Use this for follow-up messages that should wait until the agent finishes.
|
||||
*/
|
||||
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Thinking/reasoning level for models that support it.
|
||||
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
|
||||
*/
|
||||
export type ThinkingLevel =
|
||||
| "off"
|
||||
| "minimal"
|
||||
| "low"
|
||||
| "medium"
|
||||
| "high"
|
||||
| "xhigh";
|
||||
|
||||
/**
|
||||
* Extensible interface for custom app messages.
|
||||
* Apps can extend via declaration merging:
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* declare module "@mariozechner/agent" {
|
||||
* interface CustomAgentMessages {
|
||||
* artifact: ArtifactMessage;
|
||||
* notification: NotificationMessage;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export interface CustomAgentMessages {
|
||||
// Empty by default - apps extend via declaration merging
|
||||
}
|
||||
|
||||
/**
|
||||
* AgentMessage: Union of LLM messages + custom messages.
|
||||
* This abstraction allows apps to add custom message types while maintaining
|
||||
* type safety and compatibility with the base LLM messages.
|
||||
*/
|
||||
export type AgentMessage =
|
||||
| Message
|
||||
| CustomAgentMessages[keyof CustomAgentMessages];
|
||||
|
||||
/**
|
||||
* Agent state containing all configuration and conversation data.
|
||||
*/
|
||||
export interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[]; // Can include attachments + custom message types
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null;
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (
|
||||
partialResult: AgentToolResult<T>,
|
||||
) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<
|
||||
TParameters extends TSchema = TSchema,
|
||||
TDetails = any,
|
||||
> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: AgentMessage[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Events emitted by the Agent for UI updates.
|
||||
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
|
||||
*/
|
||||
export type AgentEvent =
|
||||
// Agent lifecycle
|
||||
| { type: "agent_start" }
|
||||
| { type: "agent_end"; messages: AgentMessage[] }
|
||||
// Turn lifecycle - a turn is one assistant response + any tool calls/results
|
||||
| { type: "turn_start" }
|
||||
| {
|
||||
type: "turn_end";
|
||||
message: AgentMessage;
|
||||
toolResults: ToolResultMessage[];
|
||||
}
|
||||
// Message lifecycle - emitted for user, assistant, and toolResult messages
|
||||
| { type: "message_start"; message: AgentMessage }
|
||||
// Only emitted for assistant messages during streaming
|
||||
| {
|
||||
type: "message_update";
|
||||
message: AgentMessage;
|
||||
assistantMessageEvent: AssistantMessageEvent;
|
||||
}
|
||||
| { type: "message_end"; message: AgentMessage }
|
||||
// Tool execution lifecycle
|
||||
| {
|
||||
type: "tool_execution_start";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: any;
|
||||
}
|
||||
| {
|
||||
type: "tool_execution_update";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: any;
|
||||
partialResult: any;
|
||||
}
|
||||
| {
|
||||
type: "tool_execution_end";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
result: any;
|
||||
isError: boolean;
|
||||
};
|
||||
629
packages/agent/test/agent-loop.test.ts
Normal file
629
packages/agent/test/agent-loop.test.ts
Normal file
|
|
@ -0,0 +1,629 @@
|
|||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
EventStream,
|
||||
type Message,
|
||||
type Model,
|
||||
type UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
} from "../src/types.js";
|
||||
|
||||
// Mock stream for testing - mimics MockAssistantStream
|
||||
class MockAssistantStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function createUsage() {
|
||||
return {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
function createModel(): Model<"openai-responses"> {
|
||||
return {
|
||||
id: "mock",
|
||||
name: "mock",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://example.invalid",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 2048,
|
||||
};
|
||||
}
|
||||
|
||||
function createAssistantMessage(
|
||||
content: AssistantMessage["content"],
|
||||
stopReason: AssistantMessage["stopReason"] = "stop",
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content,
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: createUsage(),
|
||||
stopReason,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createUserMessage(text: string): UserMessage {
|
||||
return {
|
||||
role: "user",
|
||||
content: text,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
// Simple identity converter for tests - just passes through standard messages
|
||||
function identityConverter(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter(
|
||||
(m) =>
|
||||
m.role === "user" || m.role === "assistant" || m.role === "toolResult",
|
||||
) as Message[];
|
||||
}
|
||||
|
||||
describe("agentLoop with AgentMessage", () => {
|
||||
it("should emit events with AgentMessage types", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "Hi there!" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop(
|
||||
[userPrompt],
|
||||
context,
|
||||
config,
|
||||
undefined,
|
||||
streamFn,
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should have user message and assistant message
|
||||
expect(messages.length).toBe(2);
|
||||
expect(messages[0].role).toBe("user");
|
||||
expect(messages[1].role).toBe("assistant");
|
||||
|
||||
// Verify event sequence
|
||||
const eventTypes = events.map((e) => e.type);
|
||||
expect(eventTypes).toContain("agent_start");
|
||||
expect(eventTypes).toContain("turn_start");
|
||||
expect(eventTypes).toContain("message_start");
|
||||
expect(eventTypes).toContain("message_end");
|
||||
expect(eventTypes).toContain("turn_end");
|
||||
expect(eventTypes).toContain("agent_end");
|
||||
});
|
||||
|
||||
it("should handle custom message types via convertToLlm", async () => {
|
||||
// Create a custom message type
|
||||
interface CustomNotification {
|
||||
role: "notification";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const notification: CustomNotification = {
|
||||
role: "notification",
|
||||
text: "This is a notification",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [notification as unknown as AgentMessage], // Custom message in context
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
let convertedMessages: Message[] = [];
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Filter out notifications, convert rest
|
||||
convertedMessages = messages
|
||||
.filter((m) => (m as { role: string }).role !== "notification")
|
||||
.filter(
|
||||
(m) =>
|
||||
m.role === "user" ||
|
||||
m.role === "assistant" ||
|
||||
m.role === "toolResult",
|
||||
) as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "Response" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop(
|
||||
[userPrompt],
|
||||
context,
|
||||
config,
|
||||
undefined,
|
||||
streamFn,
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// The notification should have been filtered out in convertToLlm
|
||||
expect(convertedMessages.length).toBe(1); // Only user message
|
||||
expect(convertedMessages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("should apply transformContext before convertToLlm", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [
|
||||
createUserMessage("old message 1"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 1" }]),
|
||||
createUserMessage("old message 2"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 2" }]),
|
||||
],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("new message");
|
||||
|
||||
let transformedMessages: AgentMessage[] = [];
|
||||
let convertedMessages: Message[] = [];
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
transformContext: async (messages) => {
|
||||
// Keep only last 2 messages (prune old ones)
|
||||
transformedMessages = messages.slice(-2);
|
||||
return transformedMessages;
|
||||
},
|
||||
convertToLlm: (messages) => {
|
||||
convertedMessages = messages.filter(
|
||||
(m) =>
|
||||
m.role === "user" ||
|
||||
m.role === "assistant" ||
|
||||
m.role === "toolResult",
|
||||
) as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "Response" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const stream = agentLoop(
|
||||
[userPrompt],
|
||||
context,
|
||||
config,
|
||||
undefined,
|
||||
streamFn,
|
||||
);
|
||||
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
// transformContext should have been called first, keeping only last 2
|
||||
expect(transformedMessages.length).toBe(2);
|
||||
// Then convertToLlm receives the pruned messages
|
||||
expect(convertedMessages.length).toBe(2);
|
||||
});
|
||||
|
||||
it("should handle tool calls and results", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `echoed: ${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("echo something");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
let callIndex = 0;
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return tool call
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "tool-1",
|
||||
name: "echo",
|
||||
arguments: { value: "hello" },
|
||||
},
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
stream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "done" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop(
|
||||
[userPrompt],
|
||||
context,
|
||||
config,
|
||||
undefined,
|
||||
streamFn,
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Tool should have been executed
|
||||
expect(executed).toEqual(["hello"]);
|
||||
|
||||
// Should have tool execution events
|
||||
const toolStart = events.find((e) => e.type === "tool_execution_start");
|
||||
const toolEnd = events.find((e) => e.type === "tool_execution_end");
|
||||
expect(toolStart).toBeDefined();
|
||||
expect(toolEnd).toBeDefined();
|
||||
if (toolEnd?.type === "tool_execution_end") {
|
||||
expect(toolEnd.isError).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("should inject queued messages and skip remaining tool calls", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `ok:${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("start");
|
||||
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
|
||||
|
||||
let queuedDelivered = false;
|
||||
let callIndex = 0;
|
||||
let sawInterruptInContext = false;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
getSteeringMessages: async () => {
|
||||
// Return steering message after first tool executes
|
||||
if (executed.length === 1 && !queuedDelivered) {
|
||||
queuedDelivered = true;
|
||||
return [queuedUserMessage];
|
||||
}
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop(
|
||||
[userPrompt],
|
||||
context,
|
||||
config,
|
||||
undefined,
|
||||
(_model, ctx, _options) => {
|
||||
// Check if interrupt message is in context on second call
|
||||
if (callIndex === 1) {
|
||||
sawInterruptInContext = ctx.messages.some(
|
||||
(m) =>
|
||||
m.role === "user" &&
|
||||
typeof m.content === "string" &&
|
||||
m.content === "interrupt",
|
||||
);
|
||||
}
|
||||
|
||||
const mockStream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return two tool calls
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "tool-1",
|
||||
name: "echo",
|
||||
arguments: { value: "first" },
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "tool-2",
|
||||
name: "echo",
|
||||
arguments: { value: "second" },
|
||||
},
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
mockStream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "done" },
|
||||
]);
|
||||
mockStream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return mockStream;
|
||||
},
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Only first tool should have executed
|
||||
expect(executed).toEqual(["first"]);
|
||||
|
||||
// Second tool should be skipped
|
||||
const toolEnds = events.filter(
|
||||
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> =>
|
||||
e.type === "tool_execution_end",
|
||||
);
|
||||
expect(toolEnds.length).toBe(2);
|
||||
expect(toolEnds[0].isError).toBe(false);
|
||||
expect(toolEnds[1].isError).toBe(true);
|
||||
if (toolEnds[1].result.content[0]?.type === "text") {
|
||||
expect(toolEnds[1].result.content[0].text).toContain(
|
||||
"Skipped due to queued user message",
|
||||
);
|
||||
}
|
||||
|
||||
// Queued message should appear in events
|
||||
const queuedMessageEvent = events.find(
|
||||
(e) =>
|
||||
e.type === "message_start" &&
|
||||
e.message.role === "user" &&
|
||||
typeof e.message.content === "string" &&
|
||||
e.message.content === "interrupt",
|
||||
);
|
||||
expect(queuedMessageEvent).toBeDefined();
|
||||
|
||||
// Interrupt message should be in context when second LLM call is made
|
||||
expect(sawInterruptInContext).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("agentLoopContinue with AgentMessage", () => {
|
||||
it("should throw when context has no messages", () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
expect(() => agentLoopContinue(context, config)).toThrow(
|
||||
"Cannot continue: no messages in context",
|
||||
);
|
||||
});
|
||||
|
||||
it("should continue from existing context without emitting user message events", async () => {
|
||||
const userMessage: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [userMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "Response" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should only return the new assistant message (not the existing user message)
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
|
||||
// Should NOT have user message events (that's the key difference from agentLoop)
|
||||
const messageEndEvents = events.filter((e) => e.type === "message_end");
|
||||
expect(messageEndEvents.length).toBe(1);
|
||||
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
|
||||
});
|
||||
|
||||
it("should allow custom message types as last message (caller responsibility)", async () => {
|
||||
// Custom message that will be converted to user message by convertToLlm
|
||||
interface CustomMessage {
|
||||
role: "custom";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const customMessage: CustomMessage = {
|
||||
role: "custom",
|
||||
text: "Hook content",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [customMessage as unknown as AgentMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Convert custom to user message
|
||||
return messages
|
||||
.map((m) => {
|
||||
if ((m as any).role === "custom") {
|
||||
return {
|
||||
role: "user" as const,
|
||||
content: (m as any).text,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
return m;
|
||||
})
|
||||
.filter(
|
||||
(m) =>
|
||||
m.role === "user" ||
|
||||
m.role === "assistant" ||
|
||||
m.role === "toolResult",
|
||||
) as Message[];
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([
|
||||
{ type: "text", text: "Response to custom message" },
|
||||
]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Should not throw - the custom message will be converted to user message
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
});
|
||||
});
|
||||
383
packages/agent/test/agent.test.ts
Normal file
383
packages/agent/test/agent.test.ts
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
EventStream,
|
||||
getModel,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
|
||||
// Mock stream that mimics AssistantMessageEventStream
|
||||
class MockAssistantStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function createAssistantMessage(text: string): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text }],
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
describe("Agent", () => {
|
||||
it("should create an agent instance with default state", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
expect(agent.state).toBeDefined();
|
||||
expect(agent.state.systemPrompt).toBe("");
|
||||
expect(agent.state.model).toBeDefined();
|
||||
expect(agent.state.thinkingLevel).toBe("off");
|
||||
expect(agent.state.tools).toEqual([]);
|
||||
expect(agent.state.messages).toEqual([]);
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.streamMessage).toBe(null);
|
||||
expect(agent.state.pendingToolCalls).toEqual(new Set());
|
||||
expect(agent.state.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should create an agent instance with custom initial state", () => {
|
||||
const customModel = getModel("openai", "gpt-4o-mini");
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model: customModel,
|
||||
thinkingLevel: "low",
|
||||
},
|
||||
});
|
||||
|
||||
expect(agent.state.systemPrompt).toBe("You are a helpful assistant.");
|
||||
expect(agent.state.model).toBe(customModel);
|
||||
expect(agent.state.thinkingLevel).toBe("low");
|
||||
});
|
||||
|
||||
it("should subscribe to events", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
let eventCount = 0;
|
||||
const unsubscribe = agent.subscribe((_event) => {
|
||||
eventCount++;
|
||||
});
|
||||
|
||||
// No initial event on subscribe
|
||||
expect(eventCount).toBe(0);
|
||||
|
||||
// State mutators don't emit events
|
||||
agent.setSystemPrompt("Test prompt");
|
||||
expect(eventCount).toBe(0);
|
||||
expect(agent.state.systemPrompt).toBe("Test prompt");
|
||||
|
||||
// Unsubscribe should work
|
||||
unsubscribe();
|
||||
agent.setSystemPrompt("Another prompt");
|
||||
expect(eventCount).toBe(0); // Should not increase
|
||||
});
|
||||
|
||||
it("should update state with mutators", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
// Test setSystemPrompt
|
||||
agent.setSystemPrompt("Custom prompt");
|
||||
expect(agent.state.systemPrompt).toBe("Custom prompt");
|
||||
|
||||
// Test setModel
|
||||
const newModel = getModel("google", "gemini-2.5-flash");
|
||||
agent.setModel(newModel);
|
||||
expect(agent.state.model).toBe(newModel);
|
||||
|
||||
// Test setThinkingLevel
|
||||
agent.setThinkingLevel("high");
|
||||
expect(agent.state.thinkingLevel).toBe("high");
|
||||
|
||||
// Test setTools
|
||||
const tools = [{ name: "test", description: "test tool" } as any];
|
||||
agent.setTools(tools);
|
||||
expect(agent.state.tools).toBe(tools);
|
||||
|
||||
// Test replaceMessages
|
||||
const messages = [
|
||||
{ role: "user" as const, content: "Hello", timestamp: Date.now() },
|
||||
];
|
||||
agent.replaceMessages(messages);
|
||||
expect(agent.state.messages).toEqual(messages);
|
||||
expect(agent.state.messages).not.toBe(messages); // Should be a copy
|
||||
|
||||
// Test appendMessage
|
||||
const newMessage = {
|
||||
role: "assistant" as const,
|
||||
content: [{ type: "text" as const, text: "Hi" }],
|
||||
};
|
||||
agent.appendMessage(newMessage as any);
|
||||
expect(agent.state.messages).toHaveLength(2);
|
||||
expect(agent.state.messages[1]).toBe(newMessage);
|
||||
|
||||
// Test clearMessages
|
||||
agent.clearMessages();
|
||||
expect(agent.state.messages).toEqual([]);
|
||||
});
|
||||
|
||||
it("should support steering message queue", async () => {
|
||||
const agent = new Agent();
|
||||
|
||||
const message = {
|
||||
role: "user" as const,
|
||||
content: "Steering message",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.steer(message);
|
||||
|
||||
// The message is queued but not yet in state.messages
|
||||
expect(agent.state.messages).not.toContainEqual(message);
|
||||
});
|
||||
|
||||
it("should support follow-up message queue", async () => {
|
||||
const agent = new Agent();
|
||||
|
||||
const message = {
|
||||
role: "user" as const,
|
||||
content: "Follow-up message",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.followUp(message);
|
||||
|
||||
// The message is queued but not yet in state.messages
|
||||
expect(agent.state.messages).not.toContainEqual(message);
|
||||
});
|
||||
|
||||
it("should handle abort controller", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
// Should not throw even if nothing is running
|
||||
expect(() => agent.abort()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should throw when prompt() called while streaming", async () => {
|
||||
let abortSignal: AbortSignal | undefined;
|
||||
const agent = new Agent({
|
||||
// Use a stream function that responds to abort
|
||||
streamFn: (_model, _context, options) => {
|
||||
abortSignal = options?.signal;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({ type: "start", partial: createAssistantMessage("") });
|
||||
// Check abort signal periodically
|
||||
const checkAbort = () => {
|
||||
if (abortSignal?.aborted) {
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: "aborted",
|
||||
error: createAssistantMessage("Aborted"),
|
||||
});
|
||||
} else {
|
||||
setTimeout(checkAbort, 5);
|
||||
}
|
||||
};
|
||||
checkAbort();
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
// Start first prompt (don't await, it will block until abort)
|
||||
const firstPrompt = agent.prompt("First message");
|
||||
|
||||
// Wait a tick for isStreaming to be set
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
expect(agent.state.isStreaming).toBe(true);
|
||||
|
||||
// Second prompt should reject
|
||||
await expect(agent.prompt("Second message")).rejects.toThrow(
|
||||
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||
);
|
||||
|
||||
// Cleanup - abort to stop the stream
|
||||
agent.abort();
|
||||
await firstPrompt.catch(() => {}); // Ignore abort error
|
||||
});
|
||||
|
||||
it("should throw when continue() called while streaming", async () => {
|
||||
let abortSignal: AbortSignal | undefined;
|
||||
const agent = new Agent({
|
||||
streamFn: (_model, _context, options) => {
|
||||
abortSignal = options?.signal;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({ type: "start", partial: createAssistantMessage("") });
|
||||
const checkAbort = () => {
|
||||
if (abortSignal?.aborted) {
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: "aborted",
|
||||
error: createAssistantMessage("Aborted"),
|
||||
});
|
||||
} else {
|
||||
setTimeout(checkAbort, 5);
|
||||
}
|
||||
};
|
||||
checkAbort();
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
// Start first prompt
|
||||
const firstPrompt = agent.prompt("First message");
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
expect(agent.state.isStreaming).toBe(true);
|
||||
|
||||
// continue() should reject
|
||||
await expect(agent.continue()).rejects.toThrow(
|
||||
"Agent is already processing. Wait for completion before continuing.",
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
agent.abort();
|
||||
await firstPrompt.catch(() => {});
|
||||
});
|
||||
|
||||
it("continue() should process queued follow-up messages after an assistant turn", async () => {
|
||||
const agent = new Agent({
|
||||
streamFn: () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: createAssistantMessage("Processed"),
|
||||
});
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
agent.replaceMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Initial" }],
|
||||
timestamp: Date.now() - 10,
|
||||
},
|
||||
createAssistantMessage("Initial response"),
|
||||
]);
|
||||
|
||||
agent.followUp({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Queued follow-up" }],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
await expect(agent.continue()).resolves.toBeUndefined();
|
||||
|
||||
const hasQueuedFollowUp = agent.state.messages.some((message) => {
|
||||
if (message.role !== "user") return false;
|
||||
if (typeof message.content === "string")
|
||||
return message.content === "Queued follow-up";
|
||||
return message.content.some(
|
||||
(part) => part.type === "text" && part.text === "Queued follow-up",
|
||||
);
|
||||
});
|
||||
|
||||
expect(hasQueuedFollowUp).toBe(true);
|
||||
expect(agent.state.messages[agent.state.messages.length - 1].role).toBe(
|
||||
"assistant",
|
||||
);
|
||||
});
|
||||
|
||||
it("continue() should keep one-at-a-time steering semantics from assistant tail", async () => {
|
||||
let responseCount = 0;
|
||||
const agent = new Agent({
|
||||
streamFn: () => {
|
||||
const stream = new MockAssistantStream();
|
||||
responseCount++;
|
||||
queueMicrotask(() => {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: createAssistantMessage(`Processed ${responseCount}`),
|
||||
});
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
agent.replaceMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Initial" }],
|
||||
timestamp: Date.now() - 10,
|
||||
},
|
||||
createAssistantMessage("Initial response"),
|
||||
]);
|
||||
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Steering 1" }],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Steering 2" }],
|
||||
timestamp: Date.now() + 1,
|
||||
});
|
||||
|
||||
await expect(agent.continue()).resolves.toBeUndefined();
|
||||
|
||||
const recentMessages = agent.state.messages.slice(-4);
|
||||
expect(recentMessages.map((m) => m.role)).toEqual([
|
||||
"user",
|
||||
"assistant",
|
||||
"user",
|
||||
"assistant",
|
||||
]);
|
||||
expect(responseCount).toBe(2);
|
||||
});
|
||||
|
||||
it("forwards sessionId to streamFn options", async () => {
|
||||
let receivedSessionId: string | undefined;
|
||||
const agent = new Agent({
|
||||
sessionId: "session-abc",
|
||||
streamFn: (_model, _context, options) => {
|
||||
receivedSessionId = options?.sessionId;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage("ok");
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("hello");
|
||||
expect(receivedSessionId).toBe("session-abc");
|
||||
|
||||
// Test setter
|
||||
agent.sessionId = "session-def";
|
||||
expect(agent.sessionId).toBe("session-def");
|
||||
|
||||
await agent.prompt("hello again");
|
||||
expect(receivedSessionId).toBe("session-def");
|
||||
});
|
||||
});
|
||||
316
packages/agent/test/bedrock-models.test.ts
Normal file
316
packages/agent/test/bedrock-models.test.ts
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
/**
|
||||
* A test suite to ensure Amazon Bedrock models work correctly with the agent loop.
|
||||
*
|
||||
* Some Bedrock models don't support all features (e.g., reasoning signatures).
|
||||
* This test suite verifies that the agent loop works with various Bedrock models.
|
||||
*
|
||||
* This test suite is not enabled by default unless AWS credentials and
|
||||
* `BEDROCK_EXTENSIVE_MODEL_TEST` environment variables are set.
|
||||
*
|
||||
* You can run this test suite with:
|
||||
* ```bash
|
||||
* $ AWS_REGION=us-east-1 BEDROCK_EXTENSIVE_MODEL_TEST=1 AWS_PROFILE=pi npm test -- ./test/bedrock-models.test.ts
|
||||
* ```
|
||||
*
|
||||
* ## Known Issues by Category
|
||||
*
|
||||
* 1. **Inference Profile Required**: Some models require an inference profile ARN instead of on-demand.
|
||||
* 2. **Invalid Model ID**: Model identifiers that don't exist in the current region.
|
||||
* 3. **Max Tokens Exceeded**: Model's maxTokens in our config exceeds the actual limit.
|
||||
* 4. **No Reasoning in User Messages**: Model rejects reasoning content when replayed in conversation.
|
||||
* 5. **Invalid Signature Format**: Model validates signature format (Anthropic newer models).
|
||||
*/
|
||||
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import { getModels } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
import { hasBedrockCredentials } from "./bedrock-utils.js";
|
||||
|
||||
// =============================================================================
|
||||
// Known Issue Categories
|
||||
// =============================================================================
|
||||
|
||||
/** Models that require inference profile ARN (not available on-demand in us-east-1) */
|
||||
const REQUIRES_INFERENCE_PROFILE = new Set([
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
]);
|
||||
|
||||
/** Models with invalid identifiers (not available in us-east-1 or don't exist) */
|
||||
const INVALID_MODEL_ID = new Set([
|
||||
"deepseek.v3-v1:0",
|
||||
"eu.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"eu.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"qwen.qwen3-235b-a22b-2507-v1:0",
|
||||
"qwen.qwen3-coder-480b-a35b-v1:0",
|
||||
]);
|
||||
|
||||
/** Models where our maxTokens config exceeds the model's actual limit */
|
||||
const MAX_TOKENS_EXCEEDED = new Set([
|
||||
"us.meta.llama4-maverick-17b-instruct-v1:0",
|
||||
"us.meta.llama4-scout-17b-instruct-v1:0",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Models that reject reasoning content in user messages (when replaying conversation).
|
||||
* These work for multi-turn but fail when synthetic thinking is injected.
|
||||
*/
|
||||
const NO_REASONING_IN_USER_MESSAGES = new Set([
|
||||
// Mistral models
|
||||
"mistral.ministral-3-14b-instruct",
|
||||
"mistral.ministral-3-8b-instruct",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.voxtral-mini-3b-2507",
|
||||
"mistral.voxtral-small-24b-2507",
|
||||
// Nvidia models
|
||||
"nvidia.nemotron-nano-12b-v2",
|
||||
"nvidia.nemotron-nano-9b-v2",
|
||||
// Qwen models
|
||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||
// Amazon Nova models
|
||||
"us.amazon.nova-lite-v1:0",
|
||||
"us.amazon.nova-micro-v1:0",
|
||||
"us.amazon.nova-premier-v1:0",
|
||||
"us.amazon.nova-pro-v1:0",
|
||||
// Meta Llama models
|
||||
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"us.meta.llama3-2-1b-instruct-v1:0",
|
||||
"us.meta.llama3-2-3b-instruct-v1:0",
|
||||
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||
"us.meta.llama3-3-70b-instruct-v1:0",
|
||||
// DeepSeek
|
||||
"us.deepseek.r1-v1:0",
|
||||
// Older Anthropic models
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
// Cohere models
|
||||
"cohere.command-r-plus-v1:0",
|
||||
"cohere.command-r-v1:0",
|
||||
// Google models
|
||||
"google.gemma-3-27b-it",
|
||||
"google.gemma-3-4b-it",
|
||||
// Non-Anthropic models that don't support signatures (now handled by omitting signature)
|
||||
// but still reject reasoning content in user messages
|
||||
"global.amazon.nova-2-lite-v1:0",
|
||||
"minimax.minimax-m2",
|
||||
"moonshot.kimi-k2-thinking",
|
||||
"openai.gpt-oss-120b-1:0",
|
||||
"openai.gpt-oss-20b-1:0",
|
||||
"openai.gpt-oss-safeguard-120b",
|
||||
"openai.gpt-oss-safeguard-20b",
|
||||
"qwen.qwen3-32b-v1:0",
|
||||
"qwen.qwen3-next-80b-a3b",
|
||||
"qwen.qwen3-vl-235b-a22b",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Models that validate signature format (Anthropic newer models).
|
||||
* These work for multi-turn but fail when synthetic/invalid signature is injected.
|
||||
*/
|
||||
const VALIDATES_SIGNATURE_FORMAT = new Set([
|
||||
"global.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"global.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"global.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
]);
|
||||
|
||||
/**
|
||||
* DeepSeek R1 fails multi-turn because it rejects reasoning in the replayed assistant message.
|
||||
*/
|
||||
const REJECTS_REASONING_ON_REPLAY = new Set(["us.deepseek.r1-v1:0"]);
|
||||
|
||||
// =============================================================================
|
||||
// Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
function isModelUnavailable(modelId: string): boolean {
|
||||
return (
|
||||
REQUIRES_INFERENCE_PROFILE.has(modelId) ||
|
||||
INVALID_MODEL_ID.has(modelId) ||
|
||||
MAX_TOKENS_EXCEEDED.has(modelId)
|
||||
);
|
||||
}
|
||||
|
||||
function failsMultiTurnWithThinking(modelId: string): boolean {
|
||||
return REJECTS_REASONING_ON_REPLAY.has(modelId);
|
||||
}
|
||||
|
||||
function failsSyntheticSignature(modelId: string): boolean {
|
||||
return (
|
||||
NO_REASONING_IN_USER_MESSAGES.has(modelId) ||
|
||||
VALIDATES_SIGNATURE_FORMAT.has(modelId)
|
||||
);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
|
||||
describe("Amazon Bedrock Models - Agent Loop", () => {
|
||||
const shouldRunExtensiveTests =
|
||||
hasBedrockCredentials() && process.env.BEDROCK_EXTENSIVE_MODEL_TEST;
|
||||
|
||||
// Get all Amazon Bedrock models
|
||||
const allBedrockModels = getModels("amazon-bedrock");
|
||||
|
||||
if (shouldRunExtensiveTests) {
|
||||
for (const model of allBedrockModels) {
|
||||
const modelId = model.id;
|
||||
|
||||
describe(`Model: ${modelId}`, () => {
|
||||
// Skip entirely unavailable models
|
||||
const unavailable = isModelUnavailable(modelId);
|
||||
|
||||
it.skipIf(unavailable)(
|
||||
"should handle basic text prompt",
|
||||
{ timeout: 60_000 },
|
||||
async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("Reply with exactly: 'OK'");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`Basic prompt error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
|
||||
const assistantMessage = agent.state.messages[1];
|
||||
if (assistantMessage.role !== "assistant")
|
||||
throw new Error("Expected assistant message");
|
||||
|
||||
console.log(`${modelId}: OK`);
|
||||
},
|
||||
);
|
||||
|
||||
// Skip if model is unavailable or known to fail multi-turn with thinking
|
||||
const skipMultiTurn =
|
||||
unavailable || failsMultiTurnWithThinking(modelId);
|
||||
|
||||
it.skipIf(skipMultiTurn)(
|
||||
"should handle multi-turn conversation with thinking content in history",
|
||||
{ timeout: 120_000 },
|
||||
async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "medium",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// First turn
|
||||
await agent.prompt("My name is Alice.");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`First turn error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
// Second turn - this should replay the first assistant message which may contain thinking
|
||||
await agent.prompt("What is my name?");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`Second turn error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
console.log(`${modelId}: multi-turn OK`);
|
||||
},
|
||||
);
|
||||
|
||||
// Skip if model is unavailable or known to fail synthetic signature
|
||||
const skipSynthetic = unavailable || failsSyntheticSignature(modelId);
|
||||
|
||||
it.skipIf(skipSynthetic)(
|
||||
"should handle conversation with synthetic thinking signature in history",
|
||||
{ timeout: 60_000 },
|
||||
async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// Inject a message with a thinking block that has a signature
|
||||
const syntheticAssistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "I need to remember the user's name.",
|
||||
thinkingSignature: "synthetic-signature-123",
|
||||
},
|
||||
{ type: "text", text: "Nice to meet you, Alice!" },
|
||||
],
|
||||
api: "bedrock-converse-stream",
|
||||
provider: "amazon-bedrock",
|
||||
model: modelId,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 20,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 30,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
agent.replaceMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: "My name is Alice.",
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
syntheticAssistantMessage,
|
||||
]);
|
||||
|
||||
await agent.prompt("What is my name?");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(
|
||||
`Synthetic signature error: ${agent.state.error}`,
|
||||
);
|
||||
}
|
||||
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
console.log(`${modelId}: synthetic signature OK`);
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
it.skip("skipped - set AWS credentials and BEDROCK_EXTENSIVE_MODEL_TEST=1 to run", () => {});
|
||||
}
|
||||
});
|
||||
18
packages/agent/test/bedrock-utils.ts
Normal file
18
packages/agent/test/bedrock-utils.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
/**
|
||||
* Utility functions for Amazon Bedrock tests
|
||||
*/
|
||||
|
||||
/**
|
||||
* Check if any valid AWS credentials are configured for Bedrock.
|
||||
* Returns true if any of the following are set:
|
||||
* - AWS_PROFILE (named profile from ~/.aws/credentials)
|
||||
* - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (IAM keys)
|
||||
* - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
|
||||
*/
|
||||
export function hasBedrockCredentials(): boolean {
|
||||
return !!(
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK
|
||||
);
|
||||
}
|
||||
571
packages/agent/test/e2e.test.ts
Normal file
571
packages/agent/test/e2e.test.ts
Normal file
|
|
@ -0,0 +1,571 @@
|
|||
import type {
|
||||
AssistantMessage,
|
||||
Model,
|
||||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
import { hasBedrockCredentials } from "./bedrock-utils.js";
|
||||
import { calculateTool } from "./utils/calculate.js";
|
||||
|
||||
async function basicPrompt(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Keep your responses concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("What is 2+2? Answer with just the number.");
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
expect(agent.state.messages[0].role).toBe("user");
|
||||
expect(agent.state.messages[1].role).toBe("assistant");
|
||||
|
||||
const assistantMessage = agent.state.messages[1];
|
||||
if (assistantMessage.role !== "assistant")
|
||||
throw new Error("Expected assistant message");
|
||||
expect(assistantMessage.content.length).toBeGreaterThan(0);
|
||||
|
||||
const textContent = assistantMessage.content.find((c) => c.type === "text");
|
||||
expect(textContent).toBeDefined();
|
||||
if (textContent?.type !== "text") throw new Error("Expected text content");
|
||||
expect(textContent.text).toContain("4");
|
||||
}
|
||||
|
||||
async function toolExecution(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. Always use the calculator tool for math.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(3);
|
||||
|
||||
const toolResultMsg = agent.state.messages.find(
|
||||
(m) => m.role === "toolResult",
|
||||
);
|
||||
expect(toolResultMsg).toBeDefined();
|
||||
if (toolResultMsg?.role !== "toolResult")
|
||||
throw new Error("Expected tool result message");
|
||||
const textContent =
|
||||
toolResultMsg.content
|
||||
?.filter((c) => c.type === "text")
|
||||
.map((c: any) => c.text)
|
||||
.join("\n") || "";
|
||||
expect(textContent).toBeDefined();
|
||||
|
||||
const expectedResult = 123 * 456;
|
||||
expect(textContent).toContain(String(expectedResult));
|
||||
|
||||
const finalMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
if (finalMessage.role !== "assistant")
|
||||
throw new Error("Expected final assistant message");
|
||||
const finalText = finalMessage.content.find((c) => c.type === "text");
|
||||
expect(finalText).toBeDefined();
|
||||
if (finalText?.type !== "text") throw new Error("Expected text content");
|
||||
// Check for number with or without comma formatting
|
||||
const hasNumber =
|
||||
finalText.text.includes(String(expectedResult)) ||
|
||||
finalText.text.includes("56,088") ||
|
||||
finalText.text.includes("56088");
|
||||
expect(hasNumber).toBe(true);
|
||||
}
|
||||
|
||||
async function abortExecution(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
const promptPromise = agent.prompt(
|
||||
"Calculate 100 * 200, then 300 * 400, then sum the results.",
|
||||
);
|
||||
|
||||
setTimeout(() => {
|
||||
agent.abort();
|
||||
}, 100);
|
||||
|
||||
await promptPromise;
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(2);
|
||||
|
||||
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
if (lastMessage.role !== "assistant")
|
||||
throw new Error("Expected assistant message");
|
||||
expect(lastMessage.stopReason).toBe("aborted");
|
||||
expect(lastMessage.errorMessage).toBeDefined();
|
||||
expect(agent.state.error).toBeDefined();
|
||||
expect(agent.state.error).toBe(lastMessage.errorMessage);
|
||||
}
|
||||
|
||||
async function stateUpdates(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
const events: Array<string> = [];
|
||||
|
||||
agent.subscribe((event) => {
|
||||
events.push(event.type);
|
||||
});
|
||||
|
||||
await agent.prompt("Count from 1 to 5.");
|
||||
|
||||
// Should have received lifecycle events
|
||||
expect(events).toContain("agent_start");
|
||||
expect(events).toContain("agent_end");
|
||||
expect(events).toContain("message_start");
|
||||
expect(events).toContain("message_end");
|
||||
// May have message_update events during streaming
|
||||
const hasMessageUpdates = events.some((e) => e === "message_update");
|
||||
expect(hasMessageUpdates).toBe(true);
|
||||
|
||||
// Check final state
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2); // User message + assistant response
|
||||
}
|
||||
|
||||
async function multiTurnConversation(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("My name is Alice.");
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
|
||||
await agent.prompt("What is my name?");
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
|
||||
const lastMessage = agent.state.messages[3];
|
||||
if (lastMessage.role !== "assistant")
|
||||
throw new Error("Expected assistant message");
|
||||
const lastText = lastMessage.content.find((c) => c.type === "text");
|
||||
if (lastText?.type !== "text") throw new Error("Expected text content");
|
||||
expect(lastText.text.toLowerCase()).toContain("alice");
|
||||
}
|
||||
|
||||
describe("Agent E2E Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)(
|
||||
"Google Provider (gemini-2.5-flash)",
|
||||
() => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)(
|
||||
"OpenAI Provider (gpt-4o-mini)",
|
||||
() => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
|
||||
"Anthropic Provider (claude-haiku-4-5)",
|
||||
() => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-3)", () => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)(
|
||||
"Groq Provider (openai/gpt-oss-20b)",
|
||||
() => {
|
||||
const model = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)(
|
||||
"Cerebras Provider (gpt-oss-120b)",
|
||||
() => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.ZAI_API_KEY)(
|
||||
"zAI Provider (glm-4.5-air)",
|
||||
() => {
|
||||
const model = getModel("zai", "glm-4.5-air");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!hasBedrockCredentials())(
|
||||
"Amazon Bedrock Provider (claude-sonnet-4-5)",
|
||||
() => {
|
||||
const model = getModel(
|
||||
"amazon-bedrock",
|
||||
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
);
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("Agent.continue()", () => {
|
||||
describe("validation", () => {
|
||||
it("should throw when no messages in context", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
});
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow(
|
||||
"No messages to continue from",
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw when last message is assistant", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([assistantMessage]);
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow(
|
||||
"Cannot continue from message role: assistant",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
|
||||
"continue from user message",
|
||||
() => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue and get response when last message is user", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. Follow instructions exactly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// Manually add a user message without calling prompt()
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([userMessage]);
|
||||
|
||||
// Continue from the user message
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
expect(agent.state.messages[0].role).toBe("user");
|
||||
expect(agent.state.messages[1].role).toBe("assistant");
|
||||
|
||||
const assistantMsg = agent.state.messages[1] as AssistantMessage;
|
||||
const textContent = assistantMsg.content.find((c) => c.type === "text");
|
||||
expect(textContent).toBeDefined();
|
||||
if (textContent?.type === "text") {
|
||||
expect(textContent.text.toUpperCase()).toContain("HELLO WORLD");
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)(
|
||||
"continue from tool result",
|
||||
() => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue and process tool results", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. After getting a calculation result, state the answer clearly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
// Set up a conversation state as if tool was just executed
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "What is 5 + 3?" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me calculate that." },
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "calc-1",
|
||||
name: "calculate",
|
||||
arguments: { expression: "5 + 3" },
|
||||
},
|
||||
],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const toolResult: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: "calc-1",
|
||||
toolName: "calculate",
|
||||
content: [{ type: "text", text: "5 + 3 = 8" }],
|
||||
isError: false,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
agent.replaceMessages([userMessage, assistantMessage, toolResult]);
|
||||
|
||||
// Continue from the tool result
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
// Should have added an assistant response
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(4);
|
||||
|
||||
const lastMessage =
|
||||
agent.state.messages[agent.state.messages.length - 1];
|
||||
expect(lastMessage.role).toBe("assistant");
|
||||
|
||||
if (lastMessage.role === "assistant") {
|
||||
const textContent = lastMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as { type: "text"; text: string }).text)
|
||||
.join(" ");
|
||||
// Should mention 8 in the response
|
||||
expect(textContent).toMatch(/8/);
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
37
packages/agent/test/utils/calculate.ts
Normal file
37
packages/agent/test/utils/calculate.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface CalculateResult extends AgentToolResult<undefined> {
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: undefined;
|
||||
}
|
||||
|
||||
export function calculate(expression: string): CalculateResult {
|
||||
try {
|
||||
const result = new Function(`return ${expression}`)();
|
||||
return {
|
||||
content: [{ type: "text", text: `${expression} = ${result}` }],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (e: any) {
|
||||
throw new Error(e.message || String(e));
|
||||
}
|
||||
}
|
||||
|
||||
const calculateSchema = Type.Object({
|
||||
expression: Type.String({
|
||||
description: "The mathematical expression to evaluate",
|
||||
}),
|
||||
});
|
||||
|
||||
type CalculateParams = Static<typeof calculateSchema>;
|
||||
|
||||
export const calculateTool: AgentTool<typeof calculateSchema, undefined> = {
|
||||
label: "Calculator",
|
||||
name: "calculate",
|
||||
description: "Evaluate mathematical expressions",
|
||||
parameters: calculateSchema,
|
||||
execute: async (_toolCallId: string, args: CalculateParams) => {
|
||||
return calculate(args.expression);
|
||||
},
|
||||
};
|
||||
61
packages/agent/test/utils/get-current-time.ts
Normal file
61
packages/agent/test/utils/get-current-time.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface GetCurrentTimeResult extends AgentToolResult<{
|
||||
utcTimestamp: number;
|
||||
}> {}
|
||||
|
||||
export async function getCurrentTime(
|
||||
timezone?: string,
|
||||
): Promise<GetCurrentTimeResult> {
|
||||
const date = new Date();
|
||||
if (timezone) {
|
||||
try {
|
||||
const timeStr = date.toLocaleString("en-US", {
|
||||
timeZone: timezone,
|
||||
dateStyle: "full",
|
||||
timeStyle: "long",
|
||||
});
|
||||
return {
|
||||
content: [{ type: "text", text: timeStr }],
|
||||
details: { utcTimestamp: date.getTime() },
|
||||
};
|
||||
} catch (_e) {
|
||||
throw new Error(
|
||||
`Invalid timezone: ${timezone}. Current UTC time: ${date.toISOString()}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
const timeStr = date.toLocaleString("en-US", {
|
||||
dateStyle: "full",
|
||||
timeStyle: "long",
|
||||
});
|
||||
return {
|
||||
content: [{ type: "text", text: timeStr }],
|
||||
details: { utcTimestamp: date.getTime() },
|
||||
};
|
||||
}
|
||||
|
||||
const getCurrentTimeSchema = Type.Object({
|
||||
timezone: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional timezone (e.g., 'America/New_York', 'Europe/London')",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
type GetCurrentTimeParams = Static<typeof getCurrentTimeSchema>;
|
||||
|
||||
export const getCurrentTimeTool: AgentTool<
|
||||
typeof getCurrentTimeSchema,
|
||||
{ utcTimestamp: number }
|
||||
> = {
|
||||
label: "Current Time",
|
||||
name: "get_current_time",
|
||||
description: "Get the current date and time",
|
||||
parameters: getCurrentTimeSchema,
|
||||
execute: async (_toolCallId: string, args: GetCurrentTimeParams) => {
|
||||
return getCurrentTime(args.timezone);
|
||||
},
|
||||
};
|
||||
9
packages/agent/tsconfig.build.json
Normal file
9
packages/agent/tsconfig.build.json
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"extends": "../../tsconfig.base.json",
|
||||
"compilerOptions": {
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
|
||||
}
|
||||
9
packages/agent/vitest.config.ts
Normal file
9
packages/agent/vitest.config.ts
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import { defineConfig } from "vitest/config";
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
environment: "node",
|
||||
testTimeout: 30000, // 30 seconds for API calls
|
||||
},
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue