diff --git a/packages/ai/README.md b/packages/ai/README.md index 94bab2bf..052d22e0 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -795,6 +795,8 @@ The Agent API streams events during execution, allowing you to build reactive UI This continues until the assistant produces a response without tool calls. +**Queued messages**: If you provide `getQueuedMessages` in the loop config, the agent checks for queued user messages after each tool call. When queued messages are found, any remaining tool calls from the current assistant message are skipped and returned as error tool results (`isError: true`) with the message "Skipped due to queued user message." The queued user messages are injected before the next assistant response. + ### Event Flow Example Given a prompt asking to calculate two expressions and sum them: @@ -1139,4 +1141,4 @@ If you get "The requested model is not supported" error, enable the model manual ## License -MIT \ No newline at end of file +MIT diff --git a/packages/ai/src/agent/agent-loop.ts b/packages/ai/src/agent/agent-loop.ts index 2d844495..a33d1ffb 100644 --- a/packages/ai/src/agent/agent-loop.ts +++ b/packages/ai/src/agent/agent-loop.ts @@ -92,6 +92,7 @@ async function runLoop( let hasMoreToolCalls = true; let firstTurn = true; let queuedMessages: QueuedMessage[] = (await config.getQueuedMessages?.()) || []; + let queuedAfterTools: QueuedMessage[] | null = null; while (hasMoreToolCalls || queuedMessages.length > 0) { if (!firstTurn) { @@ -132,14 +133,27 @@ async function runLoop( const toolResults: ToolResultMessage[] = []; if (hasMoreToolCalls) { // Execute tool calls - toolResults.push(...(await executeToolCalls(currentContext.tools, message, signal, stream))); + const toolExecution = await executeToolCalls( + currentContext.tools, + message, + signal, + stream, + config.getQueuedMessages, + ); + toolResults.push(...toolExecution.toolResults); + queuedAfterTools = toolExecution.queuedMessages ?? null; currentContext.messages.push(...toolResults); newMessages.push(...toolResults); } stream.push({ type: "turn_end", message, toolResults: toolResults }); // Get queued messages after turn completes - queuedMessages = (await config.getQueuedMessages?.()) || []; + if (queuedAfterTools && queuedAfterTools.length > 0) { + queuedMessages = queuedAfterTools; + queuedAfterTools = null; + } else { + queuedMessages = (await config.getQueuedMessages?.()) || []; + } } stream.push({ type: "agent_end", messages: newMessages }); @@ -234,11 +248,14 @@ async function executeToolCalls( assistantMessage: AssistantMessage, signal: AbortSignal | undefined, stream: EventStream, -): Promise[]> { + getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], +): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: QueuedMessage[] }> { const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); const results: ToolResultMessage[] = []; + let queuedMessages: QueuedMessage[] | undefined; - for (const toolCall of toolCalls) { + for (let index = 0; index < toolCalls.length; index++) { + const toolCall = toolCalls[index]; const tool = tools?.find((t) => t.name === toolCall.name); stream.push({ @@ -296,7 +313,58 @@ async function executeToolCalls( results.push(toolResultMessage); stream.push({ type: "message_start", message: toolResultMessage }); stream.push({ type: "message_end", message: toolResultMessage }); + + if (getQueuedMessages) { + const queued = await getQueuedMessages(); + if (queued.length > 0) { + queuedMessages = queued; + const remainingCalls = toolCalls.slice(index + 1); + for (const skipped of remainingCalls) { + results.push(skipToolCall(skipped, stream)); + } + break; + } + } } - return results; + return { toolResults: results, queuedMessages }; +} + +function skipToolCall( + toolCall: Extract, + stream: EventStream, +): ToolResultMessage { + const result: AgentToolResult = { + content: [{ type: "text", text: "Skipped due to queued user message." }], + details: {} as T, + }; + + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + stream.push({ + type: "tool_execution_end", + toolCallId: toolCall.id, + toolName: toolCall.name, + result, + isError: true, + }); + + const toolResultMessage: ToolResultMessage = { + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: result.content, + details: result.details, + isError: true, + timestamp: Date.now(), + }; + + stream.push({ type: "message_start", message: toolResultMessage }); + stream.push({ type: "message_end", message: toolResultMessage }); + + return toolResultMessage; } diff --git a/packages/ai/test/agent-queue-interrupt.test.ts b/packages/ai/test/agent-queue-interrupt.test.ts new file mode 100644 index 00000000..42a5db45 --- /dev/null +++ b/packages/ai/test/agent-queue-interrupt.test.ts @@ -0,0 +1,166 @@ +import { Type } from "@sinclair/typebox"; +import { describe, expect, it } from "vitest"; +import { agentLoop } from "../src/agent/agent-loop.js"; +import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js"; +import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js"; +import { AssistantMessageEventStream } from "../src/utils/event-stream.js"; + +function createUsage() { + return { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; +} + +function createModel(): Model<"openai-responses"> { + return { + id: "mock", + name: "mock", + api: "openai-responses", + provider: "openai", + baseUrl: "https://example.invalid", + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 8192, + maxTokens: 2048, + }; +} + +describe("agentLoop queued message interrupt", () => { + it("injects queued messages after a tool call and skips remaining tool calls", async () => { + const toolSchema = Type.Object({ value: Type.String() }); + const executed: string[] = []; + const tool: AgentTool = { + name: "echo", + label: "Echo", + description: "Echo tool", + parameters: toolSchema, + async execute(_toolCallId, params) { + executed.push(params.value); + return { + content: [{ type: "text", text: `ok:${params.value}` }], + details: { value: params.value }, + }; + }, + }; + + const context: AgentContext = { + systemPrompt: "", + messages: [], + tools: [tool], + }; + + const userPrompt: UserMessage = { + role: "user", + content: "start", + timestamp: Date.now(), + }; + + const queuedUserMessage: Message = { + role: "user", + content: "interrupt", + timestamp: Date.now(), + }; + const queuedMessages: QueuedMessage[] = [{ original: queuedUserMessage, llm: queuedUserMessage }]; + + let queuedDelivered = false; + let sawInterruptInContext = false; + let callIndex = 0; + + const streamFn = () => { + const stream = new AssistantMessageEventStream(); + queueMicrotask(() => { + if (callIndex === 0) { + const message: AssistantMessage = { + role: "assistant", + content: [ + { type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } }, + { type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } }, + ], + api: "openai-responses", + provider: "openai", + model: "mock", + usage: createUsage(), + stopReason: "toolUse", + timestamp: Date.now(), + }; + stream.push({ type: "done", reason: "toolUse", message }); + } else { + const message: AssistantMessage = { + role: "assistant", + content: [{ type: "text", text: "done" }], + api: "openai-responses", + provider: "openai", + model: "mock", + usage: createUsage(), + stopReason: "stop", + timestamp: Date.now(), + }; + stream.push({ type: "done", reason: "stop", message }); + } + callIndex += 1; + }); + return stream; + }; + + const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async () => { + if (executed.length === 1 && !queuedDelivered) { + queuedDelivered = true; + return queuedMessages as QueuedMessage[]; + } + return []; + }; + + const config: AgentLoopConfig = { + model: createModel(), + getQueuedMessages, + }; + + const events: AgentEvent[] = []; + const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => { + if (callIndex === 1) { + sawInterruptInContext = ctx.messages.some( + (m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt", + ); + } + return streamFn(); + }); + + for await (const event of stream) { + events.push(event); + } + + expect(executed).toEqual(["first"]); + const toolEnds = events.filter( + (event): event is Extract => event.type === "tool_execution_end", + ); + expect(toolEnds.length).toBe(2); + expect(toolEnds[1].isError).toBe(true); + expect(toolEnds[1].result.content[0]?.type).toBe("text"); + if (toolEnds[1].result.content[0]?.type === "text") { + expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message"); + } + + const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end"); + const queuedMessageIndex = events.findIndex( + (event) => + event.type === "message_start" && + event.message.role === "user" && + typeof event.message.content === "string" && + event.message.content === "interrupt", + ); + const nextAssistantIndex = events.findIndex( + (event, index) => + index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant", + ); + + expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex); + expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex); + expect(sawInterruptInContext).toBe(true); + }); +});