diff --git a/packages/agent/src/agent-loop.ts b/packages/agent/src/agent-loop.ts index ee63b89f..753466ac 100644 --- a/packages/agent/src/agent-loop.ts +++ b/packages/agent/src/agent-loop.ts @@ -109,71 +109,88 @@ async function runLoop( stream: EventStream, streamFn?: StreamFn, ): Promise { - let hasMoreToolCalls = true; let firstTurn = true; - let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || []; - let queuedAfterTools: AgentMessage[] | null = null; + // Check for steering messages at start (user may have typed while waiting) + let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || []; - while (hasMoreToolCalls || queuedMessages.length > 0) { - if (!firstTurn) { - stream.push({ type: "turn_start" }); - } else { - firstTurn = false; - } + // Outer loop: continues when queued follow-up messages arrive after agent would stop + while (true) { + let hasMoreToolCalls = true; + let steeringAfterTools: AgentMessage[] | null = null; - // Process queued messages (inject before next assistant response) - if (queuedMessages.length > 0) { - for (const message of queuedMessages) { - stream.push({ type: "message_start", message }); - stream.push({ type: "message_end", message }); - currentContext.messages.push(message); - newMessages.push(message); + // Inner loop: process tool calls and steering messages + while (hasMoreToolCalls || pendingMessages.length > 0) { + if (!firstTurn) { + stream.push({ type: "turn_start" }); + } else { + firstTurn = false; } - queuedMessages = []; - } - // Stream assistant response - const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn); - newMessages.push(message); + // Process pending messages (inject before next assistant response) + if (pendingMessages.length > 0) { + for (const message of pendingMessages) { + stream.push({ type: "message_start", message }); + stream.push({ type: "message_end", message }); + currentContext.messages.push(message); + newMessages.push(message); + } + pendingMessages = []; + } - if (message.stopReason === "error" || message.stopReason === "aborted") { - stream.push({ type: "turn_end", message, toolResults: [] }); - stream.push({ type: "agent_end", messages: newMessages }); - stream.end(newMessages); - return; - } + // Stream assistant response + const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn); + newMessages.push(message); - // Check for tool calls - const toolCalls = message.content.filter((c) => c.type === "toolCall"); - hasMoreToolCalls = toolCalls.length > 0; + if (message.stopReason === "error" || message.stopReason === "aborted") { + stream.push({ type: "turn_end", message, toolResults: [] }); + stream.push({ type: "agent_end", messages: newMessages }); + stream.end(newMessages); + return; + } - const toolResults: ToolResultMessage[] = []; - if (hasMoreToolCalls) { - const toolExecution = await executeToolCalls( - currentContext.tools, - message, - signal, - stream, - config.getQueuedMessages, - ); - toolResults.push(...toolExecution.toolResults); - queuedAfterTools = toolExecution.queuedMessages ?? null; + // Check for tool calls + const toolCalls = message.content.filter((c) => c.type === "toolCall"); + hasMoreToolCalls = toolCalls.length > 0; - for (const result of toolResults) { - currentContext.messages.push(result); - newMessages.push(result); + const toolResults: ToolResultMessage[] = []; + if (hasMoreToolCalls) { + const toolExecution = await executeToolCalls( + currentContext.tools, + message, + signal, + stream, + config.getSteeringMessages, + ); + toolResults.push(...toolExecution.toolResults); + steeringAfterTools = toolExecution.steeringMessages ?? null; + + for (const result of toolResults) { + currentContext.messages.push(result); + newMessages.push(result); + } + } + + stream.push({ type: "turn_end", message, toolResults }); + + // Get steering messages after turn completes + if (steeringAfterTools && steeringAfterTools.length > 0) { + pendingMessages = steeringAfterTools; + steeringAfterTools = null; + } else { + pendingMessages = (await config.getSteeringMessages?.()) || []; } } - stream.push({ type: "turn_end", message, toolResults }); - - // Get queued messages after turn completes - if (queuedAfterTools && queuedAfterTools.length > 0) { - queuedMessages = queuedAfterTools; - queuedAfterTools = null; - } else { - queuedMessages = (await config.getQueuedMessages?.()) || []; + // Agent would stop here. Check for follow-up messages. + const followUpMessages = (await config.getFollowUpMessages?.()) || []; + if (followUpMessages.length > 0) { + // Set as pending so inner loop processes them + pendingMessages = followUpMessages; + continue; } + + // No more messages, exit + break; } stream.push({ type: "agent_end", messages: newMessages }); @@ -279,11 +296,11 @@ async function executeToolCalls( assistantMessage: AssistantMessage, signal: AbortSignal | undefined, stream: EventStream, - getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], -): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> { + getSteeringMessages?: AgentLoopConfig["getSteeringMessages"], +): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> { const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); const results: ToolResultMessage[] = []; - let queuedMessages: AgentMessage[] | undefined; + let steeringMessages: AgentMessage[] | undefined; for (let index = 0; index < toolCalls.length; index++) { const toolCall = toolCalls[index]; @@ -343,11 +360,11 @@ async function executeToolCalls( stream.push({ type: "message_start", message: toolResultMessage }); stream.push({ type: "message_end", message: toolResultMessage }); - // Check for queued messages - skip remaining tools if user interrupted - if (getQueuedMessages) { - const queued = await getQueuedMessages(); - if (queued.length > 0) { - queuedMessages = queued; + // Check for steering messages - skip remaining tools if user interrupted + if (getSteeringMessages) { + const steering = await getSteeringMessages(); + if (steering.length > 0) { + steeringMessages = steering; const remainingCalls = toolCalls.slice(index + 1); for (const skipped of remainingCalls) { results.push(skipToolCall(skipped, stream)); @@ -357,7 +374,7 @@ async function executeToolCalls( } } - return { toolResults: results, queuedMessages }; + return { toolResults: results, steeringMessages }; } function skipToolCall( diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 851b70b2..d040e9ec 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -47,9 +47,14 @@ export interface AgentOptions { transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; /** - * Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn + * Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn */ - queueMode?: "all" | "one-at-a-time"; + steeringMode?: "all" | "one-at-a-time"; + + /** + * Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn + */ + followUpMode?: "all" | "one-at-a-time"; /** * Custom stream function (for proxy backends, etc.). Default uses streamSimple. @@ -80,8 +85,10 @@ export class Agent { private abortController?: AbortController; private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise; private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise; - private messageQueue: AgentMessage[] = []; - private queueMode: "all" | "one-at-a-time"; + private steeringQueue: AgentMessage[] = []; + private followUpQueue: AgentMessage[] = []; + private steeringMode: "all" | "one-at-a-time"; + private followUpMode: "all" | "one-at-a-time"; public streamFn: StreamFn; public getApiKey?: (provider: string) => Promise | string | undefined; private runningPrompt?: Promise; @@ -91,7 +98,8 @@ export class Agent { this._state = { ...this._state, ...opts.initialState }; this.convertToLlm = opts.convertToLlm || defaultConvertToLlm; this.transformContext = opts.transformContext; - this.queueMode = opts.queueMode || "one-at-a-time"; + this.steeringMode = opts.steeringMode || "one-at-a-time"; + this.followUpMode = opts.followUpMode || "one-at-a-time"; this.streamFn = opts.streamFn || streamSimple; this.getApiKey = opts.getApiKey; } @@ -118,12 +126,20 @@ export class Agent { this._state.thinkingLevel = l; } - setQueueMode(mode: "all" | "one-at-a-time") { - this.queueMode = mode; + setSteeringMode(mode: "all" | "one-at-a-time") { + this.steeringMode = mode; } - getQueueMode(): "all" | "one-at-a-time" { - return this.queueMode; + getSteeringMode(): "all" | "one-at-a-time" { + return this.steeringMode; + } + + setFollowUpMode(mode: "all" | "one-at-a-time") { + this.followUpMode = mode; + } + + getFollowUpMode(): "all" | "one-at-a-time" { + return this.followUpMode; } setTools(t: AgentTool[]) { @@ -138,12 +154,33 @@ export class Agent { this._state.messages = [...this._state.messages, m]; } - queueMessage(m: AgentMessage) { - this.messageQueue.push(m); + /** + * Queue a steering message to interrupt the agent mid-run. + * Delivered after current tool execution, skips remaining tools. + */ + steer(m: AgentMessage) { + this.steeringQueue.push(m); } - clearMessageQueue() { - this.messageQueue = []; + /** + * Queue a follow-up message to be processed after the agent finishes. + * Delivered only when agent has no more tool calls or steering messages. + */ + followUp(m: AgentMessage) { + this.followUpQueue.push(m); + } + + clearSteeringQueue() { + this.steeringQueue = []; + } + + clearFollowUpQueue() { + this.followUpQueue = []; + } + + clearAllQueues() { + this.steeringQueue = []; + this.followUpQueue = []; } clearMessages() { @@ -164,7 +201,8 @@ export class Agent { this._state.streamMessage = null; this._state.pendingToolCalls = new Set(); this._state.error = undefined; - this.messageQueue = []; + this.steeringQueue = []; + this.followUpQueue = []; } /** Send a prompt with an AgentMessage */ @@ -172,7 +210,9 @@ export class Agent { async prompt(input: string, images?: ImageContent[]): Promise; async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) { if (this._state.isStreaming) { - throw new Error("Agent is already processing a prompt. Use queueMessage() or wait for completion."); + throw new Error( + "Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.", + ); } const model = this._state.model; @@ -255,18 +295,32 @@ export class Agent { convertToLlm: this.convertToLlm, transformContext: this.transformContext, getApiKey: this.getApiKey, - getQueuedMessages: async () => { - if (this.queueMode === "one-at-a-time") { - if (this.messageQueue.length > 0) { - const first = this.messageQueue[0]; - this.messageQueue = this.messageQueue.slice(1); + getSteeringMessages: async () => { + if (this.steeringMode === "one-at-a-time") { + if (this.steeringQueue.length > 0) { + const first = this.steeringQueue[0]; + this.steeringQueue = this.steeringQueue.slice(1); return [first]; } return []; } else { - const queued = this.messageQueue.slice(); - this.messageQueue = []; - return queued; + const steering = this.steeringQueue.slice(); + this.steeringQueue = []; + return steering; + } + }, + getFollowUpMessages: async () => { + if (this.followUpMode === "one-at-a-time") { + if (this.followUpQueue.length > 0) { + const first = this.followUpQueue[0]; + this.followUpQueue = this.followUpQueue.slice(1); + return [first]; + } + return []; + } else { + const followUp = this.followUpQueue.slice(); + this.followUpQueue = []; + return followUp; } }, }; diff --git a/packages/agent/src/types.ts b/packages/agent/src/types.ts index e8af618e..b7ea7f7c 100644 --- a/packages/agent/src/types.ts +++ b/packages/agent/src/types.ts @@ -75,12 +75,26 @@ export interface AgentLoopConfig extends SimpleStreamOptions { getApiKey?: (provider: string) => Promise | string | undefined; /** - * Returns queued messages to inject into the conversation. + * Returns steering messages to inject into the conversation mid-run. * - * Called after each turn to check for user interruptions or injected messages. - * If messages are returned, they're added to the context before the next LLM call. + * Called after each tool execution to check for user interruptions. + * If messages are returned, remaining tool calls are skipped and + * these messages are added to the context before the next LLM call. + * + * Use this for "steering" the agent while it's working. */ - getQueuedMessages?: () => Promise; + getSteeringMessages?: () => Promise; + + /** + * Returns follow-up messages to process after the agent would otherwise stop. + * + * Called when the agent has no more tool calls and no steering messages. + * If messages are returned, they're added to the context and the agent + * continues with another turn. + * + * Use this for follow-up messages that should wait until the agent finishes. + */ + getFollowUpMessages?: () => Promise; } /** diff --git a/packages/agent/test/agent-loop.test.ts b/packages/agent/test/agent-loop.test.ts index b8295038..c1ee890c 100644 --- a/packages/agent/test/agent-loop.test.ts +++ b/packages/agent/test/agent-loop.test.ts @@ -340,8 +340,8 @@ describe("agentLoop with AgentMessage", () => { const config: AgentLoopConfig = { model: createModel(), convertToLlm: identityConverter, - getQueuedMessages: async () => { - // Return queued message after first tool executes + getSteeringMessages: async () => { + // Return steering message after first tool executes if (executed.length === 1 && !queuedDelivered) { queuedDelivered = true; return [queuedUserMessage];