feat(ai): interrupt tool batch on queued messages

This commit is contained in:
Peter Steinberger 2025-12-20 16:15:40 +01:00 committed by Mario Zechner
parent 6a319f9c3c
commit 117af076c4
3 changed files with 242 additions and 6 deletions

View file

@ -795,6 +795,8 @@ The Agent API streams events during execution, allowing you to build reactive UI
This continues until the assistant produces a response without tool calls.
**Queued messages**: If you provide `getQueuedMessages` in the loop config, the agent checks for queued user messages after each tool call. When queued messages are found, any remaining tool calls from the current assistant message are skipped and returned as error tool results (`isError: true`) with the message "Skipped due to queued user message." The queued user messages are injected before the next assistant response.
### Event Flow Example
Given a prompt asking to calculate two expressions and sum them:
@ -1139,4 +1141,4 @@ If you get "The requested model is not supported" error, enable the model manual
## License
MIT
MIT

View file

@ -92,6 +92,7 @@ async function runLoop(
let hasMoreToolCalls = true;
let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] = (await config.getQueuedMessages?.()) || [];
let queuedAfterTools: QueuedMessage<any>[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) {
@ -132,14 +133,27 @@ async function runLoop(
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
// Execute tool calls
toolResults.push(...(await executeToolCalls(currentContext.tools, message, signal, stream)));
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
signal,
stream,
config.getQueuedMessages,
);
toolResults.push(...toolExecution.toolResults);
queuedAfterTools = toolExecution.queuedMessages ?? null;
currentContext.messages.push(...toolResults);
newMessages.push(...toolResults);
}
stream.push({ type: "turn_end", message, toolResults: toolResults });
// Get queued messages after turn completes
queuedMessages = (await config.getQueuedMessages?.()) || [];
if (queuedAfterTools && queuedAfterTools.length > 0) {
queuedMessages = queuedAfterTools;
queuedAfterTools = null;
} else {
queuedMessages = (await config.getQueuedMessages?.()) || [];
}
}
stream.push({ type: "agent_end", messages: newMessages });
@ -234,11 +248,14 @@ async function executeToolCalls<T>(
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
): Promise<ToolResultMessage<T>[]> {
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{ toolResults: ToolResultMessage<T>[]; queuedMessages?: QueuedMessage<any>[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage<any>[] = [];
let queuedMessages: QueuedMessage<any>[] | undefined;
for (const toolCall of toolCalls) {
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
const tool = tools?.find((t) => t.name === toolCall.name);
stream.push({
@ -296,7 +313,58 @@ async function executeToolCalls<T>(
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
if (getQueuedMessages) {
const queued = await getQueuedMessages();
if (queued.length > 0) {
queuedMessages = queued;
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
break;
}
}
}
return results;
return { toolResults: results, queuedMessages };
}
function skipToolCall<T>(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, Message[]>,
): ToolResultMessage<T> {
const result: AgentToolResult<T> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {} as T,
};
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError: true,
});
const toolResultMessage: ToolResultMessage<T> = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError: true,
timestamp: Date.now(),
};
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
return toolResultMessage;
}

View file

@ -0,0 +1,166 @@
import { Type } from "@sinclair/typebox";
import { describe, expect, it } from "vitest";
import { agentLoop } from "../src/agent/agent-loop.js";
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, QueuedMessage } from "../src/agent/types.js";
import type { AssistantMessage, Message, Model, UserMessage } from "../src/types.js";
import { AssistantMessageEventStream } from "../src/utils/event-stream.js";
function createUsage() {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
function createModel(): Model<"openai-responses"> {
return {
id: "mock",
name: "mock",
api: "openai-responses",
provider: "openai",
baseUrl: "https://example.invalid",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 2048,
};
}
describe("agentLoop queued message interrupt", () => {
it("injects queued messages after a tool call and skips remaining tool calls", async () => {
const toolSchema = Type.Object({ value: Type.String() });
const executed: string[] = [];
const tool: AgentTool<typeof toolSchema, { value: string }> = {
name: "echo",
label: "Echo",
description: "Echo tool",
parameters: toolSchema,
async execute(_toolCallId, params) {
executed.push(params.value);
return {
content: [{ type: "text", text: `ok:${params.value}` }],
details: { value: params.value },
};
},
};
const context: AgentContext = {
systemPrompt: "",
messages: [],
tools: [tool],
};
const userPrompt: UserMessage = {
role: "user",
content: "start",
timestamp: Date.now(),
};
const queuedUserMessage: Message = {
role: "user",
content: "interrupt",
timestamp: Date.now(),
};
const queuedMessages: QueuedMessage<Message>[] = [{ original: queuedUserMessage, llm: queuedUserMessage }];
let queuedDelivered = false;
let sawInterruptInContext = false;
let callIndex = 0;
const streamFn = () => {
const stream = new AssistantMessageEventStream();
queueMicrotask(() => {
if (callIndex === 0) {
const message: AssistantMessage = {
role: "assistant",
content: [
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "toolUse",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "toolUse", message });
} else {
const message: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: "done" }],
api: "openai-responses",
provider: "openai",
model: "mock",
usage: createUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
stream.push({ type: "done", reason: "stop", message });
}
callIndex += 1;
});
return stream;
};
const getQueuedMessages: AgentLoopConfig["getQueuedMessages"] = async <T>() => {
if (executed.length === 1 && !queuedDelivered) {
queuedDelivered = true;
return queuedMessages as QueuedMessage<T>[];
}
return [];
};
const config: AgentLoopConfig = {
model: createModel(),
getQueuedMessages,
};
const events: AgentEvent[] = [];
const stream = agentLoop(userPrompt, context, config, undefined, (_model, ctx, _options) => {
if (callIndex === 1) {
sawInterruptInContext = ctx.messages.some(
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
);
}
return streamFn();
});
for await (const event of stream) {
events.push(event);
}
expect(executed).toEqual(["first"]);
const toolEnds = events.filter(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
);
expect(toolEnds.length).toBe(2);
expect(toolEnds[1].isError).toBe(true);
expect(toolEnds[1].result.content[0]?.type).toBe("text");
if (toolEnds[1].result.content[0]?.type === "text") {
expect(toolEnds[1].result.content[0].text).toContain("Skipped due to queued user message");
}
const firstTurnEndIndex = events.findIndex((event) => event.type === "turn_end");
const queuedMessageIndex = events.findIndex(
(event) =>
event.type === "message_start" &&
event.message.role === "user" &&
typeof event.message.content === "string" &&
event.message.content === "interrupt",
);
const nextAssistantIndex = events.findIndex(
(event, index) =>
index > queuedMessageIndex && event.type === "message_start" && event.message.role === "assistant",
);
expect(queuedMessageIndex).toBeGreaterThan(firstTurnEndIndex);
expect(queuedMessageIndex).toBeLessThan(nextAssistantIndex);
expect(sawInterruptInContext).toBe(true);
});
});