mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 13:04:08 +00:00
WIP: Refactor agent package - not compiling
- Renamed AppMessage to AgentMessage throughout - New agent-loop.ts with AgentLoopContext, AgentLoopConfig - Removed transport abstraction, Agent now takes streamFn directly - Extracted streamProxy to proxy.ts utility - Removed agent-loop from pi-ai (now in agent package) - Updated consumers (coding-agent, mom) for AgentMessage rename - Tests updated but some consumers still need migration Known issues: - AgentTool, AgentToolResult not exported from pi-ai - Attachment not exported from pi-agent-core - ProviderTransport removed but still referenced - messageTransformer -> convertToLlm migration incomplete - CustomMessages declaration merging not working properly
This commit is contained in:
parent
f7ef44dc38
commit
a055fd4481
32 changed files with 1312 additions and 2009 deletions
398
packages/agent/src/agent-loop.ts
Normal file
398
packages/agent/src/agent-loop.ts
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompt: AgentMessage,
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [prompt];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, prompt],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = { ...context };
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let hasMoreToolCalls = true;
|
||||
let firstTurn = true;
|
||||
let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
|
||||
let queuedAfterTools: AgentMessage[] | null = null;
|
||||
|
||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
stream.push({ type: "turn_start" });
|
||||
} else {
|
||||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process queued messages (inject before next assistant response)
|
||||
if (queuedMessages.length > 0) {
|
||||
for (const message of queuedMessages) {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
queuedMessages = [];
|
||||
}
|
||||
|
||||
// Stream assistant response
|
||||
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
|
||||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for tool calls
|
||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||
hasMoreToolCalls = toolCalls.length > 0;
|
||||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext.tools,
|
||||
message,
|
||||
signal,
|
||||
stream,
|
||||
config.getQueuedMessages,
|
||||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
queuedAfterTools = toolExecution.queuedMessages ?? null;
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
|
||||
// Get queued messages after turn completes
|
||||
if (queuedAfterTools && queuedAfterTools.length > 0) {
|
||||
queuedMessages = queuedAfterTools;
|
||||
queuedAfterTools = null;
|
||||
} else {
|
||||
queuedMessages = (await config.getQueuedMessages?.()) || [];
|
||||
}
|
||||
}
|
||||
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
|
||||
|
||||
const response = streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
||||
for await (const event of response) {
|
||||
switch (event.type) {
|
||||
case "start":
|
||||
partialMessage = event.partial;
|
||||
context.messages.push(partialMessage);
|
||||
addedPartial = true;
|
||||
stream.push({ type: "message_start", message: { ...partialMessage } });
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
case "text_delta":
|
||||
case "text_end":
|
||||
case "thinking_start":
|
||||
case "thinking_delta":
|
||||
case "thinking_end":
|
||||
case "toolcall_start":
|
||||
case "toolcall_delta":
|
||||
case "toolcall_end":
|
||||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
stream.push({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "done":
|
||||
case "error": {
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
}
|
||||
if (!addedPartial) {
|
||||
stream.push({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
stream.push({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return await response.result();
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
tools: AgentTool<any>[] | undefined,
|
||||
assistantMessage: AssistantMessage,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
||||
): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
const results: ToolResultMessage[] = [];
|
||||
let queuedMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
const tool = tools?.find((t) => t.name === toolCall.name);
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
let result: AgentToolResult<any>;
|
||||
let isError = false;
|
||||
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
|
||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
});
|
||||
} catch (e) {
|
||||
result = {
|
||||
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
|
||||
details: {},
|
||||
};
|
||||
isError = true;
|
||||
}
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
results.push(toolResultMessage);
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
// Check for queued messages - skip remaining tools if user interrupted
|
||||
if (getQueuedMessages) {
|
||||
const queued = await getQueuedMessages();
|
||||
if (queued.length > 0) {
|
||||
queuedMessages = queued;
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, queuedMessages };
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {},
|
||||
};
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError: true,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: {},
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
return toolResultMessage;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue