Agent package + coding agent WIP, refactored web-ui prompts

This commit is contained in:
Mario Zechner 2025-10-17 11:47:01 +02:00
parent 4e7a340460
commit ffc9be8867
58 changed files with 5138 additions and 2206 deletions

View file

@ -0,0 +1,140 @@
import { getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
describe("Agent", () => {
it("should create an agent instance with default state", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
expect(agent.state).toBeDefined();
expect(agent.state.systemPrompt).toBe("");
expect(agent.state.model).toBeDefined();
expect(agent.state.thinkingLevel).toBe("off");
expect(agent.state.tools).toEqual([]);
expect(agent.state.messages).toEqual([]);
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.streamMessage).toBe(null);
expect(agent.state.pendingToolCalls).toEqual(new Set());
expect(agent.state.error).toBeUndefined();
});
it("should create an agent instance with custom initial state", () => {
const customModel = getModel("openai", "gpt-4o-mini");
const agent = new Agent({
transport: new ProviderTransport(),
initialState: {
systemPrompt: "You are a helpful assistant.",
model: customModel,
thinkingLevel: "low",
},
});
expect(agent.state.systemPrompt).toBe("You are a helpful assistant.");
expect(agent.state.model).toBe(customModel);
expect(agent.state.thinkingLevel).toBe("low");
});
it("should subscribe to state updates", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
let updateCount = 0;
const unsubscribe = agent.subscribe((event) => {
if (event.type === "state-update") {
updateCount++;
}
});
// Initial state update on subscribe
expect(updateCount).toBe(1);
// Update state
agent.setSystemPrompt("Test prompt");
expect(updateCount).toBe(2);
expect(agent.state.systemPrompt).toBe("Test prompt");
// Unsubscribe should work
unsubscribe();
agent.setSystemPrompt("Another prompt");
expect(updateCount).toBe(2); // Should not increase
});
it("should update state with mutators", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
// Test setSystemPrompt
agent.setSystemPrompt("Custom prompt");
expect(agent.state.systemPrompt).toBe("Custom prompt");
// Test setModel
const newModel = getModel("google", "gemini-2.5-flash");
agent.setModel(newModel);
expect(agent.state.model).toBe(newModel);
// Test setThinkingLevel
agent.setThinkingLevel("high");
expect(agent.state.thinkingLevel).toBe("high");
// Test setTools
const tools = [{ name: "test", description: "test tool" } as any];
agent.setTools(tools);
expect(agent.state.tools).toBe(tools);
// Test replaceMessages
const messages = [{ role: "user" as const, content: "Hello" }];
agent.replaceMessages(messages);
expect(agent.state.messages).toEqual(messages);
expect(agent.state.messages).not.toBe(messages); // Should be a copy
// Test appendMessage
const newMessage = { role: "assistant" as const, content: [{ type: "text" as const, text: "Hi" }] };
agent.appendMessage(newMessage as any);
expect(agent.state.messages).toHaveLength(2);
expect(agent.state.messages[1]).toBe(newMessage);
// Test clearMessages
agent.clearMessages();
expect(agent.state.messages).toEqual([]);
});
it("should support message queueing", async () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
const message = { role: "user" as const, content: "Queued message" };
await agent.queueMessage(message);
// The message is queued but not yet in state.messages
expect(agent.state.messages).not.toContainEqual(message);
});
it("should handle abort controller", () => {
const agent = new Agent({
transport: new ProviderTransport(),
});
// Should not throw even if nothing is running
expect(() => agent.abort()).not.toThrow();
});
});
describe("ProviderTransport", () => {
it("should create a provider transport instance", () => {
const transport = new ProviderTransport();
expect(transport).toBeDefined();
});
it("should create a provider transport with options", () => {
const transport = new ProviderTransport({
getApiKey: async (provider) => `test-key-${provider}`,
corsProxyUrl: "https://proxy.example.com",
});
expect(transport).toBeDefined();
});
});

View file

@ -0,0 +1,402 @@
import type { Model } from "@mariozechner/pi-ai";
import { calculateTool, getModel } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest";
import { Agent, ProviderTransport } from "../src/index.js";
async function basicPrompt(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Keep your responses concise.",
model,
thinkingLevel: "off",
tools: [],
},
transport: new ProviderTransport({
getApiKey: async (provider) => {
// Map provider names to env var names
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
}),
});
await agent.prompt("What is 2+2? Answer with just the number.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBe(2);
expect(agent.state.messages[0].role).toBe("user");
expect(agent.state.messages[1].role).toBe("assistant");
const assistantMessage = agent.state.messages[1];
if (assistantMessage.role !== "assistant") throw new Error("Expected assistant message");
expect(assistantMessage.content.length).toBeGreaterThan(0);
const textContent = assistantMessage.content.find((c) => c.type === "text");
expect(textContent).toBeDefined();
if (textContent?.type !== "text") throw new Error("Expected text content");
expect(textContent.text).toContain("4");
}
async function toolExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant. Always use the calculator tool for math.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
transport: new ProviderTransport({
getApiKey: async (provider) => {
// Map provider names to env var names
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
}),
});
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(3);
const toolResultMsg = agent.state.messages.find((m) => m.role === "toolResult");
expect(toolResultMsg).toBeDefined();
if (toolResultMsg?.role !== "toolResult") throw new Error("Expected tool result message");
expect(toolResultMsg.output).toBeDefined();
const expectedResult = 123 * 456;
expect(toolResultMsg.output).toContain(String(expectedResult));
const finalMessage = agent.state.messages[agent.state.messages.length - 1];
if (finalMessage.role !== "assistant") throw new Error("Expected final assistant message");
const finalText = finalMessage.content.find((c) => c.type === "text");
expect(finalText).toBeDefined();
if (finalText?.type !== "text") throw new Error("Expected text content");
// Check for number with or without comma formatting
const hasNumber =
finalText.text.includes(String(expectedResult)) ||
finalText.text.includes("56,088") ||
finalText.text.includes("56088");
expect(hasNumber).toBe(true);
}
async function abortExecution(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [calculateTool],
},
transport: new ProviderTransport({
getApiKey: async (provider) => {
// Map provider names to env var names
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
}),
});
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
setTimeout(() => {
agent.abort();
}, 100);
await promptPromise;
expect(agent.state.isStreaming).toBe(false);
expect(agent.state.messages.length).toBeGreaterThanOrEqual(2);
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
expect(lastMessage.stopReason).toBe("aborted");
expect(lastMessage.errorMessage).toBeDefined();
}
async function stateUpdates(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
transport: new ProviderTransport({
getApiKey: async (provider) => {
// Map provider names to env var names
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
}),
});
const stateSnapshots: Array<{ isStreaming: boolean; messageCount: number; hasStreamMessage: boolean }> = [];
agent.subscribe((event) => {
if (event.type === "state-update") {
stateSnapshots.push({
isStreaming: event.state.isStreaming,
messageCount: event.state.messages.length,
hasStreamMessage: event.state.streamMessage !== null,
});
}
});
await agent.prompt("Count from 1 to 5.");
const streamingStates = stateSnapshots.filter((s) => s.isStreaming);
const nonStreamingStates = stateSnapshots.filter((s) => !s.isStreaming);
expect(streamingStates.length).toBeGreaterThan(0);
expect(nonStreamingStates.length).toBeGreaterThan(0);
const finalState = stateSnapshots[stateSnapshots.length - 1];
expect(finalState.isStreaming).toBe(false);
expect(finalState.messageCount).toBe(2);
}
async function multiTurnConversation(model: Model<any>) {
const agent = new Agent({
initialState: {
systemPrompt: "You are a helpful assistant.",
model,
thinkingLevel: "off",
tools: [],
},
transport: new ProviderTransport({
getApiKey: async (provider) => {
// Map provider names to env var names
const envVarMap: Record<string, string> = {
google: "GEMINI_API_KEY",
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
xai: "XAI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
zai: "ZAI_API_KEY",
};
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
return process.env[envVar];
},
}),
});
await agent.prompt("My name is Alice.");
expect(agent.state.messages.length).toBe(2);
await agent.prompt("What is my name?");
expect(agent.state.messages.length).toBe(4);
const lastMessage = agent.state.messages[3];
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
const lastText = lastMessage.content.find((c) => c.type === "text");
if (lastText?.type !== "text") throw new Error("Expected text content");
expect(lastText.text.toLowerCase()).toContain("alice");
}
describe("Agent E2E Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider (gemini-2.5-flash)", () => {
const model = getModel("google", "gemini-2.5-flash");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Provider (gpt-4o-mini)", () => {
const model = getModel("openai", "gpt-4o-mini");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-3)", () => {
const model = getModel("xai", "grok-3");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (openai/gpt-oss-20b)", () => {
const model = getModel("groq", "openai/gpt-oss-20b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b)", () => {
const model = getModel("cerebras", "gpt-oss-120b");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider (glm-4.5-air)", () => {
const model = getModel("zai", "glm-4.5-air");
it("should handle basic text prompt", async () => {
await basicPrompt(model);
});
it("should execute tools correctly", async () => {
await toolExecution(model);
});
it("should handle abort during execution", async () => {
await abortExecution(model);
});
it("should emit state updates during streaming", async () => {
await stateUpdates(model);
});
it("should maintain context across multiple turns", async () => {
await multiTurnConversation(model);
});
});
});