WIP: Refactor agent package - not compiling

- Renamed AppMessage to AgentMessage throughout
- New agent-loop.ts with AgentLoopContext, AgentLoopConfig
- Removed transport abstraction, Agent now takes streamFn directly
- Extracted streamProxy to proxy.ts utility
- Removed agent-loop from pi-ai (now in agent package)
- Updated consumers (coding-agent, mom) for AgentMessage rename
- Tests updated but some consumers still need migration

Known issues:
- AgentTool, AgentToolResult not exported from pi-ai
- Attachment not exported from pi-agent-core
- ProviderTransport removed but still referenced
- messageTransformer -> convertToLlm migration incomplete
- CustomMessages declaration merging not working properly
This commit is contained in:
Mario Zechner 2025-12-28 09:23:38 +01:00
parent f7ef44dc38
commit a055fd4481
32 changed files with 1312 additions and 2009 deletions

View file

@ -1,376 +0,0 @@
import { streamSimple } from "../stream.js";
import type { AssistantMessage, Context, Message, ToolResultMessage, UserMessage } from "../types.js";
import { EventStream } from "../utils/event-stream.js";
import { validateToolArguments } from "../utils/validation.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, QueuedMessage } from "./types.js";
/**
* Start an agent loop with a new user message.
* The prompt is added to the context and events are emitted for it.
*/
export function agentLoop(
prompt: UserMessage,
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [prompt];
const currentContext: AgentContext = {
...context,
messages: [...context.messages, prompt],
};
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
return stream;
}
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retry after overflow - context already has user message or tool results.
* Throws if the last message is not a user message or tool result.
*/
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retry after overflow - context already has user message or tool results.
* Throws if the last message is not a user message or tool result.
*/
export function agentLoopContinue(
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: typeof streamSimple,
): EventStream<AgentEvent, AgentContext["messages"]> {
// Validate that we can continue from this context
const lastMessage = context.messages[context.messages.length - 1];
if (!lastMessage) {
throw new Error("Cannot continue: no messages in context");
}
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
throw new Error(`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`);
}
const stream = createAgentStream();
(async () => {
const newMessages: AgentContext["messages"] = [];
const currentContext: AgentContext = { ...context };
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
// No user message events - we're continuing from existing context
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
return stream;
}
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
return new EventStream<AgentEvent, AgentContext["messages"]>(
(event: AgentEvent) => event.type === "agent_end",
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
);
}
/**
* Shared loop logic for both agentLoop and agentLoopContinue.
*/
async function runLoop(
currentContext: AgentContext,
newMessages: AgentContext["messages"],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
): Promise<void> {
let hasMoreToolCalls = true;
let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: QueuedMessage<any>[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) {
stream.push({ type: "turn_start" });
} else {
firstTurn = false;
}
// Process queued messages first (inject before next assistant response)
if (queuedMessages.length > 0) {
for (const { original, llm } of queuedMessages) {
stream.push({ type: "message_start", message: original });
stream.push({ type: "message_end", message: original });
if (llm) {
currentContext.messages.push(llm);
newMessages.push(llm);
}
}
queuedMessages = [];
}
// Stream assistant response
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
newMessages.push(message);
if (message.stopReason === "error" || message.stopReason === "aborted") {
// Stop the loop on error or abort
stream.push({ type: "turn_end", message, toolResults: [] });
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
return;
}
// Check for tool calls
const toolCalls = message.content.filter((c) => c.type === "toolCall");
hasMoreToolCalls = toolCalls.length > 0;
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
// Execute tool calls
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
signal,
stream,
config.getQueuedMessages,
);
toolResults.push(...toolExecution.toolResults);
queuedAfterTools = toolExecution.queuedMessages ?? null;
currentContext.messages.push(...toolResults);
newMessages.push(...toolResults);
}
stream.push({ type: "turn_end", message, toolResults: toolResults });
// Get queued messages after turn completes
if (queuedAfterTools && queuedAfterTools.length > 0) {
queuedMessages = queuedAfterTools;
queuedAfterTools = null;
} else {
queuedMessages = (await config.getQueuedMessages?.()) || [];
}
}
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
}
// Helper functions
async function streamAssistantResponse(
context: AgentContext,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentContext["messages"]>,
streamFn?: typeof streamSimple,
): Promise<AssistantMessage> {
// Convert AgentContext to Context for streamSimple
// Use a copy of messages to avoid mutating the original context
const processedMessages = config.preprocessor
? await config.preprocessor(context.messages, signal)
: [...context.messages];
const processedContext: Context = {
systemPrompt: context.systemPrompt,
messages: [...processedMessages].map((m) => {
if (m.role === "toolResult") {
// biome-ignore lint/correctness/noUnusedVariables: fine here
const { details, ...rest } = m;
return rest;
} else {
return m;
}
}),
tools: context.tools, // AgentTool extends Tool, so this works
};
// Use custom stream function if provided, otherwise use default streamSimple
const streamFunction = streamFn || streamSimple;
// Resolve API key for every assistant response (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
const response = await streamFunction(config.model, processedContext, { ...config, apiKey: resolvedApiKey, signal });
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
for await (const event of response) {
switch (event.type) {
case "start":
partialMessage = event.partial;
context.messages.push(partialMessage);
addedPartial = true;
stream.push({ type: "message_start", message: { ...partialMessage } });
break;
case "text_start":
case "text_delta":
case "text_end":
case "thinking_start":
case "thinking_delta":
case "thinking_end":
case "toolcall_start":
case "toolcall_delta":
case "toolcall_end":
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({ type: "message_update", assistantMessageEvent: event, message: { ...partialMessage } });
}
break;
case "done":
case "error": {
const finalMessage = await response.result();
if (addedPartial) {
context.messages[context.messages.length - 1] = finalMessage;
} else {
context.messages.push(finalMessage);
}
if (!addedPartial) {
stream.push({ type: "message_start", message: { ...finalMessage } });
}
stream.push({ type: "message_end", message: finalMessage });
return finalMessage;
}
}
}
return await response.result();
}
async function executeToolCalls<T>(
tools: AgentTool<any, T>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{ toolResults: ToolResultMessage<T>[]; queuedMessages?: QueuedMessage<any>[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage<any>[] = [];
let queuedMessages: QueuedMessage<any>[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
const tool = tools?.find((t) => t.name === toolCall.name);
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
let result: AgentToolResult<T>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
// Validate arguments using shared validation function
const validatedArgs = validateToolArguments(tool, toolCall);
// Execute with validated, typed arguments, passing update callback
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
});
} catch (e) {
result = {
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
details: {} as T,
};
isError = true;
}
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError,
});
const toolResultMessage: ToolResultMessage<T> = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError,
timestamp: Date.now(),
};
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
if (getQueuedMessages) {
const queued = await getQueuedMessages();
if (queued.length > 0) {
queuedMessages = queued;
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
break;
}
}
}
return { toolResults: results, queuedMessages };
}
function skipToolCall<T>(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, Message[]>,
): ToolResultMessage<T> {
const result: AgentToolResult<T> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {} as T,
};
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError: true,
});
const toolResultMessage: ToolResultMessage<T> = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError: true,
timestamp: Date.now(),
};
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
return toolResultMessage;
}

View file

@ -1,11 +0,0 @@
export { agentLoop, agentLoopContinue } from "./agent-loop.js";
export * from "./tools/index.js";
export type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentTool,
AgentToolResult,
AgentToolUpdateCallback,
QueuedMessage,
} from "./types.js";

View file

@ -1,32 +0,0 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool, AgentToolResult } from "../../agent/types.js";
export interface CalculateResult extends AgentToolResult<undefined> {
content: Array<{ type: "text"; text: string }>;
details: undefined;
}
export function calculate(expression: string): CalculateResult {
try {
const result = new Function(`return ${expression}`)();
return { content: [{ type: "text", text: `${expression} = ${result}` }], details: undefined };
} catch (e: any) {
throw new Error(e.message || String(e));
}
}
const calculateSchema = Type.Object({
expression: Type.String({ description: "The mathematical expression to evaluate" }),
});
type CalculateParams = Static<typeof calculateSchema>;
export const calculateTool: AgentTool<typeof calculateSchema, undefined> = {
label: "Calculator",
name: "calculate",
description: "Evaluate mathematical expressions",
parameters: calculateSchema,
execute: async (_toolCallId: string, args: CalculateParams) => {
return calculate(args.expression);
},
};

View file

@ -1,47 +0,0 @@
import { type Static, Type } from "@sinclair/typebox";
import type { AgentTool } from "../../agent/index.js";
import type { AgentToolResult } from "../types.js";
export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {}
export async function getCurrentTime(timezone?: string): Promise<GetCurrentTimeResult> {
const date = new Date();
if (timezone) {
try {
const timeStr = date.toLocaleString("en-US", {
timeZone: timezone,
dateStyle: "full",
timeStyle: "long",
});
return {
content: [{ type: "text", text: timeStr }],
details: { utcTimestamp: date.getTime() },
};
} catch (_e) {
throw new Error(`Invalid timezone: ${timezone}. Current UTC time: ${date.toISOString()}`);
}
}
const timeStr = date.toLocaleString("en-US", { dateStyle: "full", timeStyle: "long" });
return {
content: [{ type: "text", text: timeStr }],
details: { utcTimestamp: date.getTime() },
};
}
const getCurrentTimeSchema = Type.Object({
timezone: Type.Optional(
Type.String({ description: "Optional timezone (e.g., 'America/New_York', 'Europe/London')" }),
),
});
type GetCurrentTimeParams = Static<typeof getCurrentTimeSchema>;
export const getCurrentTimeTool: AgentTool<typeof getCurrentTimeSchema, { utcTimestamp: number }> = {
label: "Current Time",
name: "get_current_time",
description: "Get the current date and time",
parameters: getCurrentTimeSchema,
execute: async (_toolCallId: string, args: GetCurrentTimeParams) => {
return getCurrentTime(args.timezone);
},
};

View file

@ -1,2 +0,0 @@
export { calculate, calculateTool } from "./calculate.js";
export { getCurrentTime, getCurrentTimeTool } from "./get-current-time.js";

View file

@ -1,105 +0,0 @@
import type { Static, TSchema } from "@sinclair/typebox";
import type {
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
TextContent,
Tool,
ToolResultMessage,
} from "../types.js";
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: Message[];
tools?: AgentTool<any>[];
}
// Event types
export type AgentEvent =
// Emitted when the agent starts. An agent can emit multiple turns
| { type: "agent_start" }
// Emitted when a turn starts. A turn can emit an optional user message (initial prompt), an assistant message (response) and multiple tool result messages
| { type: "turn_start" }
// Emitted when a user, assistant or tool result message starts
| { type: "message_start"; message: Message }
// Emitted when an asssitant messages is updated due to streaming
| { type: "message_update"; assistantMessageEvent: AssistantMessageEvent; message: AssistantMessage }
// Emitted when a user, assistant or tool result message is complete
| { type: "message_end"; message: Message }
// Emitted when a tool execution starts
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
// Emitted when a tool execution produces output (streaming)
| {
type: "tool_execution_update";
toolCallId: string;
toolName: string;
args: any;
partialResult: AgentToolResult<any>;
}
// Emitted when a tool execution completes
| {
type: "tool_execution_end";
toolCallId: string;
toolName: string;
result: AgentToolResult<any>;
isError: boolean;
}
// Emitted when a full turn completes
| { type: "turn_end"; message: AssistantMessage; toolResults: ToolResultMessage[] }
// Emitted when the agent has completed all its turns. All messages from every turn are
// contained in messages, which can be appended to the context
| { type: "agent_end"; messages: AgentContext["messages"] };
// Queued message with optional LLM representation
export interface QueuedMessage<TApp = Message> {
original: TApp; // Original message for UI events
llm?: Message; // Optional transformed message for loop context (undefined if filtered)
}
// Configuration for agent loop execution
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Optional hook to resolve an API key dynamically for each LLM call.
*
* This is useful for short-lived OAuth tokens (e.g. GitHub Copilot) that may
* expire during long-running tool execution phases.
*
* The agent loop will call this before each assistant response and pass the
* returned value as `apiKey` to `streamSimple()` (or a custom `streamFn`).
*
* If it returns `undefined`, the loop falls back to `config.apiKey`, and then
* to `streamSimple()`'s own provider key lookup (setApiKey/env vars).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
preprocessor?: (messages: AgentContext["messages"], abortSignal?: AbortSignal) => Promise<AgentContext["messages"]>;
getQueuedMessages?: <T>() => Promise<QueuedMessage<T>[]>;
}

View file

@ -1,4 +1,3 @@
export * from "./agent/index.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/google.js";
@ -7,6 +6,7 @@ export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/oauth/index.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";

View file

@ -1,166 +0,0 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop } from "../src/agent/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js";
import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js";
import { AssistantMessageEventStream } from "../src/utils/event-stream.js";
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
describe("agentLoop queued message interrupt", () => {
it("injects queued messages after a tool call and skips remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: UserMessage = {
role: "user",
content: "start",
timestamp: Date.now(),
};
const queuedUserMessage: Message = {
role: "user",
content: "interrupt",
timestamp: Date.now(),
};
const queuedMessages: QueuedMessage<Message>[] = [{ original: queuedUserMessage, llm: queuedUserMessage }];
let queuedDelivered = false;
let sawInterruptInContext = false;
let callIndex = 0;
const streamFn = () => {
const stream = new AssistantMessageEventStream();
queueMicrotask(() => {
if (callIndex === 0) {
const message: AssistantMessage = {
role: "assistant",
content: [
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "toolUse",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "toolUse", message });
} else {
const message: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "done" }],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "stop", message });
}
callIndex += 1;
});
return stream;
};
const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async <T>() => {
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return queuedMessages as QueuedMessage<T>[];
}
return [];
};
const config: AgentLoopConfig = {
model: createModel(),
getQueuedMessages,
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
return streamFn();
});
for await (const event of stream) {
events.push(event);
}
expect(executed).toEqual(["first"]);
const toolEnds = events.filter(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[1].isError).toBe(true);
expect(toolEnds[1].result.content[0]?.type).toBe("text");
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end");
const queuedMessageIndex = events.findIndex(
(event) =>
event.type === "message_start" &&
event.message.role === "user" &&
typeof event.message.content === "string" &&
event.message.content === "interrupt",
);
const nextAssistantIndex = events.findIndex(
(event, index) =>
index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant",
);
expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex);
expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex);
expect(sawInterruptInContext).toBe(true);
});
});

View file

@ -1,701 +0,0 @@
import { describe, expect, it } from "vitest";
import { agentLoop, agentLoopContinue } from "../src/agent/agent-loop.js";
import { calculateTool } from "../src/agent/tools/calculate.js";
import type { AgentContext, AgentEvent, AgentLoopConfig } from "../src/agent/types.js";
import { getModel } from "../src/models.js";
import type {
Api,
AssistantMessage,
Message,
Model,
OptionsForApi,
ToolResultMessage,
UserMessage,
} from "../src/types.js";
import { resolveApiKey } from "./oauth.js";
// Resolve OAuth tokens at module level (async, runs before tests)
const oauthTokens = await Promise.all([
resolveApiKey("anthropic"),
resolveApiKey("github-copilot"),
resolveApiKey("google-gemini-cli"),
resolveApiKey("google-antigravity"),
]);
const [anthropicOAuthToken, githubCopilotToken, geminiCliToken, antigravityToken] = oauthTokens;
async function calculateTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. When asked to calculate multiple expressions, you can use parallel tool calls if the model supports it. In your final answer, output ONLY the final sum as a single integer number, nothing else.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create the user prompt asking for multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: `Use the calculator tool to complete the following mulit-step task.
1. Calculate 3485 * 4234 and 88823 * 3482 in parallel
2. Calculate the sum of the two results using the calculator tool
3. Output ONLY the final sum as a single integer number, nothing else.`,
timestamp: Date.now(),
};
// Calculate expected results (using integers)
const expectedFirst = 3485 * 4234; // = 14755490
const expectedSecond = 88823 * 3482; // = 309281786
const expectedSum = expectedFirst + expectedSecond; // = 324037276
// Track events for verification
const events: AgentEvent[] = [];
let turns = 0;
let toolCallCount = 0;
const toolResults: number[] = [];
let finalAnswer: number | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config);
for await (const event of stream) {
events.push(event);
switch (event.type) {
case "turn_start":
turns++;
console.log(`\n=== Turn ${turns} started ===`);
break;
case "turn_end":
console.log(`=== Turn ${turns} ended with ${event.toolResults.length} tool results ===`);
console.log(event.message);
break;
case "tool_execution_end":
if (!event.isError && typeof event.result === "object" && event.result.content) {
const textOutput = event.result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
toolCallCount++;
// Extract number from output like "expression = result"
const match = textOutput.match(/=\s*([\d.]+)/);
if (match) {
const value = parseFloat(match[1]);
toolResults.push(value);
console.log(`Tool ${toolCallCount}: ${textOutput}`);
}
}
break;
case "message_end":
// Just track the message end event, don't extract answer here
break;
}
}
// Get the final messages
const finalMessages = await stream.result();
// Verify the results
expect(finalMessages).toBeDefined();
expect(finalMessages.length).toBeGreaterThan(0);
const finalMessage = finalMessages[finalMessages.length - 1];
expect(finalMessage).toBeDefined();
expect(finalMessage.role).toBe("assistant");
if (finalMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Extract the final answer from the last assistant message
const content = finalMessage.content
.filter((c) => c.type === "text")
.map((c) => (c.type === "text" ? c.text : ""))
.join(" ");
// Look for integers in the response that might be the final answer
const numbers = content.match(/\b\d+\b/g);
if (numbers) {
// Check if any of the numbers matches our expected sum
for (const num of numbers) {
const value = parseInt(num, 10);
if (Math.abs(value - expectedSum) < 10) {
finalAnswer = value;
break;
}
}
// If no exact match, take the last large number as likely the answer
if (finalAnswer === undefined) {
const largeNumbers = numbers.map((n) => parseInt(n, 10)).filter((n) => n > 1000000);
if (largeNumbers.length > 0) {
finalAnswer = largeNumbers[largeNumbers.length - 1];
}
}
}
// Should have executed at least 3 tool calls: 2 for the initial calculations, 1 for the sum
// (or possibly 2 if the model calculates the sum itself without a tool)
expect(toolCallCount).toBeGreaterThanOrEqual(2);
// Must be at least 3 turns: first to calculate the expressions, then to sum them, then give the answer
// Could be 3 turns if model does parallel calls, or 4 turns if sequential calculation of expressions
expect(turns).toBeGreaterThanOrEqual(3);
expect(turns).toBeLessThanOrEqual(4);
// Verify the individual calculations are in the results
const hasFirstCalc = toolResults.some((r) => r === expectedFirst);
const hasSecondCalc = toolResults.some((r) => r === expectedSecond);
expect(hasFirstCalc).toBe(true);
expect(hasSecondCalc).toBe(true);
// Verify the final sum
if (finalAnswer !== undefined) {
expect(finalAnswer).toBe(expectedSum);
console.log(`Final answer: ${finalAnswer} (expected: ${expectedSum})`);
} else {
// If we couldn't extract the final answer from text, check if it's in the tool results
const hasSum = toolResults.some((r) => r === expectedSum);
expect(hasSum).toBe(true);
}
// Log summary
console.log(`\nTest completed with ${turns} turns and ${toolCallCount} tool calls`);
if (turns === 3) {
console.log("Model used parallel tool calls for initial calculations");
} else {
console.log("Model used sequential tool calls");
}
return {
turns,
toolCallCount,
toolResults,
finalAnswer,
events,
};
}
async function abortTest<TApi extends Api>(model: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Create the agent context with the calculator tool
const context: AgentContext = {
systemPrompt:
"You are a helpful assistant that performs mathematical calculations. Always use the calculator tool for each calculation.",
messages: [],
tools: [calculateTool],
};
// Create the prompt config
const config: AgentLoopConfig = {
model,
...options,
};
// Create a prompt that will require multiple calculations
const userPrompt: UserMessage = {
role: "user",
content: "Calculate 100 * 200, then 300 * 400, then 500 * 600, then sum all three results.",
timestamp: Date.now(),
};
// Create abort controller
const abortController = new AbortController();
// Track events for verification
const events: AgentEvent[] = [];
let toolCallCount = 0;
const errorReceived = false;
let finalMessages: Message[] | undefined;
// Execute the prompt
const stream = agentLoop(userPrompt, context, config, abortController.signal);
// Abort after first tool execution
(async () => {
for await (const event of stream) {
events.push(event);
if (event.type === "tool_execution_end" && !event.isError) {
toolCallCount++;
// Abort after first successful tool execution
if (toolCallCount === 1) {
console.log("Aborting after first tool execution");
abortController.abort();
}
}
if (event.type === "agent_end") {
finalMessages = event.messages;
}
}
})();
finalMessages = await stream.result();
// Verify abort behavior
console.log(`\nAbort test completed with ${toolCallCount} tool calls`);
const assistantMessage = finalMessages[finalMessages.length - 1];
if (!assistantMessage) throw new Error("No final message received");
expect(assistantMessage).toBeDefined();
expect(assistantMessage.role).toBe("assistant");
if (assistantMessage.role !== "assistant") throw new Error("Final message is not from assistant");
// Should have executed 1 tool call before abort
expect(toolCallCount).toBeGreaterThanOrEqual(1);
expect(assistantMessage.stopReason).toBe("aborted");
return {
toolCallCount,
events,
errorReceived,
finalMessages,
};
}
describe("Agent Calculator Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Agent", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Agent", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Agent", () => {
const model = getModel("openai", "gpt-5-mini");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Agent", () => {
const model = getModel("xai", "grok-3");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Agent", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Agent", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider Agent", () => {
const model = getModel("zai", "glm-4.5-air");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe.skipIf(!process.env.MISTRAL_API_KEY)("Mistral Provider Agent", () => {
const model = getModel("mistral", "devstral-medium-latest");
it("should calculate multiple expressions and sum the results", { retry: 3 }, async () => {
const result = await calculateTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
});
it("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model);
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
// =========================================================================
// OAuth-based providers (credentials from ~/.pi/agent/oauth.json)
// =========================================================================
describe("Anthropic OAuth Provider Agent", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it.skipIf(!anthropicOAuthToken)(
"should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const result = await calculateTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!anthropicOAuthToken)("should handle abort during tool execution", { retry: 3 }, async () => {
const result = await abortTest(model, { apiKey: anthropicOAuthToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
});
describe("GitHub Copilot Provider Agent", () => {
it.skipIf(!githubCopilotToken)(
"gpt-4o - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)("gpt-4o - should handle abort during tool execution", { retry: 3 }, async () => {
const model = getModel("github-copilot", "gpt-4o");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
});
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await calculateTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!githubCopilotToken)(
"claude-sonnet-4 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("github-copilot", "claude-sonnet-4");
const result = await abortTest(model, { apiKey: githubCopilotToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Gemini CLI Provider Agent", () => {
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await calculateTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!geminiCliToken)(
"gemini-2.5-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-gemini-cli", "gemini-2.5-flash");
const result = await abortTest(model, { apiKey: geminiCliToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
describe("Google Antigravity Provider Agent", () => {
it.skipIf(!antigravityToken)(
"gemini-3-flash - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gemini-3-flash - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gemini-3-flash");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"claude-sonnet-4-5 - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "claude-sonnet-4-5");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should calculate multiple expressions and sum the results",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await calculateTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(2);
},
);
it.skipIf(!antigravityToken)(
"gpt-oss-120b-medium - should handle abort during tool execution",
{ retry: 3 },
async () => {
const model = getModel("google-antigravity", "gpt-oss-120b-medium");
const result = await abortTest(model, { apiKey: antigravityToken });
expect(result.toolCallCount).toBeGreaterThanOrEqual(1);
},
);
});
});
describe("agentLoopContinue", () => {
describe("validation", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
const baseContext: AgentContext = {
systemPrompt: "You are a helpful assistant.",
messages: [],
tools: [],
};
const config: AgentLoopConfig = { model };
it("should throw when context has no messages", () => {
expect(() => agentLoopContinue(baseContext, config)).toThrow("Cannot continue: no messages in context");
});
it("should throw when last message is an assistant message", () => {
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "Hello" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const context: AgentContext = {
...baseContext,
messages: [assistantMessage],
};
expect(() => agentLoopContinue(context, config)).toThrow(
"Cannot continue from message role: assistant. Expected 'user' or 'toolResult'.",
);
});
// Note: "should not throw" tests for valid inputs are covered by the E2E tests below
// which actually consume the stream and verify the output
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and get assistant response when last message is user", { retry: 3 }, async () => {
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
messages: [userMessage],
tools: [],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBe(1);
expect(messages[0].role).toBe("assistant");
// Verify event sequence - no user message events since we're continuing
const eventTypes = events.map((e) => e.type);
expect(eventTypes).toContain("agent_start");
expect(eventTypes).toContain("turn_start");
expect(eventTypes).toContain("message_start");
expect(eventTypes).toContain("message_end");
expect(eventTypes).toContain("turn_end");
expect(eventTypes).toContain("agent_end");
// Should NOT have user message events (that's the difference from agentLoop)
const messageEndEvents = events.filter((e) => e.type === "message_end");
expect(messageEndEvents.length).toBe(1); // Only assistant message
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue processing after tool results", { retry: 3 }, async () => {
// Simulate a conversation where:
// 1. User asked to calculate something
// 2. Assistant made a tool call
// 3. Tool result is ready
// 4. We continue from here
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What is 5 + 3? Use the calculator." }],
timestamp: Date.now(),
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{ type: "text", text: "Let me calculate that for you." },
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now(),
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: "calc-1",
toolName: "calculate",
content: [{ type: "text", text: "5 + 3 = 8" }],
isError: false,
timestamp: Date.now(),
};
const context: AgentContext = {
systemPrompt: "You are a helpful assistant. After getting a calculation result, state the answer clearly.",
messages: [userMessage, assistantMessage, toolResult],
tools: [calculateTool],
};
const config: AgentLoopConfig = { model };
const events: AgentEvent[] = [];
const stream = agentLoopContinue(context, config);
for await (const event of stream) {
events.push(event);
}
const messages = await stream.result();
// Should have gotten an assistant response
expect(messages.length).toBeGreaterThanOrEqual(1);
const lastMessage = messages[messages.length - 1];
expect(lastMessage.role).toBe("assistant");
// The assistant should mention the result (8)
if (lastMessage.role === "assistant") {
const textContent = lastMessage.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join(" ");
expect(textContent).toMatch(/8/);
}
});
});
});

View file

@ -1,4 +1,4 @@
import { type Static, Type } from "@sinclair/typebox";
import { Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
@ -7,7 +7,7 @@ const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import { describe, expect, it } from "vitest";
import type { AgentTool } from "../src/agent/types.js";
import type { Tool } from "../src/types.js";
describe("Tool Validation with TypeBox and AJV", () => {
// Define a test tool with TypeBox schema
@ -18,20 +18,11 @@ describe("Tool Validation with TypeBox and AJV", () => {
tags: Type.Optional(Type.Array(Type.String())),
});
type TestParams = Static<typeof testSchema>;
const testTool: AgentTool<typeof testSchema, void> = {
label: "Test Tool",
const testTool = {
name: "test_tool",
description: "A test tool for validation",
parameters: testSchema,
execute: async (_toolCallId, args) => {
return {
content: [{ type: "text", text: `Processed: ${args.name}, ${args.age}, ${args.email}` }],
details: undefined,
};
},
};
} satisfies Tool<typeof testSchema>;
// Create AJV instance for validation
const ajv = new Ajv({ allErrors: true });
@ -115,26 +106,4 @@ describe("Tool Validation with TypeBox and AJV", () => {
expect(errors).toContain('email: must match format "email"');
}
});
it("should have type-safe execute function", async () => {
const validInput = {
name: "John Doe",
age: 30,
email: "john@example.com",
};
// Validate and execute
const validate = ajv.compile(testTool.parameters);
const isValid = validate(validInput);
expect(isValid).toBe(true);
const result = await testTool.execute("test-id", validInput as TestParams);
const textOutput = result.content
.filter((c: any) => c.type === "text")
.map((c: any) => c.text)
.join("\n");
expect(textOutput).toBe("Processed: John Doe, 30, john@example.com");
expect(result.details).toBeUndefined();
});
});