co-mono/packages/agent/test/e2e.test.ts
Mario Zechner 5a9d844f9a Simplify compaction: remove proactive abort, use Agent.continue() for retry
- Add agentLoopContinue() to pi-ai for resuming from existing context
- Add Agent.continue() method and transport.continue() interface
- Simplify AgentSession compaction to two cases: overflow (auto-retry) and threshold (no retry)
- Remove proactive mid-turn compaction abort
- Merge turn prefix summary into main summary
- Add isCompacting property to AgentSession and RPC state
- Block input during compaction in interactive mode
- Show compaction count on session resume
- Rename RPC.md to rpc.md for consistency

Related to #128
2025-12-09 21:43:49 +01:00

510 lines
15 KiB
TypeScript

import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { calculateTool, getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
function createTransport() {
return new ProviderTransport({
getApiKey: async (provider) => {
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
});
}
async function basicPrompt(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Keep your responses concise.",
model,
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("What is 2+2? Answer with just the number.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
expect(agent.state.messages[0].role).toBe("user");
expect(agent.state.messages[1].role).toBe("assistant");
const assistantMessage = agent.state.messages[1];
if (assistantMessage.role !== "assistant") throw new Error("Expected assistant message");
expect(assistantMessage.content.length).toBeGreaterThan(0);
const textContent = assistantMessage.content.find((c) => c.type === "text");
expect(textContent).toBeDefined();
if (textContent?.type !== "text") throw new Error("Expected text content");
expect(textContent.text).toContain("4");
}
async function toolExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Always use the calculator tool for math.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(3);
const toolResultMsg = agent.state.messages.find((m) => m.role === "toolResult");
expect(toolResultMsg).toBeDefined();
if (toolResultMsg?.role !== "toolResult") throw new Error("Expected tool result message");
const textContent =
toolResultMsg.content
?.filter((c) => c.type === "text")
.map((c: any) => c.text)
.join("\n") || "";
expect(textContent).toBeDefined();
const expectedResult = 123 * 456;
expect(textContent).toContain(String(expectedResult));
const finalMessage = agent.state.messages[agent.state.messages.length - 1];
if (finalMessage.role !== "assistant") throw new Error("Expected final assistant message");
const finalText = finalMessage.content.find((c) => c.type === "text");
expect(finalText).toBeDefined();
if (finalText?.type !== "text") throw new Error("Expected text content");
// Check for number with or without comma formatting
const hasNumber =
finalText.text.includes(String(expectedResult)) ||
finalText.text.includes("56,088") ||
finalText.text.includes("56088");
expect(hasNumber).toBe(true);
}
async function abortExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
setTimeout(() => {
agent.abort();
}, 100);
await promptPromise;
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(2);
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
expect(lastMessage.stopReason).toBe("aborted");
expect(lastMessage.errorMessage).toBeDefined();
expect(agent.state.error).toBeDefined();
expect(agent.state.error).toBe(lastMessage.errorMessage);
}
async function stateUpdates(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
const events: Array<string> = [];
agent.subscribe((event) => {
events.push(event.type);
});
await agent.prompt("Count from 1 to 5.");
// Should have received lifecycle events
expect(events).toContain("agent_start");
expect(events).toContain("agent_end");
expect(events).toContain("message_start");
expect(events).toContain("message_end");
// May have message_update events during streaming
const hasMessageUpdates = events.some((e) => e === "message_update");
expect(hasMessageUpdates).toBe(true);
// Check final state
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2); // User message + assistant response
}
async function multiTurnConversation(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
await agent.prompt("My name is Alice.");
expect(agent.state.messages.length).toBe(2);
await agent.prompt("What is my name?");
expect(agent.state.messages.length).toBe(4);
const lastMessage = agent.state.messages[3];
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
const lastText = lastMessage.content.find((c) => c.type === "text");
if (lastText?.type !== "text") throw new Error("Expected text content");
expect(lastText.text.toLowerCase()).toContain("alice");
}
describe("Agent E2E Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider (gemini-2.5-flash)", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Provider (gpt-4o-mini)", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-haiku-4-5)", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-3)", () => {
const model = getModel("xai", "grok-3");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (openai/gpt-oss-20b)", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b)", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider (glm-4.5-air)", () => {
const model = getModel("zai", "glm-4.5-air");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
});
describe("Agent.continue()", () => {
describe("validation", () => {
it("should throw when no messages in context", async () => {
const agent = new Agent({
initialState: {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
});
it("should throw when last message is assistant", async () => {
const agent = new Agent({
initialState: {
systemPrompt: "Test",
model: getModel("anthropic", "claude-haiku-4-5"),
},
transport: createTransport(),
});
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "Hello" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
agent.replaceMessages([assistantMessage]);
await expect(agent.continue()).rejects.toThrow("Cannot continue from message role: assistant");
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and get response when last message is user", async () => {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
model,
thinkingLevel: "off",
tools: [],
},
transport: createTransport(),
});
// Manually add a user message without calling prompt()
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
timestamp: Date.now(),
};
agent.replaceMessages([userMessage]);
// Continue from the user message
await agent.continue();
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
expect(agent.state.messages[0].role).toBe("user");
expect(agent.state.messages[1].role).toBe("assistant");
const assistantMsg = agent.state.messages[1] as AssistantMessage;
const textContent = assistantMsg.content.find((c) => c.type === "text");
expect(textContent).toBeDefined();
if (textContent?.type === "text") {
expect(textContent.text.toUpperCase()).toContain("HELLO WORLD");
}
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
const model = getModel("anthropic", "claude-haiku-4-5");
it("should continue and process tool results", async () => {
const agent = new Agent({
initialState: {
systemPrompt:
"You are a helpful assistant. After getting a calculation result, state the answer clearly.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
transport: createTransport(),
});
// Set up a conversation state as if tool was just executed
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What is 5 + 3?" }],
timestamp: Date.now(),
};
const assistantMessage: AssistantMessage = {
role: "assistant",
content: [
{ type: "text", text: "Let me calculate that." },
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-haiku-4-5",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "toolUse",
timestamp: Date.now(),
};
const toolResult: ToolResultMessage = {
role: "toolResult",
toolCallId: "calc-1",
toolName: "calculate",
content: [{ type: "text", text: "5 + 3 = 8" }],
isError: false,
timestamp: Date.now(),
};
agent.replaceMessages([userMessage, assistantMessage, toolResult]);
// Continue from the tool result
await agent.continue();
expect(agent.state.isStreaming).toBe(false);
// Should have added an assistant response
expect(agent.state.messages.length).toBeGreaterThanOrEqual(4);
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
expect(lastMessage.role).toBe("assistant");
if (lastMessage.role === "assistant") {
const textContent = lastMessage.content
.filter((c) => c.type === "text")
.map((c) => (c as { type: "text"; text: string }).text)
.join(" ");
// Should mention 8 in the response
expect(textContent).toMatch(/8/);
}
});
});
});