mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 17:02:11 +00:00
feat(agent): split queue into steer() and followUp() APIs
Breaking change: replaces queueMessage() with two separate methods: - steer(msg): interrupt mid-run, delivered after current tool execution - followUp(msg): wait until agent finishes before delivery Also renames: - queueMode -> steeringMode/followUpMode - getQueuedMessages -> getSteeringMessages/getFollowUpMessages Refs #403
This commit is contained in:
parent
345fa975f1
commit
d0a4c37028
4 changed files with 175 additions and 90 deletions
|
|
@ -109,71 +109,88 @@ async function runLoop(
|
||||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||||
streamFn?: StreamFn,
|
streamFn?: StreamFn,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
let hasMoreToolCalls = true;
|
|
||||||
let firstTurn = true;
|
let firstTurn = true;
|
||||||
let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
|
// Check for steering messages at start (user may have typed while waiting)
|
||||||
let queuedAfterTools: AgentMessage[] | null = null;
|
let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || [];
|
||||||
|
|
||||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
// Outer loop: continues when queued follow-up messages arrive after agent would stop
|
||||||
if (!firstTurn) {
|
while (true) {
|
||||||
stream.push({ type: "turn_start" });
|
let hasMoreToolCalls = true;
|
||||||
} else {
|
let steeringAfterTools: AgentMessage[] | null = null;
|
||||||
firstTurn = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process queued messages (inject before next assistant response)
|
// Inner loop: process tool calls and steering messages
|
||||||
if (queuedMessages.length > 0) {
|
while (hasMoreToolCalls || pendingMessages.length > 0) {
|
||||||
for (const message of queuedMessages) {
|
if (!firstTurn) {
|
||||||
stream.push({ type: "message_start", message });
|
stream.push({ type: "turn_start" });
|
||||||
stream.push({ type: "message_end", message });
|
} else {
|
||||||
currentContext.messages.push(message);
|
firstTurn = false;
|
||||||
newMessages.push(message);
|
|
||||||
}
|
}
|
||||||
queuedMessages = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream assistant response
|
// Process pending messages (inject before next assistant response)
|
||||||
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
|
if (pendingMessages.length > 0) {
|
||||||
newMessages.push(message);
|
for (const message of pendingMessages) {
|
||||||
|
stream.push({ type: "message_start", message });
|
||||||
|
stream.push({ type: "message_end", message });
|
||||||
|
currentContext.messages.push(message);
|
||||||
|
newMessages.push(message);
|
||||||
|
}
|
||||||
|
pendingMessages = [];
|
||||||
|
}
|
||||||
|
|
||||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
// Stream assistant response
|
||||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
|
||||||
stream.push({ type: "agent_end", messages: newMessages });
|
newMessages.push(message);
|
||||||
stream.end(newMessages);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for tool calls
|
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||||
hasMoreToolCalls = toolCalls.length > 0;
|
stream.push({ type: "agent_end", messages: newMessages });
|
||||||
|
stream.end(newMessages);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const toolResults: ToolResultMessage[] = [];
|
// Check for tool calls
|
||||||
if (hasMoreToolCalls) {
|
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||||
const toolExecution = await executeToolCalls(
|
hasMoreToolCalls = toolCalls.length > 0;
|
||||||
currentContext.tools,
|
|
||||||
message,
|
|
||||||
signal,
|
|
||||||
stream,
|
|
||||||
config.getQueuedMessages,
|
|
||||||
);
|
|
||||||
toolResults.push(...toolExecution.toolResults);
|
|
||||||
queuedAfterTools = toolExecution.queuedMessages ?? null;
|
|
||||||
|
|
||||||
for (const result of toolResults) {
|
const toolResults: ToolResultMessage[] = [];
|
||||||
currentContext.messages.push(result);
|
if (hasMoreToolCalls) {
|
||||||
newMessages.push(result);
|
const toolExecution = await executeToolCalls(
|
||||||
|
currentContext.tools,
|
||||||
|
message,
|
||||||
|
signal,
|
||||||
|
stream,
|
||||||
|
config.getSteeringMessages,
|
||||||
|
);
|
||||||
|
toolResults.push(...toolExecution.toolResults);
|
||||||
|
steeringAfterTools = toolExecution.steeringMessages ?? null;
|
||||||
|
|
||||||
|
for (const result of toolResults) {
|
||||||
|
currentContext.messages.push(result);
|
||||||
|
newMessages.push(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.push({ type: "turn_end", message, toolResults });
|
||||||
|
|
||||||
|
// Get steering messages after turn completes
|
||||||
|
if (steeringAfterTools && steeringAfterTools.length > 0) {
|
||||||
|
pendingMessages = steeringAfterTools;
|
||||||
|
steeringAfterTools = null;
|
||||||
|
} else {
|
||||||
|
pendingMessages = (await config.getSteeringMessages?.()) || [];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.push({ type: "turn_end", message, toolResults });
|
// Agent would stop here. Check for follow-up messages.
|
||||||
|
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
|
||||||
// Get queued messages after turn completes
|
if (followUpMessages.length > 0) {
|
||||||
if (queuedAfterTools && queuedAfterTools.length > 0) {
|
// Set as pending so inner loop processes them
|
||||||
queuedMessages = queuedAfterTools;
|
pendingMessages = followUpMessages;
|
||||||
queuedAfterTools = null;
|
continue;
|
||||||
} else {
|
|
||||||
queuedMessages = (await config.getQueuedMessages?.()) || [];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No more messages, exit
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.push({ type: "agent_end", messages: newMessages });
|
stream.push({ type: "agent_end", messages: newMessages });
|
||||||
|
|
@ -279,11 +296,11 @@ async function executeToolCalls(
|
||||||
assistantMessage: AssistantMessage,
|
assistantMessage: AssistantMessage,
|
||||||
signal: AbortSignal | undefined,
|
signal: AbortSignal | undefined,
|
||||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
|
||||||
): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
|
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
|
||||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||||
const results: ToolResultMessage[] = [];
|
const results: ToolResultMessage[] = [];
|
||||||
let queuedMessages: AgentMessage[] | undefined;
|
let steeringMessages: AgentMessage[] | undefined;
|
||||||
|
|
||||||
for (let index = 0; index < toolCalls.length; index++) {
|
for (let index = 0; index < toolCalls.length; index++) {
|
||||||
const toolCall = toolCalls[index];
|
const toolCall = toolCalls[index];
|
||||||
|
|
@ -343,11 +360,11 @@ async function executeToolCalls(
|
||||||
stream.push({ type: "message_start", message: toolResultMessage });
|
stream.push({ type: "message_start", message: toolResultMessage });
|
||||||
stream.push({ type: "message_end", message: toolResultMessage });
|
stream.push({ type: "message_end", message: toolResultMessage });
|
||||||
|
|
||||||
// Check for queued messages - skip remaining tools if user interrupted
|
// Check for steering messages - skip remaining tools if user interrupted
|
||||||
if (getQueuedMessages) {
|
if (getSteeringMessages) {
|
||||||
const queued = await getQueuedMessages();
|
const steering = await getSteeringMessages();
|
||||||
if (queued.length > 0) {
|
if (steering.length > 0) {
|
||||||
queuedMessages = queued;
|
steeringMessages = steering;
|
||||||
const remainingCalls = toolCalls.slice(index + 1);
|
const remainingCalls = toolCalls.slice(index + 1);
|
||||||
for (const skipped of remainingCalls) {
|
for (const skipped of remainingCalls) {
|
||||||
results.push(skipToolCall(skipped, stream));
|
results.push(skipToolCall(skipped, stream));
|
||||||
|
|
@ -357,7 +374,7 @@ async function executeToolCalls(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { toolResults: results, queuedMessages };
|
return { toolResults: results, steeringMessages };
|
||||||
}
|
}
|
||||||
|
|
||||||
function skipToolCall(
|
function skipToolCall(
|
||||||
|
|
|
||||||
|
|
@ -47,9 +47,14 @@ export interface AgentOptions {
|
||||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Queue mode: "all" = send all queued messages at once, "one-at-a-time" = one per turn
|
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
||||||
*/
|
*/
|
||||||
queueMode?: "all" | "one-at-a-time";
|
steeringMode?: "all" | "one-at-a-time";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
||||||
|
*/
|
||||||
|
followUpMode?: "all" | "one-at-a-time";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||||
|
|
@ -80,8 +85,10 @@ export class Agent {
|
||||||
private abortController?: AbortController;
|
private abortController?: AbortController;
|
||||||
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||||
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||||
private messageQueue: AgentMessage[] = [];
|
private steeringQueue: AgentMessage[] = [];
|
||||||
private queueMode: "all" | "one-at-a-time";
|
private followUpQueue: AgentMessage[] = [];
|
||||||
|
private steeringMode: "all" | "one-at-a-time";
|
||||||
|
private followUpMode: "all" | "one-at-a-time";
|
||||||
public streamFn: StreamFn;
|
public streamFn: StreamFn;
|
||||||
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||||
private runningPrompt?: Promise<void>;
|
private runningPrompt?: Promise<void>;
|
||||||
|
|
@ -91,7 +98,8 @@ export class Agent {
|
||||||
this._state = { ...this._state, ...opts.initialState };
|
this._state = { ...this._state, ...opts.initialState };
|
||||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||||
this.transformContext = opts.transformContext;
|
this.transformContext = opts.transformContext;
|
||||||
this.queueMode = opts.queueMode || "one-at-a-time";
|
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
||||||
|
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
||||||
this.streamFn = opts.streamFn || streamSimple;
|
this.streamFn = opts.streamFn || streamSimple;
|
||||||
this.getApiKey = opts.getApiKey;
|
this.getApiKey = opts.getApiKey;
|
||||||
}
|
}
|
||||||
|
|
@ -118,12 +126,20 @@ export class Agent {
|
||||||
this._state.thinkingLevel = l;
|
this._state.thinkingLevel = l;
|
||||||
}
|
}
|
||||||
|
|
||||||
setQueueMode(mode: "all" | "one-at-a-time") {
|
setSteeringMode(mode: "all" | "one-at-a-time") {
|
||||||
this.queueMode = mode;
|
this.steeringMode = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
getQueueMode(): "all" | "one-at-a-time" {
|
getSteeringMode(): "all" | "one-at-a-time" {
|
||||||
return this.queueMode;
|
return this.steeringMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
||||||
|
this.followUpMode = mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
getFollowUpMode(): "all" | "one-at-a-time" {
|
||||||
|
return this.followUpMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
setTools(t: AgentTool<any>[]) {
|
setTools(t: AgentTool<any>[]) {
|
||||||
|
|
@ -138,12 +154,33 @@ export class Agent {
|
||||||
this._state.messages = [...this._state.messages, m];
|
this._state.messages = [...this._state.messages, m];
|
||||||
}
|
}
|
||||||
|
|
||||||
queueMessage(m: AgentMessage) {
|
/**
|
||||||
this.messageQueue.push(m);
|
* Queue a steering message to interrupt the agent mid-run.
|
||||||
|
* Delivered after current tool execution, skips remaining tools.
|
||||||
|
*/
|
||||||
|
steer(m: AgentMessage) {
|
||||||
|
this.steeringQueue.push(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
clearMessageQueue() {
|
/**
|
||||||
this.messageQueue = [];
|
* Queue a follow-up message to be processed after the agent finishes.
|
||||||
|
* Delivered only when agent has no more tool calls or steering messages.
|
||||||
|
*/
|
||||||
|
followUp(m: AgentMessage) {
|
||||||
|
this.followUpQueue.push(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
clearSteeringQueue() {
|
||||||
|
this.steeringQueue = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
clearFollowUpQueue() {
|
||||||
|
this.followUpQueue = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
clearAllQueues() {
|
||||||
|
this.steeringQueue = [];
|
||||||
|
this.followUpQueue = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
clearMessages() {
|
clearMessages() {
|
||||||
|
|
@ -164,7 +201,8 @@ export class Agent {
|
||||||
this._state.streamMessage = null;
|
this._state.streamMessage = null;
|
||||||
this._state.pendingToolCalls = new Set<string>();
|
this._state.pendingToolCalls = new Set<string>();
|
||||||
this._state.error = undefined;
|
this._state.error = undefined;
|
||||||
this.messageQueue = [];
|
this.steeringQueue = [];
|
||||||
|
this.followUpQueue = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Send a prompt with an AgentMessage */
|
/** Send a prompt with an AgentMessage */
|
||||||
|
|
@ -172,7 +210,9 @@ export class Agent {
|
||||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||||
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
||||||
if (this._state.isStreaming) {
|
if (this._state.isStreaming) {
|
||||||
throw new Error("Agent is already processing a prompt. Use queueMessage() or wait for completion.");
|
throw new Error(
|
||||||
|
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = this._state.model;
|
const model = this._state.model;
|
||||||
|
|
@ -255,18 +295,32 @@ export class Agent {
|
||||||
convertToLlm: this.convertToLlm,
|
convertToLlm: this.convertToLlm,
|
||||||
transformContext: this.transformContext,
|
transformContext: this.transformContext,
|
||||||
getApiKey: this.getApiKey,
|
getApiKey: this.getApiKey,
|
||||||
getQueuedMessages: async () => {
|
getSteeringMessages: async () => {
|
||||||
if (this.queueMode === "one-at-a-time") {
|
if (this.steeringMode === "one-at-a-time") {
|
||||||
if (this.messageQueue.length > 0) {
|
if (this.steeringQueue.length > 0) {
|
||||||
const first = this.messageQueue[0];
|
const first = this.steeringQueue[0];
|
||||||
this.messageQueue = this.messageQueue.slice(1);
|
this.steeringQueue = this.steeringQueue.slice(1);
|
||||||
return [first];
|
return [first];
|
||||||
}
|
}
|
||||||
return [];
|
return [];
|
||||||
} else {
|
} else {
|
||||||
const queued = this.messageQueue.slice();
|
const steering = this.steeringQueue.slice();
|
||||||
this.messageQueue = [];
|
this.steeringQueue = [];
|
||||||
return queued;
|
return steering;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
getFollowUpMessages: async () => {
|
||||||
|
if (this.followUpMode === "one-at-a-time") {
|
||||||
|
if (this.followUpQueue.length > 0) {
|
||||||
|
const first = this.followUpQueue[0];
|
||||||
|
this.followUpQueue = this.followUpQueue.slice(1);
|
||||||
|
return [first];
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
} else {
|
||||||
|
const followUp = this.followUpQueue.slice();
|
||||||
|
this.followUpQueue = [];
|
||||||
|
return followUp;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -75,12 +75,26 @@ export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns queued messages to inject into the conversation.
|
* Returns steering messages to inject into the conversation mid-run.
|
||||||
*
|
*
|
||||||
* Called after each turn to check for user interruptions or injected messages.
|
* Called after each tool execution to check for user interruptions.
|
||||||
* If messages are returned, they're added to the context before the next LLM call.
|
* If messages are returned, remaining tool calls are skipped and
|
||||||
|
* these messages are added to the context before the next LLM call.
|
||||||
|
*
|
||||||
|
* Use this for "steering" the agent while it's working.
|
||||||
*/
|
*/
|
||||||
getQueuedMessages?: () => Promise<AgentMessage[]>;
|
getSteeringMessages?: () => Promise<AgentMessage[]>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns follow-up messages to process after the agent would otherwise stop.
|
||||||
|
*
|
||||||
|
* Called when the agent has no more tool calls and no steering messages.
|
||||||
|
* If messages are returned, they're added to the context and the agent
|
||||||
|
* continues with another turn.
|
||||||
|
*
|
||||||
|
* Use this for follow-up messages that should wait until the agent finishes.
|
||||||
|
*/
|
||||||
|
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -340,8 +340,8 @@ describe("agentLoop with AgentMessage", () => {
|
||||||
const config: AgentLoopConfig = {
|
const config: AgentLoopConfig = {
|
||||||
model: createModel(),
|
model: createModel(),
|
||||||
convertToLlm: identityConverter,
|
convertToLlm: identityConverter,
|
||||||
getQueuedMessages: async () => {
|
getSteeringMessages: async () => {
|
||||||
// Return queued message after first tool executes
|
// Return steering message after first tool executes
|
||||||
if (executed.length === 1 && !queuedDelivered) {
|
if (executed.length === 1 && !queuedDelivered) {
|
||||||
queuedDelivered = true;
|
queuedDelivered = true;
|
||||||
return [queuedUserMessage];
|
return [queuedUserMessage];
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue