WIP: Refactor agent package - not compiling

- Renamed AppMessage to AgentMessage throughout
- New agent-loop.ts with AgentLoopContext, AgentLoopConfig
- Removed transport abstraction, Agent now takes streamFn directly
- Extracted streamProxy to proxy.ts utility
- Removed agent-loop from pi-ai (now in agent package)
- Updated consumers (coding-agent, mom) for AgentMessage rename
- Tests updated but some consumers still need migration

Known issues:
- AgentTool, AgentToolResult not exported from pi-ai
- Attachment not exported from pi-agent-core
- ProviderTransport removed but still referenced
- messageTransformer -> convertToLlm migration incomplete
- CustomMessages declaration merging not working properly
This commit is contained in:
Mario Zechner 2025-12-28 09:23:38 +01:00
parent f7ef44dc38
commit a055fd4481
32 changed files with 1312 additions and 2009 deletions

View file

@ -1,24 +1,41 @@
import { streamSimple } from "../stream.js";
import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js";
import { EventStream } from "../utils/event-stream.js";
import { validateToolArguments } from "../utils/validation.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, QueuedMessage } from "./types.js";
/**
* Agent loop that works with AgentMessage throughout.
* Transforms to Message[] only at the LLM call boundary.
*/
import {
type AssistantMessage,
type Context,
EventStream,
streamSimple,
type ToolResultMessage,
validateToolArguments,
} from "@mariozechner/pi-ai";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
AgentToolResult,
StreamFn,
} from "./types.js";
/**
* Start an agent loop with a new user message.
* Start an agent loop with a new prompt message.
* The prompt is added to the context and events are emitted for it.
*/
export function agentLoop(
prompt: UserMessage,
prompt: AgentMessage,
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [prompt];
const newMessages: AgentMessage[] = [prompt];
const currentContext: AgentContext = {
...context,
messages: [...context.messages, prompt],
@ -37,38 +54,34 @@ export function agentLoop(
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retry after overflow - context already has user message or tool results.
* Throws if the last message is not a user message or tool result.
*/
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retry after overflow - context already has user message or tool results.
* Throws if the last message is not a user message or tool result.
* Used for retries - context already has user message or tool results.
*
* **Important:** The last message in context must convert to a `user` or `toolResult` message
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
* This cannot be validated here since `convertToLlm` is only called once per turn.
*/
export function agentLoopContinue(
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
// Validate that we can continue from this context
const lastMessage = context.messages[context.messages.length - 1];
if (!lastMessage) {
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
if (context.messages.length === 0) {
throw new Error("Cannot continue: no messages in context");
}
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
throw new Error(`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`);
if (context.messages[context.messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [];
const newMessages: AgentMessage[] = [];
const currentContext: AgentContext = { ...context };
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
// No user message events - we're continuing from existing context
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
@ -76,28 +89,28 @@ export function agentLoopContinue(
return stream;
}
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
return new EventStream<AgentEvent, AgentContext["messages"]>(
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
return new EventStream<AgentEvent, AgentMessage[]>(
(event: AgentEvent) => event.type === "agent_end",
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
);
}
/**
* Shared loop logic for both agentLoop and agentLoopContinue.
* Main loop logic shared by agentLoop and agentLoopContinue.
*/
async function runLoop(
currentContext: AgentContext,
newMessages: AgentContext["messages"],
newMessages: AgentMessage[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<void> {
let hasMoreToolCalls = true;
let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: QueuedMessage<any>[] | null = null;
let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: AgentMessage[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) {
@ -106,15 +119,13 @@ async function runLoop(
firstTurn = false;
}
// Process queued messages first (inject before next assistant response)
// Process queued messages (inject before next assistant response)
if (queuedMessages.length > 0) {
for (const { original, llm } of queuedMessages) {
stream.push({ type: "message_start", message: original });
stream.push({ type: "message_end", message: original });
if (llm) {
currentContext.messages.push(llm);
newMessages.push(llm);
}
for (const message of queuedMessages) {
stream.push({ type: "message_start", message });
stream.push({ type: "message_end", message });
currentContext.messages.push(message);
newMessages.push(message);
}
queuedMessages = [];
}
@ -124,7 +135,6 @@ async function runLoop(
newMessages.push(message);
if (message.stopReason === "error" || message.stopReason === "aborted") {
// Stop the loop on error or abort
stream.push({ type: "turn_end", message, toolResults: [] });
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
@ -137,7 +147,6 @@ async function runLoop(
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
// Execute tool calls
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
@ -147,10 +156,14 @@ async function runLoop(
);
toolResults.push(...toolExecution.toolResults);
queuedAfterTools = toolExecution.queuedMessages ?? null;
currentContext.messages.push(...toolResults);
newMessages.push(...toolResults);
for (const result of toolResults) {
currentContext.messages.push(result);
newMessages.push(result);
}
}
stream.push({ type: "turn_end", message, toolResults: toolResults });
stream.push({ type: "turn_end", message, toolResults });
// Get queued messages after turn completes
if (queuedAfterTools && queuedAfterTools.length > 0) {
@ -165,41 +178,44 @@ async function runLoop(
stream.end(newMessages);
}
// Helper functions
/**
* Stream an assistant response from the LLM.
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
*/
async function streamAssistantResponse(
context: AgentContext,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<AssistantMessage> {
// Convert AgentContext to Context for streamSimple
// Use a copy of messages to avoid mutating the original context
const processedMessages = config.preprocessor
? await config.preprocessor(context.messages, signal)
: [...context.messages];
const processedContext: Context = {
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
let messages = context.messages;
if (config.transformContext) {
messages = await config.transformContext(messages, signal);
}
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
const llmMessages = await config.convertToLlm(messages);
// Build LLM context
const llmContext: Context = {
systemPrompt: context.systemPrompt,
messages: [...processedMessages].map((m) => {
if (m.role === "toolResult") {
// biome-ignore lint/correctness/noUnusedVariables: fine here
const { details, ...rest } = m;
return rest;
} else {
return m;
}
}),
tools: context.tools, // AgentTool extends Tool, so this works
messages: llmMessages,
tools: context.tools,
};
// Use custom stream function if provided, otherwise use default streamSimple
const streamFunction = streamFn || streamSimple;
// Resolve API key for every assistant response (important for expiring tokens)
// Resolve API key (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
const response = await streamFunction(config.model, processedContext, { ...config, apiKey: resolvedApiKey, signal });
const response = streamFunction(config.model, llmContext, {
...config,
apiKey: resolvedApiKey,
signal,
});
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
@ -225,7 +241,11 @@ async function streamAssistantResponse(
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } });
stream.push({
type: "message_update",
assistantMessageEvent: event,
message: { ...partialMessage },
});
}
break;
@ -249,16 +269,19 @@ async function streamAssistantResponse(
return await response.result();
}
async function executeToolCalls<T>(
tools: AgentTool<any, T>[] | undefined,
/**
* Execute tool calls from an assistant message.
*/
async function executeToolCalls(
tools: AgentTool<any>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
stream: EventStream<AgentEvent, AgentMessage[]>,
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{ toolResults: ToolResultMessage<T>[]; queuedMessages?: QueuedMessage<any>[] }> {
): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage<any>[] = [];
let queuedMessages: QueuedMessage<any>[] | undefined;
const results: ToolResultMessage[] = [];
let queuedMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
@ -271,16 +294,14 @@ async function executeToolCalls<T>(
args: toolCall.arguments,
});
let result: AgentToolResult<T>;
let result: AgentToolResult<any>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
// Validate arguments using shared validation function
const validatedArgs = validateToolArguments(tool, toolCall);
// Execute with validated, typed arguments, passing update callback
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
@ -293,7 +314,7 @@ async function executeToolCalls<T>(
} catch (e) {
result = {
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
details: {} as T,
details: {},
};
isError = true;
}
@ -306,7 +327,7 @@ async function executeToolCalls<T>(
isError,
});
const toolResultMessage: ToolResultMessage<T> = {
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
@ -320,6 +341,7 @@ async function executeToolCalls<T>(
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
// Check for queued messages - skip remaining tools if user interrupted
if (getQueuedMessages) {
const queued = await getQueuedMessages();
if (queued.length > 0) {
@ -336,13 +358,13 @@ async function executeToolCalls<T>(
return { toolResults: results, queuedMessages };
}
function skipToolCall<T>(
function skipToolCall(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, Message[]>,
): ToolResultMessage<T> {
const result: AgentToolResult<T> = {
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {} as T,
details: {},
};
stream.push({
@ -359,12 +381,12 @@ function skipToolCall<T>(
isError: true,
});
const toolResultMessage: ToolResultMessage<T> = {
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
details: {},
isError: true,
timestamp: Date.now(),
};

View file

@ -1,64 +1,66 @@
import type { ImageContent, Message, QueuedMessage, ReasoningEffort, TextContent } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import type { AgentTransport } from "./transports/types.js";
import type { AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "./types.js";
/**
* Agent class that uses the agent-loop directly.
* No transport abstraction - calls streamSimple via the loop.
*/
import {
getModel,
type ImageContent,
type Message,
type Model,
type ReasoningEffort,
streamSimple,
type TextContent,
} from "@mariozechner/pi-ai";
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentState,
AgentTool,
StreamFn,
ThinkingLevel,
} from "./types.js";
/**
* Default message transformer: Keep only LLM-compatible messages, strip app-specific fields.
* Converts attachments to proper content blocks (images ImageContent, documents TextContent).
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
*/
function defaultMessageTransformer(messages: AppMessage[]): Message[] {
return messages
.filter((m) => {
// Only keep standard LLM message roles
return m.role === "user" || m.role === "assistant" || m.role === "toolResult";
})
.map((m) => {
if (m.role === "user") {
const { attachments, ...rest } = m as any;
// If no attachments, return as-is
if (!attachments || attachments.length === 0) {
return rest as Message;
}
// Convert attachments to content blocks
const content = Array.isArray(rest.content) ? [...rest.content] : [{ type: "text", text: rest.content }];
for (const attachment of attachments as Attachment[]) {
// Add image blocks for image attachments
if (attachment.type === "image") {
content.push({
type: "image",
data: attachment.content,
mimeType: attachment.mimeType,
} as ImageContent);
}
// Add text blocks for documents with extracted text
else if (attachment.type === "document" && attachment.extractedText) {
content.push({
type: "text",
text: `\n\n[Document: ${attachment.fileName}]\n${attachment.extractedText}`,
isDocument: true,
} as TextContent);
}
}
return { ...rest, content } as Message;
}
return m as Message;
});
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
}
export interface AgentOptions {
initialState?: Partial<AgentState>;
transport: AgentTransport;
// Transform app messages to LLM-compatible messages before sending to transport
messageTransformer?: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
// Called before each LLM call inside the agent loop - can modify messages (e.g., for pruning)
preprocessor?: (messages: Message[]) => Promise<Message[]>;
// Queue mode: "all" = send all queued messages at once, "one-at-a-time" = send one queued message per turn
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
* Default filters to user/assistant/toolResult and converts attachments.
*/
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to context before convertToLlm.
* Use for context pruning, injecting external context, etc.
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn
*/
queueMode?: "all" | "one-at-a-time";
/**
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
*/
streamFn?: StreamFn;
/**
* Resolves an API key dynamically for each LLM call.
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
}
export class Agent {
@ -73,22 +75,25 @@ export class Agent {
pendingToolCalls: new Set<string>(),
error: undefined,
};
private listeners = new Set<(e: AgentEvent) => void>();
private abortController?: AbortController;
private transport: AgentTransport;
private messageTransformer: (messages: AppMessage[]) => Message[] | Promise<Message[]>;
private preprocessor?: (messages: Message[]) => Promise<Message[]>;
private messageQueue: Array<QueuedMessage<AppMessage>> = [];
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
private messageQueue: AgentMessage[] = [];
private queueMode: "all" | "one-at-a-time";
private streamFn: StreamFn;
private getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
private runningPrompt?: Promise<void>;
private resolveRunningPrompt?: () => void;
constructor(opts: AgentOptions) {
constructor(opts: AgentOptions = {}) {
this._state = { ...this._state, ...opts.initialState };
this.transport = opts.transport;
this.messageTransformer = opts.messageTransformer || defaultMessageTransformer;
this.preprocessor = opts.preprocessor;
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
this.transformContext = opts.transformContext;
this.queueMode = opts.queueMode || "one-at-a-time";
this.streamFn = opts.streamFn || streamSimple;
this.getApiKey = opts.getApiKey;
}
get state(): AgentState {
@ -100,12 +105,12 @@ export class Agent {
return () => this.listeners.delete(fn);
}
// State mutators - update internal state without emitting events
// State mutators
setSystemPrompt(v: string) {
this._state.systemPrompt = v;
}
setModel(m: typeof this._state.model) {
setModel(m: Model<any>) {
this._state.model = m;
}
@ -121,25 +126,20 @@ export class Agent {
return this.queueMode;
}
setTools(t: typeof this._state.tools) {
setTools(t: AgentTool<any>[]) {
this._state.tools = t;
}
replaceMessages(ms: AppMessage[]) {
replaceMessages(ms: AgentMessage[]) {
this._state.messages = ms.slice();
}
appendMessage(m: AppMessage) {
appendMessage(m: AgentMessage) {
this._state.messages = [...this._state.messages, m];
}
async queueMessage(m: AppMessage) {
// Transform message and queue it for injection at next turn
const transformed = await this.messageTransformer([m]);
this.messageQueue.push({
original: m,
llm: transformed[0], // undefined if filtered out
});
queueMessage(m: AgentMessage) {
this.messageQueue.push(m);
}
clearMessageQueue() {
@ -154,17 +154,10 @@ export class Agent {
this.abortController?.abort();
}
/**
* Returns a promise that resolves when the current prompt completes.
* Returns immediately resolved promise if no prompt is running.
*/
waitForIdle(): Promise<void> {
return this.runningPrompt ?? Promise.resolve();
}
/**
* Clear all messages and state. Call abort() first if a prompt is in flight.
*/
reset() {
this._state.messages = [];
this._state.isStreaming = false;
@ -174,99 +167,53 @@ export class Agent {
this.messageQueue = [];
}
/** Send a prompt to the agent with an AppMessage. */
async prompt(message: AppMessage): Promise<void>;
/** Send a prompt to the agent with text and optional attachments. */
async prompt(input: string, attachments?: Attachment[]): Promise<void>;
async prompt(input: string | AppMessage, attachments?: Attachment[]) {
/** Send a prompt with an AgentMessage */
async prompt(message: AgentMessage): Promise<void>;
async prompt(input: string, images?: ImageContent[]): Promise<void>;
async prompt(input: string | AgentMessage, images?: ImageContent[]) {
const model = this._state.model;
if (!model) {
throw new Error("No model configured");
}
if (!model) throw new Error("No model configured");
let userMessage: AppMessage;
let userMessage: AgentMessage;
if (typeof input === "string") {
// Build user message from text + attachments
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
if (attachments?.length) {
for (const a of attachments) {
if (a.type === "image") {
content.push({ type: "image", data: a.content, mimeType: a.mimeType });
} else if (a.type === "document" && a.extractedText) {
content.push({
type: "text",
text: `\n\n[Document: ${a.fileName}]\n${a.extractedText}`,
isDocument: true,
} as TextContent);
}
}
if (images && images.length > 0) {
content.push(...images);
}
userMessage = {
role: "user",
content,
attachments: attachments?.length ? attachments : undefined,
timestamp: Date.now(),
};
} else {
// Use provided AppMessage directly
userMessage = input;
}
await this._runAgentLoop(userMessage);
await this._runLoop(userMessage);
}
/**
* Continue from the current context without adding a new user message.
* Used for retry after overflow recovery when context already has user message or tool results.
*/
/** Continue from current context (for retry after overflow) */
async continue() {
const messages = this._state.messages;
if (messages.length === 0) {
throw new Error("No messages to continue from");
}
const lastMessage = messages[messages.length - 1];
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
throw new Error(`Cannot continue from message role: ${lastMessage.role}`);
if (messages[messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
await this._runAgentLoopContinue();
await this._runLoop(undefined);
}
/**
* Internal: Run the agent loop with a new user message.
* Run the agent loop.
* If userMessage is provided, starts a new conversation turn.
* Otherwise, continues from existing context.
*/
private async _runAgentLoop(userMessage: AppMessage) {
const { llmMessages, cfg } = await this._prepareRun();
// Transform user message (e.g., HookMessage -> user message)
const [transformedUserMessage] = await this.messageTransformer([userMessage]);
const events = this.transport.run(llmMessages, transformedUserMessage, cfg, this.abortController!.signal);
await this._processEvents(events);
}
/**
* Internal: Continue the agent loop from current context.
*/
private async _runAgentLoopContinue() {
const { llmMessages, cfg } = await this._prepareRun();
const events = this.transport.continue(llmMessages, cfg, this.abortController!.signal);
await this._processEvents(events);
}
/**
* Prepare for running the agent loop.
*/
private async _prepareRun() {
private async _runLoop(userMessage?: AgentMessage) {
const model = this._state.model;
if (!model) {
throw new Error("No model configured");
}
if (!model) throw new Error("No model configured");
this.runningPrompt = new Promise<void>((resolve) => {
this.resolveRunningPrompt = resolve;
@ -282,88 +229,89 @@ export class Agent {
? undefined
: this._state.thinkingLevel === "minimal"
? "low"
: this._state.thinkingLevel;
: (this._state.thinkingLevel as ReasoningEffort);
const cfg = {
const context: AgentContext = {
systemPrompt: this._state.systemPrompt,
messages: this._state.messages.slice(),
tools: this._state.tools,
};
const config: AgentLoopConfig = {
model,
reasoning,
preprocessor: this.preprocessor,
getQueuedMessages: async <T>() => {
convertToLlm: this.convertToLlm,
transformContext: this.transformContext,
getApiKey: this.getApiKey,
getQueuedMessages: async () => {
if (this.queueMode === "one-at-a-time") {
if (this.messageQueue.length > 0) {
const first = this.messageQueue[0];
this.messageQueue = this.messageQueue.slice(1);
return [first] as QueuedMessage<T>[];
return [first];
}
return [];
} else {
const queued = this.messageQueue.slice();
this.messageQueue = [];
return queued as QueuedMessage<T>[];
return queued;
}
},
};
const llmMessages = await this.messageTransformer(this._state.messages);
return { llmMessages, cfg, model };
}
/**
* Process events from the transport.
*/
private async _processEvents(events: AsyncIterable<AgentEvent>) {
const model = this._state.model!;
const generatedMessages: AppMessage[] = [];
let partial: AppMessage | null = null;
let partial: AgentMessage | null = null;
try {
for await (const ev of events) {
switch (ev.type) {
case "message_start": {
partial = ev.message as AppMessage;
this._state.streamMessage = ev.message as Message;
const stream = userMessage
? agentLoop(userMessage, context, config, this.abortController.signal, this.streamFn)
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
for await (const event of stream) {
// Update internal state based on events
switch (event.type) {
case "message_start":
partial = event.message;
this._state.streamMessage = event.message;
break;
}
case "message_update": {
partial = ev.message;
this._state.streamMessage = ev.message;
case "message_update":
partial = event.message;
this._state.streamMessage = event.message;
break;
}
case "message_end": {
case "message_end":
partial = null;
this._state.streamMessage = null;
this.appendMessage(ev.message);
generatedMessages.push(ev.message);
this.appendMessage(event.message);
break;
}
case "tool_execution_start": {
const s = new Set(this._state.pendingToolCalls);
s.add(ev.toolCallId);
s.add(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "tool_execution_end": {
const s = new Set(this._state.pendingToolCalls);
s.delete(ev.toolCallId);
s.delete(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "turn_end": {
if (ev.message.role === "assistant" && ev.message.errorMessage) {
this._state.error = ev.message.errorMessage;
case "turn_end":
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
this._state.error = (event.message as any).errorMessage;
}
break;
}
case "agent_end": {
case "agent_end":
this._state.streamMessage = null;
break;
}
}
this.emit(ev as AgentEvent);
// Emit to listeners
this.emit(event);
}
// Handle any remaining partial message
@ -375,8 +323,7 @@ export class Agent {
(c.type === "toolCall" && c.name.trim().length > 0),
);
if (!onlyEmpty) {
this.appendMessage(partial as AppMessage);
generatedMessages.push(partial as AppMessage);
this.appendMessage(partial);
} else {
if (this.abortController?.signal.aborted) {
throw new Error("Request was aborted");
@ -384,7 +331,7 @@ export class Agent {
}
}
} catch (err: any) {
const msg: Message = {
const errorMsg: AgentMessage = {
role: "assistant",
content: [{ type: "text", text: "" }],
api: model.api,
@ -401,10 +348,11 @@ export class Agent {
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
errorMessage: err?.message || String(err),
timestamp: Date.now(),
};
this.appendMessage(msg as AppMessage);
generatedMessages.push(msg as AppMessage);
} as AgentMessage;
this.appendMessage(errorMsg);
this._state.error = err?.message || String(err);
this.emit({ type: "agent_end", messages: [errorMsg] });
} finally {
this._state.isStreaming = false;
this._state.streamMessage = null;

View file

@ -1,22 +1,6 @@
// Core Agent
export { Agent, type AgentOptions } from "./agent.js";
// Transports
export {
type AgentRunConfig,
type AgentTransport,
AppTransport,
type AppTransportOptions,
ProviderTransport,
type ProviderTransportOptions,
type ProxyAssistantMessageEvent,
} from "./transports/index.js";
export * from "./agent.js";
// Loop functions
export * from "./agent-loop.js";
// Types
export type {
AgentEvent,
AgentState,
AppMessage,
Attachment,
CustomMessages,
ThinkingLevel,
UserMessageWithAttachments,
} from "./types.js";
export * from "./types.js";

340
packages/agent/src/proxy.ts Normal file
View file

@ -0,0 +1,340 @@
/**
* Proxy stream function for apps that route LLM calls through a server.
* The server manages auth and proxies requests to LLM providers.
*/
import {
type AssistantMessage,
type AssistantMessageEvent,
type Context,
EventStream,
type Model,
type SimpleStreamOptions,
type StopReason,
type ToolCall,
} from "@mariozechner/pi-ai";
// Internal import for JSON parsing utility
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
// Create stream class matching ProxyMessageEventStream
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
/**
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
usage: AssistantMessage["usage"];
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
errorMessage?: string;
usage: AssistantMessage["usage"];
};
export interface ProxyStreamOptions extends SimpleStreamOptions {
/** Auth token for the proxy server */
authToken: string;
/** Proxy server URL (e.g., "https://genai.example.com") */
proxyUrl: string;
}
/**
* Stream function that proxies through a server instead of calling LLM providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
*
* @example
* ```typescript
* const agent = new Agent({
* streamFn: (model, context, options) =>
* streamProxy(model, context, {
* ...options,
* authToken: await getAuthToken(),
* proxyUrl: "https://genai.example.com",
* }),
* });
* ```
*/
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
const stream = new ProxyMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${options.proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response
}
throw new Error(errorMessage);
}
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
const event = processProxyEvent(proxyEvent, partial);
if (event) {
stream.push(event);
}
}
}
}
}
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
const reason = options.signal?.aborted ? "aborted" : "error";
partial.stopReason = reason;
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason,
error: partial,
});
stream.end();
} finally {
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
/**
* Process a proxy event and update the partial message.
*/
function processProxyEvent(
proxyEvent: ProxyAssistantMessageEvent,
partial: AssistantMessage,
): AssistantMessageEvent | undefined {
switch (proxyEvent.type) {
case "start":
return { type: "start", partial };
case "text_start":
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
return {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received text_delta for non-text content");
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
return {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
}
throw new Error("Received text_end for non-text content");
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
return {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received thinking_delta for non-thinking content");
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
return {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
}
throw new Error("Received thinking_end for non-thinking content");
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments = parseStreamingJson((content as any).partialJson) || {};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
return {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received toolcall_delta for non-toolCall content");
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
return {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
return undefined;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
return { type: "done", reason: proxyEvent.reason, message: partial };
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
return { type: "error", reason: proxyEvent.reason, error: partial };
default: {
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
return undefined;
}
}
}

View file

@ -1,397 +0,0 @@
import type {
AgentContext,
AgentLoopConfig,
Api,
AssistantMessage,
AssistantMessageEvent,
Context,
Message,
Model,
SimpleStreamOptions,
ToolCall,
UserMessage,
} from "@mariozechner/pi-ai";
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai";
import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js";
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
import type { ProxyAssistantMessageEvent } from "./proxy-types.js";
import type { AgentRunConfig, AgentTransport } from "./types.js";
/**
* Stream function that proxies through a server instead of calling providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*/
function streamSimpleProxy(
model: Model<any>,
context: Context,
options: SimpleStreamOptions & { authToken: string },
proxyUrl: string,
): AssistantMessageEventStream {
const stream = new AssistantMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
// Set up abort handler to cancel the reader
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
// Don't send apiKey or signal - those are added server-side
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response, use default message
}
throw new Error(errorMessage);
}
// Parse SSE stream
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Check if aborted after reading
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
let event: AssistantMessageEvent | undefined;
// Handle different event types
// Server sends events with partial for non-delta events,
// and without partial for delta events
switch (proxyEvent.type) {
case "start":
event = { type: "start", partial };
break;
case "text_start":
partial.content[proxyEvent.contentIndex] = {
type: "text",
text: "",
};
event = { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
event = {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
} else {
throw new Error("Received text_delta for non-text content");
}
break;
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
event = {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
} else {
throw new Error("Received text_end for non-text content");
}
break;
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = {
type: "thinking",
thinking: "",
};
event = { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
event = {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
} else {
throw new Error("Received thinking_delta for non-thinking content");
}
break;
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
event = {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
} else {
throw new Error("Received thinking_end for non-thinking content");
}
break;
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
event = { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
break;
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments = parseStreamingJson((content as any).partialJson) || {};
event = {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
} else {
throw new Error("Received toolcall_delta for non-toolCall content");
}
break;
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
event = {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
break;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
event = { type: "done", reason: proxyEvent.reason, message: partial };
break;
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
event = { type: "error", reason: proxyEvent.reason, error: partial };
break;
default: {
// Exhaustive check
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled event type: ${(proxyEvent as any).type}`);
break;
}
}
// Push the event to stream
if (event) {
stream.push(event);
} else {
throw new Error("Failed to create event from proxy event");
}
}
}
}
}
// Check if aborted after reading
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
partial.stopReason = options.signal?.aborted ? "aborted" : "error";
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason: partial.stopReason,
error: partial,
} satisfies AssistantMessageEvent);
stream.end();
} finally {
// Clean up abort handler
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
export interface AppTransportOptions {
/**
* Proxy server URL. The server manages user accounts and proxies requests to LLM providers.
* Example: "https://genai.mariozechner.at"
*/
proxyUrl: string;
/**
* Function to retrieve auth token for the proxy server.
* The token is used for user authentication and authorization.
*/
getAuthToken: () => Promise<string> | string;
}
/**
* Transport that uses an app server with user authentication tokens.
* The server manages user accounts and proxies requests to LLM providers.
*/
export class AppTransport implements AgentTransport {
private options: AppTransportOptions;
constructor(options: AppTransportOptions) {
this.options = options;
}
private async getStreamFn(authToken: string) {
return <TApi extends Api>(model: Model<TApi>, context: Context, options?: SimpleStreamOptions) => {
return streamSimpleProxy(
model,
context,
{
...options,
authToken,
},
this.options.proxyUrl,
);
};
}
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
return {
systemPrompt: cfg.systemPrompt,
messages,
tools: cfg.tools,
};
}
private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig {
return {
model: cfg.model,
reasoning: cfg.reasoning,
getQueuedMessages: cfg.getQueuedMessages,
};
}
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
const authToken = await this.options.getAuthToken();
if (!authToken) {
throw new Error("Auth token is required for AppTransport");
}
const streamFn = await this.getStreamFn(authToken);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(cfg);
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) {
yield ev;
}
}
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
const authToken = await this.options.getAuthToken();
if (!authToken) {
throw new Error("Auth token is required for AppTransport");
}
const streamFn = await this.getStreamFn(authToken);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(cfg);
for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) {
yield ev;
}
}
}

View file

@ -1,86 +0,0 @@
import {
type AgentContext,
type AgentLoopConfig,
agentLoop,
agentLoopContinue,
type Message,
type UserMessage,
} from "@mariozechner/pi-ai";
import type { AgentRunConfig, AgentTransport } from "./types.js";
export interface ProviderTransportOptions {
/**
* Function to retrieve API key for a given provider.
* If not provided, transport will try to use environment variables.
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Optional CORS proxy URL for browser environments.
* If provided, all requests will be routed through this proxy.
* Format: "https://proxy.example.com"
*/
corsProxyUrl?: string;
}
/**
* Transport that calls LLM providers directly.
* Optionally routes calls through a CORS proxy if configured.
*/
export class ProviderTransport implements AgentTransport {
private options: ProviderTransportOptions;
constructor(options: ProviderTransportOptions = {}) {
this.options = options;
}
private getModel(cfg: AgentRunConfig) {
let model = cfg.model;
if (this.options.corsProxyUrl && cfg.model.baseUrl) {
model = {
...cfg.model,
baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`,
};
}
return model;
}
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
return {
systemPrompt: cfg.systemPrompt,
messages,
tools: cfg.tools,
};
}
private buildLoopConfig(model: AgentRunConfig["model"], cfg: AgentRunConfig): AgentLoopConfig {
return {
model,
reasoning: cfg.reasoning,
// Resolve API key per assistant response (important for expiring OAuth tokens)
getApiKey: this.options.getApiKey,
getQueuedMessages: cfg.getQueuedMessages,
preprocessor: cfg.preprocessor,
};
}
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
const model = this.getModel(cfg);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(model, cfg);
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) {
yield ev;
}
}
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
const model = this.getModel(cfg);
const context = this.buildContext(messages, cfg);
const pc = this.buildLoopConfig(model, cfg);
for await (const ev of agentLoopContinue(context, pc, signal)) {
yield ev;
}
}
}

View file

@ -1,4 +0,0 @@
export { AppTransport, type AppTransportOptions } from "./AppTransport.js";
export { ProviderTransport, type ProviderTransportOptions } from "./ProviderTransport.js";
export type { ProxyAssistantMessageEvent } from "./proxy-types.js";
export type { AgentRunConfig, AgentTransport } from "./types.js";

View file

@ -1,20 +0,0 @@
import type { StopReason, Usage } from "@mariozechner/pi-ai";
/**
* Event types emitted by the proxy server.
* The server strips the `partial` field from delta events to reduce bandwidth.
* Clients reconstruct the partial message from these events.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; usage: Usage }
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; errorMessage: string; usage: Usage };

View file

@ -1,34 +0,0 @@
import type { AgentEvent, AgentTool, Message, Model, QueuedMessage, ReasoningEffort } from "@mariozechner/pi-ai";
/**
* The minimal configuration needed to run an agent turn.
*/
export interface AgentRunConfig {
systemPrompt: string;
tools: AgentTool<any>[];
model: Model<any>;
reasoning?: ReasoningEffort;
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
/** Called before each LLM call - can modify messages (e.g., for pruning) */
preprocessor?: (messages: Message[]) => Promise<Message[]>;
}
/**
* Transport interface for executing agent turns.
* Transports handle the communication with LLM providers,
* abstracting away the details of API calls, proxies, etc.
*
* Events yielded must match the @mariozechner/pi-ai AgentEvent types.
*/
export interface AgentTransport {
/** Run with a new user message */
run(
messages: Message[],
userMessage: Message,
config: AgentRunConfig,
signal?: AbortSignal,
): AsyncIterable<AgentEvent>;
/** Continue from current context (no new user message) */
continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable<AgentEvent>;
}

View file

@ -1,26 +1,83 @@
import type {
AgentTool,
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
streamSimple,
TextContent,
Tool,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import type { Static, TSchema } from "@sinclair/typebox";
export type StreamFn = typeof streamSimple;
/**
* Attachment type definition.
* Processing is done by consumers (e.g., document extraction in web-ui).
* Configuration for the agent loop.
*/
export interface Attachment {
id: string;
type: "image" | "document";
fileName: string;
mimeType: string;
size: number;
content: string; // base64 encoded (without data URL prefix)
extractedText?: string; // For documents
preview?: string; // base64 image preview
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
*
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
* status messages) should be filtered out.
*
* @example
* ```typescript
* convertToLlm: (messages) => messages.flatMap(m => {
* if (m.role === "hookMessage") {
* // Convert custom message to user message
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
* }
* if (m.role === "notification") {
* // Filter out UI-only messages
* return [];
* }
* // Pass through standard LLM messages
* return [m];
* })
* ```
*/
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to the context before `convertToLlm`.
*
* Use this for operations that work at the AgentMessage level:
* - Context window management (pruning old messages)
* - Injecting context from external sources
*
* @example
* ```typescript
* transformContext: async (messages) => {
* if (estimateTokens(messages) > MAX_TOKENS) {
* return pruneOldMessages(messages);
* }
* return messages;
* }
* ```
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Resolves an API key dynamically for each LLM call.
*
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
* during long-running tool execution phases.
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Returns queued messages to inject into the conversation.
*
* Called after each turn to check for user interruptions or injected messages.
* If messages are returned, they're added to the context before the next LLM call.
*/
getQueuedMessages?: () => Promise<AgentMessage[]>;
}
/**
@ -29,11 +86,6 @@ export interface Attachment {
*/
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
/**
* User message with optional attachments.
*/
export type UserMessageWithAttachments = UserMessage & { attachments?: Attachment[] };
/**
* Extensible interface for custom app messages.
* Apps can extend via declaration merging:
@ -41,27 +93,23 @@ export type UserMessageWithAttachments = UserMessage & { attachments?: Attachmen
* @example
* ```typescript
* declare module "@mariozechner/agent" {
* interface CustomMessages {
* interface CustomAgentMessages {
* artifact: ArtifactMessage;
* notification: NotificationMessage;
* }
* }
* ```
*/
export interface CustomMessages {
export interface CustomAgentMessages {
// Empty by default - apps extend via declaration merging
}
/**
* AppMessage: Union of LLM messages + attachments + custom messages.
* AgentMessage: Union of LLM messages + custom messages.
* This abstraction allows apps to add custom message types while maintaining
* type safety and compatibility with the base LLM messages.
*/
export type AppMessage =
| AssistantMessage
| UserMessageWithAttachments
| Message // Includes ToolResultMessage
| CustomMessages[keyof CustomMessages];
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
/**
* Agent state containing all configuration and conversation data.
@ -71,13 +119,42 @@ export interface AgentState {
model: Model<any>;
thinkingLevel: ThinkingLevel;
tools: AgentTool<any>[];
messages: AppMessage[]; // Can include attachments + custom message types
messages: AgentMessage[]; // Can include attachments + custom message types
isStreaming: boolean;
streamMessage: AppMessage | null;
streamMessage: AgentMessage | null;
pendingToolCalls: Set<string>;
error?: string;
}
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: Message[];
tools?: AgentTool<any>[];
}
/**
* Events emitted by the Agent for UI updates.
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
@ -85,15 +162,15 @@ export interface AgentState {
export type AgentEvent =
// Agent lifecycle
| { type: "agent_start" }
| { type: "agent_end"; messages: AppMessage[] }
| { type: "agent_end"; messages: AgentMessage[] }
// Turn lifecycle - a turn is one assistant response + any tool calls/results
| { type: "turn_start" }
| { type: "turn_end"; message: AppMessage; toolResults: ToolResultMessage[] }
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
// Message lifecycle - emitted for user, assistant, and toolResult messages
| { type: "message_start"; message: AppMessage }
| { type: "message_start"; message: AgentMessage }
// Only emitted for assistant messages during streaming
| { type: "message_update"; message: AppMessage; assistantMessageEvent: AssistantMessageEvent }
| { type: "message_end"; message: AppMessage }
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
| { type: "message_end"; message: AgentMessage }
// Tool execution lifecycle
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }

View file

@ -0,0 +1,535 @@
import {
type AssistantMessage,
type AssistantMessageEvent,
EventStream,
type Message,
type Model,
type UserMessage,
} from "@mariozechner/pi-ai";
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AgentTool } from "../src/types.js";
// Mock stream for testing - mimics MockAssistantStream
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
function createAssistantMessage(
content: AssistantMessage["content"],
stopReason: AssistantMessage["stopReason"] = "stop",
): AssistantMessage {
return {
role: "assistant",
content,
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason,
timestamp: Date.now(),
};
}
function createUserMessage(text: string): UserMessage {
return {
role: "user",
content: text,
timestamp: Date.now(),
};
}
// Simple identity converter for tests - just passes through standard messages
function identityConverter(messages: AgentMessage[]): Message[] {
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
}
describe("agentLoop with AgentMessage", () => {
it("should emit events with AgentMessage types", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Hi there!" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have user message and assistant message
expect(messages.length).toBe(2);
expect(messages[0].role).toBe("user");
expect(messages[1].role).toBe("assistant");
// Verify event sequence
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
});
it("should handle custom message types via convertToLlm", async () => {
// Create a custom message type
interface CustomNotification {
role: "notification";
text: string;
timestamp: number;
}
const notification: CustomNotification = {
role: "notification",
text: "This is a notification",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [notification as unknown as AgentMessage], // Custom message in context
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("Hello");
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Filter out notifications, convert rest
convertedMessages = messages
.filter((m) => (m as { role: string }).role !== "notification")
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
// The notification should have been filtered out in convertToLlm
expect(convertedMessages.length).toBe(1); // Only user message
expect(convertedMessages[0].role).toBe("user");
});
it("should apply transformContext before convertToLlm", async () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [
createUserMessage("old message 1"),
createAssistantMessage([{ type: "text", text: "old response 1" }]),
createUserMessage("old message 2"),
createAssistantMessage([{ type: "text", text: "old response 2" }]),
],
tools: [],
};
const userPrompt: AgentMessage = createUserMessage("new message");
let transformedMessages: AgentMessage[] = [];
let convertedMessages: Message[] = [];
const config: AgentLoopConfig = {
model: createModel(),
transformContext: async (messages) => {
// Keep only last 2 messages (prune old ones)
transformedMessages = messages.slice(-2);
return transformedMessages;
},
convertToLlm: (messages) => {
convertedMessages = messages.filter(
(m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult",
) as Message[];
return convertedMessages;
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const stream = agentLoop(userPrompt, context, config, undefined, streamFn);
for await (const _ of stream) {
// consume
}
// transformContext should have been called first, keeping only last 2
expect(transformedMessages.length).toBe(2);
// Then convertToLlm receives the pruned messages
expect(convertedMessages.length).toBe(2);
});
it("should handle tool calls and results", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `echoed: ${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("echo something");
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
let callIndex = 0;
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return tool call
const message = createAssistantMessage(
[{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "hello" } }],
"toolUse",
);
stream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([{ type: "text", text: "done" }]);
stream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
// Tool should have been executed
expect(executed).toEqual(["hello"]);
// Should have tool execution events
const toolStart = events.find((e) => e.type === "tool_execution_start");
const toolEnd = events.find((e) => e.type === "tool_execution_end");
expect(toolStart).toBeDefined();
expect(toolEnd).toBeDefined();
if (toolEnd?.type === "tool_execution_end") {
expect(toolEnd.isError).toBe(false);
}
});
it("should inject queued messages and skip remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: AgentMessage = createUserMessage("start");
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
let queuedDelivered = false;
let callIndex = 0;
let sawInterruptInContext = false;
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
getQueuedMessages: async () => {
// Return queued message after first tool executes
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return [queuedUserMessage];
}
return [];
},
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
// Check if interrupt message is in context on second call
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
const mockStream = new MockAssistantStream();
queueMicrotask(() => {
if (callIndex === 0) {
// First call: return two tool calls
const message = createAssistantMessage(
[
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
"toolUse",
);
mockStream.push({ type: "done", reason: "toolUse", message });
} else {
// Second call: return final response
const message = createAssistantMessage([{ type: "text", text: "done" }]);
mockStream.push({ type: "done", reason: "stop", message });
}
callIndex++;
});
return mockStream;
});
for await (const event of stream) {
events.push(event);
}
// Only first tool should have executed
expect(executed).toEqual(["first"]);
// Second tool should be skipped
const toolEnds = events.filter(
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> => e.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[0].isError).toBe(false);
expect(toolEnds[1].isError).toBe(true);
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
// Queued message should appear in events
const queuedMessageEvent = events.find(
(e) =>
e.type === "message_start" &&
e.message.role === "user" &&
typeof e.message.content === "string" &&
e.message.content === "interrupt",
);
expect(queuedMessageEvent).toBeDefined();
// Interrupt message should be in context when second LLM call is made
expect(sawInterruptInContext).toBe(true);
});
});
describe("agentLoopContinue with AgentMessage", () => {
it("should throw when context has no messages", () => {
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
expect(() => agentLoopContinue(context, config)).toThrow("Cannot continue: no messages in context");
});
it("should continue from existing context without emitting user message events", async () => {
const userMessage: AgentMessage = createUserMessage("Hello");
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: identityConverter,
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config, undefined, streamFn);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should only return the new assistant message (not the existing user message)
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Should NOT have user message events (that's the key difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1);
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
it("should allow custom message types as last message (caller responsibility)", async () => {
// Custom message that will be converted to user message by convertToLlm
interface HookMessage {
role: "hookMessage";
text: string;
timestamp: number;
}
const hookMessage: HookMessage = {
role: "hookMessage",
text: "Hook content",
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are helpful.",
messages: [hookMessage as unknown as AgentMessage],
tools: [],
};
const config: AgentLoopConfig = {
model: createModel(),
convertToLlm: (messages) => {
// Convert hookMessage to user message
return messages
.map((m) => {
if ((m as any).role === "hookMessage") {
return {
role: "user" as const,
content: (m as any).text,
timestamp: m.timestamp,
};
}
return m;
})
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
},
};
const streamFn = () => {
const stream = new MockAssistantStream();
queueMicrotask(() => {
const message = createAssistantMessage([{ type: "text", text: "Response to hook" }]);
stream.push({ type: "done", reason: "stop", message });
});
return stream;
};
// Should not throw - the hookMessage will be converted to user message
const stream = agentLoopContinue(context, config, undefined, streamFn);
const events: AgentEvent[] = [];
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
});
});

View file

@ -1,12 +1,10 @@
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
import { Agent } from "../src/index.js";
describe("Agent", () => {
it("should create an agent instance with default state", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
expect(agent.state).toBeDefined();
expect(agent.state.systemPrompt).toBe("");
@ -23,7 +21,6 @@ describe("Agent", () => {
it("should create an agent instance with custom initial state", () => {
const customModel = getModel("openai", "gpt-4o-mini");
const agent = new Agent({
transport: new ProviderTransport(),
initialState: {
systemPrompt: "You are a helpful assistant.",
model: customModel,
@ -37,9 +34,7 @@ describe("Agent", () => {
});
it("should subscribe to events", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
let eventCount = 0;
const unsubscribe = agent.subscribe((_event) => {
@ -61,9 +56,7 @@ describe("Agent", () => {
});
it("should update state with mutators", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
// Test setSystemPrompt
agent.setSystemPrompt("Custom prompt");
@ -101,38 +94,19 @@ describe("Agent", () => {
});
it("should support message queueing", async () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
const message = { role: "user" as const, content: "Queued message", timestamp: Date.now() };
await agent.queueMessage(message);
agent.queueMessage(message);
// The message is queued but not yet in state.messages
expect(agent.state.messages).not.toContainEqual(message);
});
it("should handle abort controller", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const agent = new Agent();
// Should not throw even if nothing is running
expect(() => agent.abort()).not.toThrow();
});
});
describe("ProviderTransport", () => {
it("should create a provider transport instance", () => {
const transport = new ProviderTransport();
expect(transport).toBeDefined();
});
it("should create a provider transport with options", () => {
const transport = new ProviderTransport({
getApiKey: async (provider) => `test-key-${provider}`,
corsProxyUrl: "https://proxy.example.com",
});
expect(transport).toBeDefined();
});
});

View file

@ -1,25 +1,8 @@
import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { calculateTool, getModel } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
function createTransport() {
return new ProviderTransport({
getApiKey: async (provider) => {
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
});
}
import { Agent } from "../src/index.js";
import { calculateTool } from "./utils/calculate.js";
async function basicPrompt(model: Model<any>) {
const agent = new Agent({
@ -29,7 +12,6 @@ async function basicPrompt(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("What is 2+2? Answer with just the number.");
@ -57,7 +39,6 @@ async function toolExecution(model: Model<any>) {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
@ -99,7 +80,6 @@ async function abortExecution(model: Model<any>) {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
@ -129,7 +109,6 @@ async function stateUpdates(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
const events: Array<string> = [];
@ -162,7 +141,6 @@ async function multiTurnConversation(model: Model<any>) {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("My name is Alice.");
@ -356,7 +334,6 @@ describe("Agent.continue()", () => {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
@ -368,7 +345,6 @@ describe("Agent.continue()", () => {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
const assistantMessage: AssistantMessage = {
@ -405,7 +381,6 @@ describe("Agent.continue()", () => {
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
// Manually add a user message without calling prompt()
@ -445,7 +420,6 @@ describe("Agent.continue()", () => {
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
// Set up a conversation state as if tool was just executed

View file

@ -1,5 +1,5 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool, AgentToolResult } from "../../agent/types.js";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface CalculateResult extends AgentToolResult<undefined> {
content: Array<{ type: "text"; text: string }>;

View file

@ -1,6 +1,5 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool } from "../../agent/index.js";
import type { AgentToolResult } from "../types.js";
import type { AgentTool, AgentToolResult } from "../../src/types.js";
export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {}

View file

@ -1,11 +0,0 @@
export { agentLoop, agentLoopContinue } from "./agent-loop.js";
export * from "./tools/index.js";
export type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentTool,
AgentToolResult,
AgentToolUpdateCallback,
QueuedMessage,
} from "./types.js";

View file

@ -1,2 +0,0 @@
export { calculate, calculateTool } from "./calculate.js";
export { getCurrentTime, getCurrentTimeTool } from "./get-current-time.js";

View file

@ -1,105 +0,0 @@
import type { Static, TSchema } from "@sinclair/typebox";
import type {
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
TextContent,
Tool,
ToolResultMessage,
} from "../types.js";
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: Message[];
tools?: AgentTool<any>[];
}
// Event types
export type AgentEvent =
// Emitted when the agent starts. An agent can emit multiple turns
| { type: "agent_start" }
// Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages
| { type: "turn_start" }
// Emitted when a user, assistant or tool result message starts
| { type: "message_start"; message: Message }
// Emitted when an asssitant messages is updated due to streaming
| { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage }
// Emitted when a user, assistant or tool result message is complete
| { type: "message_end"; message: Message }
// Emitted when a tool execution starts
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
// Emitted when a tool execution produces output (streaming)
| {
type: "tool_execution_update";
toolCallId: string;
toolName: string;
args: any;
partialResult: AgentToolResult<any>;
}
// Emitted when a tool execution completes
| {
type: "tool_execution_end";
toolCallId: string;
toolName: string;
result: AgentToolResult<any>;
isError: boolean;
}
// Emitted when a full turn completes
| { type: "turn_end"; message: AssistantMessage; toolResults: ToolResultMessage[] }
// Emitted when the agent has completed all its turns. All messages from every turn are
// contained in messages, which can be appended to the context
| { type: "agent_end"; messages: AgentContext["messages"] };
// Queued message with optional LLM representation
export interface QueuedMessage<TApp = Message> {
original: TApp; // Original message for UI events
llm?: Message; // Optional transformed message for loop context (undefined if filtered)
}
// Configuration for agent loop execution
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Optional hook to resolve an API key dynamically for each LLM call.
*
* This is useful for short-lived OAuth tokens (e.g. GitHub Copilot) that may
* expire during long-running tool execution phases.
*
* The agent loop will call this before each assistant response and pass the
* returned value as `apiKey` to `streamSimple()` (or a custom `streamFn`).
*
* If it returns `undefined`, the loop falls back to `config.apiKey`, and then
* to `streamSimple()`'s own provider key lookup (setApiKey/env vars).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
}

View file

@ -1,4 +1,3 @@
export * from "./agent/index.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/google.js";
@ -7,6 +6,7 @@ export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/oauth/index.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";

View file

@ -1,166 +0,0 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop } from "../src/agent/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js";
import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js";
import { AssistantMessageEventStream } from "../src/utils/event-stream.js";
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
describe("agentLoop queued message interrupt", () => {
it("injects queued messages after a tool call and skips remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: UserMessage = {
role: "user",
content: "start",
timestamp: Date.now(),
};
const queuedUserMessage: Message = {
role: "user",
content: "interrupt",
timestamp: Date.now(),
};
const queuedMessages: QueuedMessage<Message>[] = [{ original: queuedUserMessage, llm: queuedUserMessage }];
let queuedDelivered = false;
let sawInterruptInContext = false;
let callIndex = 0;
const streamFn = () => {
const stream = new AssistantMessageEventStream();
queueMicrotask(() => {
if (callIndex === 0) {
const message: AssistantMessage = {
role: "assistant",
content: [
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "toolUse",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "toolUse", message });
} else {
const message: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "done" }],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "stop", message });
}
callIndex += 1;
});
return stream;
};
const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async <T>() => {
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return queuedMessages as QueuedMessage<T>[];
}
return [];
};
const config: AgentLoopConfig = {
model: createModel(),
getQueuedMessages,
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
return streamFn();
});
for await (const event of stream) {
events.push(event);
}
expect(executed).toEqual(["first"]);
const toolEnds = events.filter(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[1].isError).toBe(true);
expect(toolEnds[1].result.content[0]?.type).toBe("text");
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end");
const queuedMessageIndex = events.findIndex(
(event) =>
event.type === "message_start" &&
event.message.role === "user" &&
typeof event.message.content === "string" &&
event.message.content === "interrupt",
);
const nextAssistantIndex = events.findIndex(
(event, index) =>
index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant",
);
expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex);
expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex);
expect(sawInterruptInContext).toBe(true);
});
});

View file

@ -1,701 +0,0 @@
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js";
import { calculateTool } from "../src/agent/tools/calculate.js";
import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js";
import { getModel } from "../src/models.js";
import type {
Api,
AssistantMessage,
Message,
Model,
OptionsForApi,
ToolResultMessage,
UserMessage,
} from "../src/types.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
]);
const [anthropicOAuthToken, githubCopilotToken, geminiCliToken, antigravityToken] = oauthTokens;
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create the user prompt asking for multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: `Use the calculator tool to complete the following mulit-step task.
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
2. Calculate the sum of the two results using the calculator tool
3. Output ONLY the final sum as a single integer number, nothing else.`,
timestamp: Date.now(),
};
// Calculate expected results (using integers)
const expectedFirst = 3485 * 4234; // = 14755490
const expectedSecond = 88823 * 3482; // = 309281786
const expectedSum = expectedFirst + expectedSecond; // = 324037276
// Track events for verification
const events: AgentEvent[] = [];
let turns = 0;
let toolCallCount = 0;
const toolResults: number[] = [];
let finalAnswer: number | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config);
for await (const event of stream) {
events.push(event);
switch (event.type) {
case "turn_start":
turns++;
console.log(`\n=== Turn ${turns} started ===`);
break;
case "turn_end":
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
console.log(event.message);
break;
case "tool_execution_end":
if (!event.isError && typeof event.result === "object" && event.result.content) {
const textOutput = event.result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
toolCallCount++;
// Extract number from output like "expression = result"
const match = textOutput.match(/=\s*([\d.]+)/);
if (match) {
const value = parseFloat(match[1]);
toolResults.push(value);
console.log(`Tool ${toolCallCount}: ${textOutput}`);
}
}
break;
case "message_end":
// Just track the message end event, don't extract answer here
break;
}
}
// Get the final messages
const finalMessages = await stream.result();
// Verify the results
expect(finalMessages).toBeDefined();
expect(finalMessages.length).toBeGreaterThan(0);
const finalMessage = finalMessages[finalMessages.length - 1];
expect(finalMessage).toBeDefined();
expect(finalMessage.role).toBe("assistant");
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Extract the final answer from the last assistant message
const content = finalMessage.content
.filter((c) => c.type === "text")
.map((c) => (c.type === "text" ? c.text : ""))
.join(" ");
// Look for integers in the response that might be the final answer
const numbers = content.match(/\b\d+\b/g);
if (numbers) {
// Check if any of the numbers matches our expected sum
for (const num of numbers) {
const value = parseInt(num, 10);
if (Math.abs(value - expectedSum) < 10) {
finalAnswer = value;
break;
}
}
// If no exact match, take the last large number as likely the answer
if (finalAnswer === undefined) {
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
if (largeNumbers.length > 0) {
finalAnswer = largeNumbers[largeNumbers.length - 1];
}
}
}
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
// (or possibly 2 if the model calculates the sum itself without a tool)
expect(toolCallCount).toBeGreaterThanOrEqual(2);
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
expect(turns).toBeGreaterThanOrEqual(3);
expect(turns).toBeLessThanOrEqual(4);
// Verify the individual calculations are in the results
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
expect(hasFirstCalc).toBe(true);
expect(hasSecondCalc).toBe(true);
// Verify the final sum
if (finalAnswer !== undefined) {
expect(finalAnswer).toBe(expectedSum);
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
} else {
// If we couldn't extract the final answer from text, check if it's in the tool results
const hasSum = toolResults.some((r) => r === expectedSum);
expect(hasSum).toBe(true);
}
// Log summary
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
if (turns === 3) {
console.log("Model used parallel tool calls for initial calculations");
} else {
console.log("Model used sequential tool calls");
}
return {
turns,
toolCallCount,
toolResults,
finalAnswer,
events,
};
}
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create a prompt that will require multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
timestamp: Date.now(),
};
// Create abort controller
const abortController = new AbortController();
// Track events for verification
const events: AgentEvent[] = [];
let toolCallCount = 0;
const errorReceived = false;
let finalMessages: Message[] | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config, abortController.signal);
// Abort after first tool execution
(async () => {
for await (const event of stream) {
events.push(event);
if (event.type === "tool_execution_end" && !event.isError) {
toolCallCount++;
// Abort after first successful tool execution
if (toolCallCount === 1) {
console.log("Aborting after first tool execution");
abortController.abort();
}
}
if (event.type === "agent_end") {
finalMessages = event.messages;
}
}
})();
finalMessages = await stream.result();
// Verify abort behavior
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
const assistantMessage = finalMessages[finalMessages.length - 1];
if (!assistantMessage) throw new Error("No final message received");
expect(assistantMessage).toBeDefined();
expect(assistantMessage.role).toBe("assistant");
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Should have executed 1 tool call before abort
expect(toolCallCount).toBeGreaterThanOrEqual(1);
expect(assistantMessage.stopReason).toBe("aborted");
return {
toolCallCount,
events,
errorReceived,
finalMessages,
};
}
describe("Agent Calculator Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
const model = getModel("openai", "gpt-5-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
const model = getModel("xai", "grok-3");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
const model = getModel("zai", "glm-4.5-air");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider Agent", () => {
const model = getModel("mistral", "devstral-medium-latest");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it.skipIf(!anthropicOAuthToken)(
"should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const result = await calculateTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!anthropicOAuthToken)("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe("GitHub Copilot Provider Agent", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)("gpt-4o - should handle abort during tool execution", { retry: 3 }, async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Gemini CLI Provider Agent", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await calculateTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await abortTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Antigravity Provider Agent", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gemini-3-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
});
describe("agentLoopContinue", () => {
describe("validation", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
const baseContext: AgentContext = {
systemPrompt: "You are a helpful assistant.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = { model };
it("should throw when context has no messages", () => {
expect(() => agentLoopContinue(baseContext, config)).toThrow("Cannot continue: no messages in context");
});
it("should throw when last message is an assistant message", () => {
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "Hello" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const context: AgentContext = {
...baseContext,
messages: [assistantMessage],
};
expect(() => agentLoopContinue(context, config)).toThrow(
"Cannot continue from message role: assistant. Expected 'user' or 'toolResult'.",
);
});
// Note: "should not throw" tests for valid inputs are covered by the E2E tests below
// which actually consume the stream and verify the output
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and get assistant response when last message is user", { retry: 3 }, async () => {
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Verify event sequence - no user message events since we're continuing
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
// Should NOT have user message events (that's the difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1); // Only assistant message
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue processing after tool results", { retry: 3 }, async () => {
// Simulate a conversation where:
// 1. User asked to calculate something
// 2. Assistant made a tool call
// 3. Tool result is ready
// 4. We continue from here
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What is 5 + 3? Use the calculator." }],
timestamp: Date.now(),
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{ type: "text", text: "Let me calculate that for you." },
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now(),
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: "calc-1",
toolName: "calculate",
content: [{ type: "text", text: "5 + 3 = 8" }],
isError: false,
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. After getting a calculation result, state the answer clearly.",
messages: [userMessage, assistantMessage, toolResult],
tools: [calculateTool],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBeGreaterThanOrEqual(1);
const lastMessage = messages[messages.length - 1];
expect(lastMessage.role).toBe("assistant");
// The assistant should mention the result (8)
if (lastMessage.role === "assistant") {
const textContent = lastMessage.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join(" ");
expect(textContent).toMatch(/8/);
}
});
});
});

View file

@ -1,4 +1,4 @@
import { type Static, Type } from "@sinclair/typebox";
import { Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
@ -7,7 +7,7 @@ const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import { describe, expect, it } from "vitest";
import type { AgentTool } from "../src/agent/types.js";
import type { Tool } from "../src/types.js";
describe("Tool Validation with TypeBox and AJV", () => {
// Define a test tool with TypeBox schema
@ -18,20 +18,11 @@ describe("Tool Validation with TypeBox and AJV", () => {
tags: Type.Optional(Type.Array(Type.String())),
});
type TestParams = Static<typeof testSchema>;
const testTool: AgentTool<typeof testSchema, void> = {
label: "Test Tool",
const testTool = {
name: "test_tool",
description: "A test tool for validation",
parameters: testSchema,
execute: async (_toolCallId, args) => {
return {
content: [{ type: "text", text: `Processed: ${args.name}, ${args.age}, ${args.email}` }],
details: undefined,
};
},
};
} satisfies Tool<typeof testSchema>;
// Create AJV instance for validation
const ajv = new Ajv({ allErrors: true });
@ -115,26 +106,4 @@ describe("Tool Validation with TypeBox and AJV", () => {
expect(errors).toContain('email: must match format "email"');
}
});
it("should have type-safe execute function", async () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
};
// Validate and execute
const validate = ajv.compile(testTool.parameters);
const isValid = validate(validInput);
expect(isValid).toBe(true);
const result = await testTool.execute("test-id", validInput as TestParams);
const textOutput = result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
expect(textOutput).toBe("Processed: John Doe, 30, john@example.com");
expect(result.details).toBeUndefined();
});
});

View file

@ -13,7 +13,14 @@
* Modes use this class and add their own I/O layer on top.
*/
import type { Agent, AgentEvent, AgentState, AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type {
Agent,
AgentEvent,
AgentMessage,
AgentState,
Attachment,
ThinkingLevel,
} from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Message, Model, TextContent } from "@mariozechner/pi-ai";
import { isContextOverflow, modelsAreEqual, supportsXhigh } from "@mariozechner/pi-ai";
import { getAuthPath } from "../config.js";
@ -403,7 +410,7 @@ export class AgentSession {
}
/** All messages including custom types like BashExecutionMessage */
get messages(): AppMessage[] {
get messages(): AgentMessage[] {
return this.agent.state.messages;
}

View file

@ -5,17 +5,17 @@
* and after compaction the session is reloaded.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@mariozechner/pi-ai";
import { complete } from "@mariozechner/pi-ai";
import { messageTransformer } from "./messages.js";
import { type CompactionEntry, createSummaryMessage, type SessionEntry } from "./session-manager.js";
/**
* Extract AppMessage from an entry if it produces one.
* Extract AgentMessage from an entry if it produces one.
* Returns null for entries that don't contribute to LLM context.
*/
function getMessageFromEntry(entry: SessionEntry): AppMessage | null {
function getMessageFromEntry(entry: SessionEntry): AgentMessage | null {
if (entry.type === "message") {
return entry.message;
}
@ -73,7 +73,7 @@ export function calculateContextTokens(usage: Usage): number {
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AppMessage): Usage | null {
function getAssistantUsage(msg: AgentMessage): Usage | null {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
@ -113,7 +113,7 @@ export function shouldCompact(contextTokens: number, contextWindow: number, sett
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AppMessage): number {
export function estimateTokens(message: AgentMessage): number {
let chars = 0;
// Handle bashExecution messages
@ -323,7 +323,7 @@ Be concise, structured, and focused on helping the next LLM seamlessly continue
* Generate a summary of the conversation using the LLM.
*/
export async function generateSummary(
currentMessages: AppMessage[],
currentMessages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
@ -371,9 +371,9 @@ export interface CompactionPreparation {
/** UUID of first entry to keep */
firstKeptEntryId: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AppMessage[];
messagesToSummarize: AgentMessage[];
/** Messages that will be kept after the summary (recent turns) */
messagesToKeep: AppMessage[];
messagesToKeep: AgentMessage[];
tokensBefore: number;
boundaryStart: number;
}
@ -408,14 +408,14 @@ export function prepareCompaction(entries: SessionEntry[], settings: CompactionS
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AppMessage[] = [];
const messagesToSummarize: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) messagesToSummarize.push(msg);
}
// Messages to keep (recent turns, kept after summary)
const messagesToKeep: AppMessage[] = [];
const messagesToKeep: AgentMessage[] = [];
for (let i = cutPoint.firstKeptEntryIndex; i < boundaryEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) messagesToKeep.push(msg);
@ -482,7 +482,7 @@ export async function compact(
// Extract messages for history summary (before the turn that contains the cut point)
const historyEnd = cutResult.isSplitTurn ? cutResult.turnStartIndex : cutResult.firstKeptEntryIndex;
const historyMessages: AppMessage[] = [];
const historyMessages: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(entries[i]);
if (msg) historyMessages.push(msg);
@ -499,7 +499,7 @@ export async function compact(
}
// Extract messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AppMessage[] = [];
const turnPrefixMessages: AgentMessage[] = [];
if (cutResult.isSplitTurn) {
for (let i = cutResult.turnStartIndex; i < cutResult.firstKeptEntryIndex; i++) {
const msg = getMessageFromEntry(entries[i]);
@ -550,7 +550,7 @@ export async function compact(
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AppMessage[],
messages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,

View file

@ -5,7 +5,7 @@
* and interact with the user via UI primitives.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, Message, Model, TextContent, ToolResultMessage } from "@mariozechner/pi-ai";
import type { Component } from "@mariozechner/pi-tui";
import type { Theme } from "../../modes/interactive/theme/theme.js";
@ -151,7 +151,7 @@ export type SessionEvent =
* Event data for context event.
* Fired before each LLM call, allowing hooks to modify context non-destructively.
* Original session messages are NOT modified - only the messages sent to the LLM are affected.
* Messages are already in LLM format (Message[], not AppMessage[]).
* Messages are already in LLM format (Message[], not AgentMessage[]).
*/
export interface ContextEvent {
type: "context";
@ -172,7 +172,7 @@ export interface AgentStartEvent {
*/
export interface AgentEndEvent {
type: "agent_end";
messages: AppMessage[];
messages: AgentMessage[];
}
/**
@ -190,7 +190,7 @@ export interface TurnStartEvent {
export interface TurnEndEvent {
type: "turn_end";
turnIndex: number;
message: AppMessage;
message: AgentMessage;
toolResults: ToolResultMessage[];
}

View file

@ -1,11 +1,11 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AppMessage type with coding-agent specific message types,
* Extends the base AgentMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { Message } from "@mariozechner/pi-ai";
// ============================================================================
@ -56,14 +56,14 @@ declare module "@mariozechner/pi-agent-core" {
/**
* Type guard for BashExecutionMessage.
*/
export function isBashExecutionMessage(msg: AppMessage | Message): msg is BashExecutionMessage {
export function isBashExecutionMessage(msg: AgentMessage | Message): msg is BashExecutionMessage {
return (msg as BashExecutionMessage).role === "bashExecution";
}
/**
* Type guard for HookAppMessage.
* Type guard for HookAgentMessage.
*/
export function isHookMessage(msg: AppMessage | Message): msg is HookMessage {
export function isHookMessage(msg: AgentMessage | Message): msg is HookMessage {
return (msg as HookMessage).role === "hookMessage";
}
@ -97,13 +97,13 @@ export function bashExecutionToText(msg: BashExecutionMessage): string {
// ============================================================================
/**
* Transform AppMessages (including custom types) to LLM-compatible Messages.
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's messageTransformer option (for prompt calls)
* - Compaction's generateSummary (for summarization)
*/
export function messageTransformer(messages: AppMessage[]): Message[] {
export function messageTransformer(messages: AgentMessage[]): Message[] {
return messages
.map((m): Message | null => {
if (isBashExecutionMessage(m)) {

View file

@ -1,4 +1,4 @@
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent, TextContent } from "@mariozechner/pi-ai";
import { randomUUID } from "crypto";
import {
@ -36,7 +36,7 @@ export interface SessionEntryBase {
export interface SessionMessageEntry extends SessionEntryBase {
type: "message";
message: AppMessage;
message: AgentMessage;
}
export interface ThinkingLevelChangeEntry extends SessionEntryBase {
@ -130,7 +130,7 @@ export interface SessionTreeNode {
}
export interface SessionContext {
messages: AppMessage[];
messages: AgentMessage[];
thinkingLevel: string;
model: { provider: string; modelId: string } | null;
}
@ -154,7 +154,7 @@ export const SUMMARY_SUFFIX = `
</summary>`;
/** Exported for compaction.test.ts */
export function createSummaryMessage(summary: string, timestamp: string): AppMessage {
export function createSummaryMessage(summary: string, timestamp: string): AgentMessage {
return {
role: "user",
content: SUMMARY_PREFIX + summary + SUMMARY_SUFFIX,
@ -162,8 +162,8 @@ export function createSummaryMessage(summary: string, timestamp: string): AppMes
};
}
/** Convert CustomMessageEntry to AppMessage format */
function createCustomMessage(entry: CustomMessageEntry): AppMessage {
/** Convert CustomMessageEntry to AgentMessage format */
function createCustomMessage(entry: CustomMessageEntry): AgentMessage {
return {
role: "user",
content: entry.content,
@ -323,7 +323,7 @@ export function buildSessionContext(
// 1. Emit summary first (entry = compaction)
// 2. Emit kept messages (from firstKeptEntryId up to compaction)
// 3. Emit messages after compaction
const messages: AppMessage[] = [];
const messages: AgentMessage[] = [];
if (compaction) {
// Emit summary first
@ -595,7 +595,7 @@ export class SessionManager {
}
/** Append a message as child of current leaf, then advance leaf. Returns entry id. */
appendMessage(message: AppMessage): string {
appendMessage(message: AgentMessage): string {
const entry: SessionMessageEntry = {
type: "message",
id: generateId(this.byId),

View file

@ -6,7 +6,7 @@
import * as fs from "node:fs";
import * as os from "node:os";
import * as path from "node:path";
import type { AgentState, AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage, AgentState } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Message, OAuthProvider } from "@mariozechner/pi-ai";
import type { SlashCommand } from "@mariozechner/pi-tui";
import {
@ -1051,7 +1051,7 @@ export class InteractiveMode {
this.ui.requestRender();
}
private addMessageToChat(message: AppMessage): void {
private addMessageToChat(message: AgentMessage): void {
if (isBashExecutionMessage(message)) {
const component = new BashExecutionComponent(message.command, this.ui);
if (message.output) {

View file

@ -6,7 +6,7 @@
import { type ChildProcess, spawn } from "node:child_process";
import * as readline from "node:readline";
import type { AgentEvent, AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { AgentEvent, AgentMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { SessionStats } from "../../core/agent-session.js";
import type { BashResult } from "../../core/bash-executor.js";
import type { CompactionResult } from "../../core/compaction.js";
@ -349,9 +349,9 @@ export class RpcClient {
/**
* Get all messages in the session.
*/
async getMessages(): Promise<AppMessage[]> {
async getMessages(): Promise<AgentMessage[]> {
const response = await this.send({ type: "get_messages" });
return this.getData<{ messages: AppMessage[] }>(response).messages;
return this.getData<{ messages: AgentMessage[] }>(response).messages;
}
// =========================================================================

View file

@ -5,7 +5,7 @@
* Responses and events are emitted as JSON lines on stdout.
*/
import type { AppMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { AgentMessage, Attachment, ThinkingLevel } from "@mariozechner/pi-agent-core";
import type { Model } from "@mariozechner/pi-ai";
import type { SessionStats } from "../../core/agent-session.js";
import type { BashResult } from "../../core/bash-executor.js";
@ -161,7 +161,7 @@ export type RpcResponse =
}
// Messages
| { id?: string; type: "response"; command: "get_messages"; success: true; data: { messages: AppMessage[] } }
| { id?: string; type: "response"; command: "get_messages"; success: true; data: { messages: AgentMessage[] } }
// Error response (any command can fail)
| { id?: string; type: "response"; command: string; success: false; error: string };

View file

@ -1,4 +1,4 @@
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AssistantMessage, Usage } from "@mariozechner/pi-ai";
import { getModel } from "@mariozechner/pi-ai";
import { readFileSync } from "fs";
@ -48,7 +48,7 @@ function createMockUsage(input: number, output: number, cacheRead = 0, cacheWrit
};
}
function createUserMessage(text: string): AppMessage {
function createUserMessage(text: string): AgentMessage {
return { role: "user", content: text, timestamp: Date.now() };
}
@ -78,7 +78,7 @@ beforeEach(() => {
resetEntryCounter();
});
function createMessageEntry(message: AppMessage): SessionMessageEntry {
function createMessageEntry(message: AgentMessage): SessionMessageEntry {
const id = `test-id-${entryCounter++}`;
const entry: SessionMessageEntry = {
type: "message",

View file

@ -10,7 +10,7 @@
* - MomSettingsManager: Simple settings for mom (compaction, retry, model preferences)
*/
import type { AppMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import {
buildSessionContext,
type CompactionEntry,
@ -153,7 +153,7 @@ export class MomSessionManager {
contextSlackTimestamps.add(entry.timestamp);
// Also store message text to catch duplicates added via prompt()
// AppMessage has different shapes, check for content property
// AgentMessage has different shapes, check for content property
const msg = msgEntry.message as { role: string; content?: unknown };
if (msg.role === "user" && msg.content !== undefined) {
const content = msg.content;
@ -189,7 +189,7 @@ export class MomSessionManager {
isBot?: boolean;
}
const newMessages: Array<{ timestamp: string; slackTs: string; message: AppMessage }> = [];
const newMessages: Array<{ timestamp: string; slackTs: string; message: AgentMessage }> = [];
for (const line of logLines) {
try {
@ -215,7 +215,7 @@ export class MomSessionManager {
if (contextMessageTexts.has(messageText)) continue;
const msgTime = new Date(date).getTime() || Date.now();
const userMessage: AppMessage = {
const userMessage: AgentMessage = {
role: "user",
content: messageText,
timestamp: msgTime,
@ -277,7 +277,7 @@ export class MomSessionManager {
return entries;
}
saveMessage(message: AppMessage): void {
saveMessage(message: AgentMessage): void {
const entry: SessionMessageEntry = { ...this._createEntryBase(), type: "message", message };
this.inMemoryEntries.push(entry);
this._persist(entry);