mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 08:00:59 +00:00
Simplify compaction: remove proactive abort, use Agent.continue() for retry
- Add agentLoopContinue() to pi-ai for resuming from existing context - Add Agent.continue() method and transport.continue() interface - Simplify AgentSession compaction to two cases: overflow (auto-retry) and threshold (no retry) - Remove proactive mid-turn compaction abort - Merge turn prefix summary into main summary - Add isCompacting property to AgentSession and RPC state - Block input during compaction in interactive mode - Show compaction count on session resume - Rename RPC.md to rpc.md for consistency Related to #128
This commit is contained in:
parent
d67c69c6e9
commit
5a9d844f9a
27 changed files with 1261 additions and 1011 deletions
|
|
@ -176,11 +176,6 @@ export class Agent {
|
|||
throw new Error("No model configured");
|
||||
}
|
||||
|
||||
// Set up running prompt tracking
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
// Build user message with attachments
|
||||
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
||||
if (attachments?.length) {
|
||||
|
|
@ -204,6 +199,62 @@ export class Agent {
|
|||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
await this._runAgentLoop(userMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from the current context without adding a new user message.
|
||||
* Used for retry after overflow recovery when context already has user message or tool results.
|
||||
*/
|
||||
async continue() {
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
|
||||
throw new Error(`Cannot continue from message role: ${lastMessage.role}`);
|
||||
}
|
||||
|
||||
await this._runAgentLoopContinue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal: Run the agent loop with a new user message.
|
||||
*/
|
||||
private async _runAgentLoop(userMessage: AppMessage) {
|
||||
const { llmMessages, cfg } = await this._prepareRun();
|
||||
|
||||
const events = this.transport.run(llmMessages, userMessage as Message, cfg, this.abortController!.signal);
|
||||
|
||||
await this._processEvents(events);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal: Continue the agent loop from current context.
|
||||
*/
|
||||
private async _runAgentLoopContinue() {
|
||||
const { llmMessages, cfg } = await this._prepareRun();
|
||||
|
||||
const events = this.transport.continue(llmMessages, cfg, this.abortController!.signal);
|
||||
|
||||
await this._processEvents(events);
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare for running the agent loop.
|
||||
*/
|
||||
private async _prepareRun() {
|
||||
const model = this._state.model;
|
||||
if (!model) {
|
||||
throw new Error("No model configured");
|
||||
}
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
this.abortController = new AbortController();
|
||||
this._state.isStreaming = true;
|
||||
this._state.streamMessage = null;
|
||||
|
|
@ -222,9 +273,7 @@ export class Agent {
|
|||
model,
|
||||
reasoning,
|
||||
getQueuedMessages: async <T>() => {
|
||||
// Return queued messages based on queue mode
|
||||
if (this.queueMode === "one-at-a-time") {
|
||||
// Return only first message
|
||||
if (this.messageQueue.length > 0) {
|
||||
const first = this.messageQueue[0];
|
||||
this.messageQueue = this.messageQueue.slice(1);
|
||||
|
|
@ -232,7 +281,6 @@ export class Agent {
|
|||
}
|
||||
return [];
|
||||
} else {
|
||||
// Return all queued messages at once
|
||||
const queued = this.messageQueue.slice();
|
||||
this.messageQueue = [];
|
||||
return queued as QueuedMessage<T>[];
|
||||
|
|
@ -240,32 +288,30 @@ export class Agent {
|
|||
},
|
||||
};
|
||||
|
||||
// Track all messages generated in this prompt
|
||||
const llmMessages = await this.messageTransformer(this._state.messages);
|
||||
|
||||
return { llmMessages, cfg, model };
|
||||
}
|
||||
|
||||
/**
|
||||
* Process events from the transport.
|
||||
*/
|
||||
private async _processEvents(events: AsyncIterable<AgentEvent>) {
|
||||
const model = this._state.model!;
|
||||
const generatedMessages: AppMessage[] = [];
|
||||
let partial: AppMessage | null = null;
|
||||
|
||||
try {
|
||||
let partial: Message | null = null;
|
||||
|
||||
// Transform app messages to LLM-compatible messages (initial set)
|
||||
const llmMessages = await this.messageTransformer(this._state.messages);
|
||||
|
||||
for await (const ev of this.transport.run(
|
||||
llmMessages,
|
||||
userMessage as Message,
|
||||
cfg,
|
||||
this.abortController.signal,
|
||||
)) {
|
||||
// Update internal state BEFORE emitting events
|
||||
// so handlers see consistent state
|
||||
for await (const ev of events) {
|
||||
switch (ev.type) {
|
||||
case "message_start": {
|
||||
partial = ev.message;
|
||||
this._state.streamMessage = ev.message;
|
||||
partial = ev.message as AppMessage;
|
||||
this._state.streamMessage = ev.message as Message;
|
||||
break;
|
||||
}
|
||||
case "message_update": {
|
||||
partial = ev.message;
|
||||
this._state.streamMessage = ev.message;
|
||||
partial = ev.message as AppMessage;
|
||||
this._state.streamMessage = ev.message as Message;
|
||||
break;
|
||||
}
|
||||
case "message_end": {
|
||||
|
|
@ -299,7 +345,6 @@ export class Agent {
|
|||
}
|
||||
}
|
||||
|
||||
// Emit after state is updated
|
||||
this.emit(ev as AgentEvent);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import type {
|
|||
ToolCall,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { agentLoop } from "@mariozechner/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "@mariozechner/pi-ai";
|
||||
import { AssistantMessageEventStream } from "@mariozechner/pi-ai/dist/utils/event-stream.js";
|
||||
import { parseStreamingJson } from "@mariozechner/pi-ai/dist/utils/json-parse.js";
|
||||
import type { ProxyAssistantMessageEvent } from "./proxy-types.js";
|
||||
|
|
@ -335,14 +335,8 @@ export class AppTransport implements AgentTransport {
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const authToken = await this.options.getAuthToken();
|
||||
if (!authToken) {
|
||||
throw new Error("Auth token is required for AppTransport");
|
||||
}
|
||||
|
||||
// Use proxy - no local API key needed
|
||||
const streamFn = <TApi extends Api>(model: Model<TApi>, context: Context, options?: SimpleStreamOptions) => {
|
||||
private async getStreamFn(authToken: string) {
|
||||
return <TApi extends Api>(model: Model<TApi>, context: Context, options?: SimpleStreamOptions) => {
|
||||
return streamSimpleProxy(
|
||||
model,
|
||||
context,
|
||||
|
|
@ -353,24 +347,51 @@ export class AppTransport implements AgentTransport {
|
|||
this.options.proxyUrl,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
// Messages are already LLM-compatible (filtered by Agent)
|
||||
const context: AgentContext = {
|
||||
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
|
||||
return {
|
||||
systemPrompt: cfg.systemPrompt,
|
||||
messages,
|
||||
tools: cfg.tools,
|
||||
};
|
||||
}
|
||||
|
||||
const pc: AgentLoopConfig = {
|
||||
private buildLoopConfig(cfg: AgentRunConfig): AgentLoopConfig {
|
||||
return {
|
||||
model: cfg.model,
|
||||
reasoning: cfg.reasoning,
|
||||
getQueuedMessages: cfg.getQueuedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const authToken = await this.options.getAuthToken();
|
||||
if (!authToken) {
|
||||
throw new Error("Auth token is required for AppTransport");
|
||||
}
|
||||
|
||||
const streamFn = await this.getStreamFn(authToken);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(cfg);
|
||||
|
||||
// Yield events from the upstream agentLoop iterator
|
||||
// Pass streamFn as the 5th parameter to use proxy
|
||||
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal, streamFn as any)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
|
||||
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const authToken = await this.options.getAuthToken();
|
||||
if (!authToken) {
|
||||
throw new Error("Auth token is required for AppTransport");
|
||||
}
|
||||
|
||||
const streamFn = await this.getStreamFn(authToken);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(cfg);
|
||||
|
||||
for await (const ev of agentLoopContinue(context, pc, signal, streamFn as any)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import {
|
|||
type AgentContext,
|
||||
type AgentLoopConfig,
|
||||
agentLoop,
|
||||
agentLoopContinue,
|
||||
type Message,
|
||||
type UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
|
|
@ -33,18 +34,15 @@ export class ProviderTransport implements AgentTransport {
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
// Get API key
|
||||
private async getModelAndKey(cfg: AgentRunConfig) {
|
||||
let apiKey: string | undefined;
|
||||
if (this.options.getApiKey) {
|
||||
apiKey = await this.options.getApiKey(cfg.model.provider);
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key found for provider: ${cfg.model.provider}`);
|
||||
}
|
||||
|
||||
// Clone model and modify baseUrl if CORS proxy is enabled
|
||||
let model = cfg.model;
|
||||
if (this.options.corsProxyUrl && cfg.model.baseUrl) {
|
||||
model = {
|
||||
|
|
@ -53,23 +51,43 @@ export class ProviderTransport implements AgentTransport {
|
|||
};
|
||||
}
|
||||
|
||||
// Messages are already LLM-compatible (filtered by Agent)
|
||||
const context: AgentContext = {
|
||||
return { model, apiKey };
|
||||
}
|
||||
|
||||
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
|
||||
return {
|
||||
systemPrompt: cfg.systemPrompt,
|
||||
messages,
|
||||
tools: cfg.tools,
|
||||
};
|
||||
}
|
||||
|
||||
const pc: AgentLoopConfig = {
|
||||
private buildLoopConfig(model: typeof cfg.model, apiKey: string, cfg: AgentRunConfig): AgentLoopConfig {
|
||||
return {
|
||||
model,
|
||||
reasoning: cfg.reasoning,
|
||||
apiKey,
|
||||
getQueuedMessages: cfg.getQueuedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
async *run(messages: Message[], userMessage: Message, cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const { model, apiKey } = await this.getModelAndKey(cfg);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(model, apiKey, cfg);
|
||||
|
||||
// Yield events from agentLoop
|
||||
for await (const ev of agentLoop(userMessage as unknown as UserMessage, context, pc, signal)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
|
||||
async *continue(messages: Message[], cfg: AgentRunConfig, signal?: AbortSignal) {
|
||||
const { model, apiKey } = await this.getModelAndKey(cfg);
|
||||
const context = this.buildContext(messages, cfg);
|
||||
const pc = this.buildLoopConfig(model, apiKey, cfg);
|
||||
|
||||
for await (const ev of agentLoopContinue(context, pc, signal)) {
|
||||
yield ev;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,10 +19,14 @@ export interface AgentRunConfig {
|
|||
* Events yielded must match the @mariozechner/pi-ai AgentEvent types.
|
||||
*/
|
||||
export interface AgentTransport {
|
||||
/** Run with a new user message */
|
||||
run(
|
||||
messages: Message[],
|
||||
userMessage: Message,
|
||||
config: AgentRunConfig,
|
||||
signal?: AbortSignal,
|
||||
): AsyncIterable<AgentEvent>;
|
||||
|
||||
/** Continue from current context (no new user message) */
|
||||
continue(messages: Message[], config: AgentRunConfig, signal?: AbortSignal): AsyncIterable<AgentEvent>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,26 @@
|
|||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
|
||||
import { calculateTool, getModel } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent, ProviderTransport } from "../src/index.js";
|
||||
|
||||
function createTransport() {
|
||||
return new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function basicPrompt(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
|
|
@ -11,22 +29,7 @@ async function basicPrompt(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
// Map provider names to env var names
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
}),
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("What is 2+2? Answer with just the number.");
|
||||
|
|
@ -54,22 +57,7 @@ async function toolExecution(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
// Map provider names to env var names
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
}),
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
|
||||
|
|
@ -111,22 +99,7 @@ async function abortExecution(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
// Map provider names to env var names
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
}),
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
|
||||
|
|
@ -156,22 +129,7 @@ async function stateUpdates(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
// Map provider names to env var names
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
}),
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const events: Array<string> = [];
|
||||
|
|
@ -204,22 +162,7 @@ async function multiTurnConversation(model: Model<any>) {
|
|||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: new ProviderTransport({
|
||||
getApiKey: async (provider) => {
|
||||
// Map provider names to env var names
|
||||
const envVarMap: Record<string, string> = {
|
||||
google: "GEMINI_API_KEY",
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
};
|
||||
const envVar = envVarMap[provider] || `${provider.toUpperCase()}_API_KEY`;
|
||||
return process.env[envVar];
|
||||
},
|
||||
}),
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await agent.prompt("My name is Alice.");
|
||||
|
|
@ -284,8 +227,8 @@ describe("Agent E2E Tests", () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-haiku-4-5)", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
|
|
@ -404,3 +347,164 @@ describe("Agent E2E Tests", () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent.continue()", () => {
|
||||
describe("validation", () => {
|
||||
it("should throw when no messages in context", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
|
||||
});
|
||||
|
||||
it("should throw when last message is assistant", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("anthropic", "claude-haiku-4-5"),
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([assistantMessage]);
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow("Cannot continue from message role: assistant");
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from user message", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue and get response when last message is user", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
// Manually add a user message without calling prompt()
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([userMessage]);
|
||||
|
||||
// Continue from the user message
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
expect(agent.state.messages[0].role).toBe("user");
|
||||
expect(agent.state.messages[1].role).toBe("assistant");
|
||||
|
||||
const assistantMsg = agent.state.messages[1] as AssistantMessage;
|
||||
const textContent = assistantMsg.content.find((c) => c.type === "text");
|
||||
expect(textContent).toBeDefined();
|
||||
if (textContent?.type === "text") {
|
||||
expect(textContent.text.toUpperCase()).toContain("HELLO WORLD");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("continue from tool result", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should continue and process tool results", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. After getting a calculation result, state the answer clearly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
transport: createTransport(),
|
||||
});
|
||||
|
||||
// Set up a conversation state as if tool was just executed
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "What is 5 + 3?" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me calculate that." },
|
||||
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
|
||||
],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const toolResult: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: "calc-1",
|
||||
toolName: "calculate",
|
||||
content: [{ type: "text", text: "5 + 3 = 8" }],
|
||||
isError: false,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
agent.replaceMessages([userMessage, assistantMessage, toolResult]);
|
||||
|
||||
// Continue from the tool result
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
// Should have added an assistant response
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(4);
|
||||
|
||||
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
expect(lastMessage.role).toBe("assistant");
|
||||
|
||||
if (lastMessage.role === "assistant") {
|
||||
const textContent = lastMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as { type: "text"; text: string }).text)
|
||||
.join(" ");
|
||||
// Should mention 8 in the response
|
||||
expect(textContent).toMatch(/8/);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -4,6 +4,6 @@ export default defineConfig({
|
|||
test: {
|
||||
globals: true,
|
||||
environment: "node",
|
||||
testTimeout: 10000, // 10 seconds
|
||||
testTimeout: 30000, // 30 seconds for API calls
|
||||
},
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue