mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 06:04:40 +00:00
559 lines
15 KiB
TypeScript
559 lines
15 KiB
TypeScript
/**
|
|
* Agent class that uses the agent-loop directly.
|
|
* No transport abstraction - calls streamSimple via the loop.
|
|
*/
|
|
|
|
import {
|
|
getModel,
|
|
type ImageContent,
|
|
type Message,
|
|
type Model,
|
|
streamSimple,
|
|
type TextContent,
|
|
type ThinkingBudgets,
|
|
type Transport,
|
|
} from "@mariozechner/pi-ai";
|
|
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
|
import type {
|
|
AgentContext,
|
|
AgentEvent,
|
|
AgentLoopConfig,
|
|
AgentMessage,
|
|
AgentState,
|
|
AgentTool,
|
|
StreamFn,
|
|
ThinkingLevel,
|
|
} from "./types.js";
|
|
|
|
/**
|
|
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
|
*/
|
|
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
|
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
|
|
}
|
|
|
|
export interface AgentOptions {
|
|
initialState?: Partial<AgentState>;
|
|
|
|
/**
|
|
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
|
* Default filters to user/assistant/toolResult and converts attachments.
|
|
*/
|
|
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
|
|
|
/**
|
|
* Optional transform applied to context before convertToLlm.
|
|
* Use for context pruning, injecting external context, etc.
|
|
*/
|
|
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
|
|
|
/**
|
|
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
|
*/
|
|
steeringMode?: "all" | "one-at-a-time";
|
|
|
|
/**
|
|
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
|
*/
|
|
followUpMode?: "all" | "one-at-a-time";
|
|
|
|
/**
|
|
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
|
*/
|
|
streamFn?: StreamFn;
|
|
|
|
/**
|
|
* Optional session identifier forwarded to LLM providers.
|
|
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
|
*/
|
|
sessionId?: string;
|
|
|
|
/**
|
|
* Resolves an API key dynamically for each LLM call.
|
|
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
|
*/
|
|
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
|
|
|
/**
|
|
* Custom token budgets for thinking levels (token-based providers only).
|
|
*/
|
|
thinkingBudgets?: ThinkingBudgets;
|
|
|
|
/**
|
|
* Preferred transport for providers that support multiple transports.
|
|
*/
|
|
transport?: Transport;
|
|
|
|
/**
|
|
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
|
* If the server's requested delay exceeds this value, the request fails immediately,
|
|
* allowing higher-level retry logic to handle it with user visibility.
|
|
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
|
*/
|
|
maxRetryDelayMs?: number;
|
|
}
|
|
|
|
export class Agent {
|
|
private _state: AgentState = {
|
|
systemPrompt: "",
|
|
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
|
thinkingLevel: "off",
|
|
tools: [],
|
|
messages: [],
|
|
isStreaming: false,
|
|
streamMessage: null,
|
|
pendingToolCalls: new Set<string>(),
|
|
error: undefined,
|
|
};
|
|
|
|
private listeners = new Set<(e: AgentEvent) => void>();
|
|
private abortController?: AbortController;
|
|
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
|
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
|
private steeringQueue: AgentMessage[] = [];
|
|
private followUpQueue: AgentMessage[] = [];
|
|
private steeringMode: "all" | "one-at-a-time";
|
|
private followUpMode: "all" | "one-at-a-time";
|
|
public streamFn: StreamFn;
|
|
private _sessionId?: string;
|
|
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
|
private runningPrompt?: Promise<void>;
|
|
private resolveRunningPrompt?: () => void;
|
|
private _thinkingBudgets?: ThinkingBudgets;
|
|
private _transport: Transport;
|
|
private _maxRetryDelayMs?: number;
|
|
|
|
constructor(opts: AgentOptions = {}) {
|
|
this._state = { ...this._state, ...opts.initialState };
|
|
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
|
this.transformContext = opts.transformContext;
|
|
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
|
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
|
this.streamFn = opts.streamFn || streamSimple;
|
|
this._sessionId = opts.sessionId;
|
|
this.getApiKey = opts.getApiKey;
|
|
this._thinkingBudgets = opts.thinkingBudgets;
|
|
this._transport = opts.transport ?? "sse";
|
|
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
|
}
|
|
|
|
/**
|
|
* Get the current session ID used for provider caching.
|
|
*/
|
|
get sessionId(): string | undefined {
|
|
return this._sessionId;
|
|
}
|
|
|
|
/**
|
|
* Set the session ID for provider caching.
|
|
* Call this when switching sessions (new session, branch, resume).
|
|
*/
|
|
set sessionId(value: string | undefined) {
|
|
this._sessionId = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current thinking budgets.
|
|
*/
|
|
get thinkingBudgets(): ThinkingBudgets | undefined {
|
|
return this._thinkingBudgets;
|
|
}
|
|
|
|
/**
|
|
* Set custom thinking budgets for token-based providers.
|
|
*/
|
|
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
|
this._thinkingBudgets = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current preferred transport.
|
|
*/
|
|
get transport(): Transport {
|
|
return this._transport;
|
|
}
|
|
|
|
/**
|
|
* Set the preferred transport.
|
|
*/
|
|
setTransport(value: Transport) {
|
|
this._transport = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current max retry delay in milliseconds.
|
|
*/
|
|
get maxRetryDelayMs(): number | undefined {
|
|
return this._maxRetryDelayMs;
|
|
}
|
|
|
|
/**
|
|
* Set the maximum delay to wait for server-requested retries.
|
|
* Set to 0 to disable the cap.
|
|
*/
|
|
set maxRetryDelayMs(value: number | undefined) {
|
|
this._maxRetryDelayMs = value;
|
|
}
|
|
|
|
get state(): AgentState {
|
|
return this._state;
|
|
}
|
|
|
|
subscribe(fn: (e: AgentEvent) => void): () => void {
|
|
this.listeners.add(fn);
|
|
return () => this.listeners.delete(fn);
|
|
}
|
|
|
|
// State mutators
|
|
setSystemPrompt(v: string) {
|
|
this._state.systemPrompt = v;
|
|
}
|
|
|
|
setModel(m: Model<any>) {
|
|
this._state.model = m;
|
|
}
|
|
|
|
setThinkingLevel(l: ThinkingLevel) {
|
|
this._state.thinkingLevel = l;
|
|
}
|
|
|
|
setSteeringMode(mode: "all" | "one-at-a-time") {
|
|
this.steeringMode = mode;
|
|
}
|
|
|
|
getSteeringMode(): "all" | "one-at-a-time" {
|
|
return this.steeringMode;
|
|
}
|
|
|
|
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
|
this.followUpMode = mode;
|
|
}
|
|
|
|
getFollowUpMode(): "all" | "one-at-a-time" {
|
|
return this.followUpMode;
|
|
}
|
|
|
|
setTools(t: AgentTool<any>[]) {
|
|
this._state.tools = t;
|
|
}
|
|
|
|
replaceMessages(ms: AgentMessage[]) {
|
|
this._state.messages = ms.slice();
|
|
}
|
|
|
|
appendMessage(m: AgentMessage) {
|
|
this._state.messages = [...this._state.messages, m];
|
|
}
|
|
|
|
/**
|
|
* Queue a steering message to interrupt the agent mid-run.
|
|
* Delivered after current tool execution, skips remaining tools.
|
|
*/
|
|
steer(m: AgentMessage) {
|
|
this.steeringQueue.push(m);
|
|
}
|
|
|
|
/**
|
|
* Queue a follow-up message to be processed after the agent finishes.
|
|
* Delivered only when agent has no more tool calls or steering messages.
|
|
*/
|
|
followUp(m: AgentMessage) {
|
|
this.followUpQueue.push(m);
|
|
}
|
|
|
|
clearSteeringQueue() {
|
|
this.steeringQueue = [];
|
|
}
|
|
|
|
clearFollowUpQueue() {
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
clearAllQueues() {
|
|
this.steeringQueue = [];
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
hasQueuedMessages(): boolean {
|
|
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
|
}
|
|
|
|
private dequeueSteeringMessages(): AgentMessage[] {
|
|
if (this.steeringMode === "one-at-a-time") {
|
|
if (this.steeringQueue.length > 0) {
|
|
const first = this.steeringQueue[0];
|
|
this.steeringQueue = this.steeringQueue.slice(1);
|
|
return [first];
|
|
}
|
|
return [];
|
|
}
|
|
|
|
const steering = this.steeringQueue.slice();
|
|
this.steeringQueue = [];
|
|
return steering;
|
|
}
|
|
|
|
private dequeueFollowUpMessages(): AgentMessage[] {
|
|
if (this.followUpMode === "one-at-a-time") {
|
|
if (this.followUpQueue.length > 0) {
|
|
const first = this.followUpQueue[0];
|
|
this.followUpQueue = this.followUpQueue.slice(1);
|
|
return [first];
|
|
}
|
|
return [];
|
|
}
|
|
|
|
const followUp = this.followUpQueue.slice();
|
|
this.followUpQueue = [];
|
|
return followUp;
|
|
}
|
|
|
|
clearMessages() {
|
|
this._state.messages = [];
|
|
}
|
|
|
|
abort() {
|
|
this.abortController?.abort();
|
|
}
|
|
|
|
waitForIdle(): Promise<void> {
|
|
return this.runningPrompt ?? Promise.resolve();
|
|
}
|
|
|
|
reset() {
|
|
this._state.messages = [];
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
this._state.pendingToolCalls = new Set<string>();
|
|
this._state.error = undefined;
|
|
this.steeringQueue = [];
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
/** Send a prompt with an AgentMessage */
|
|
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
|
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
|
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
|
if (this._state.isStreaming) {
|
|
throw new Error(
|
|
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
|
);
|
|
}
|
|
|
|
const model = this._state.model;
|
|
if (!model) throw new Error("No model configured");
|
|
|
|
let msgs: AgentMessage[];
|
|
|
|
if (Array.isArray(input)) {
|
|
msgs = input;
|
|
} else if (typeof input === "string") {
|
|
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
|
if (images && images.length > 0) {
|
|
content.push(...images);
|
|
}
|
|
msgs = [
|
|
{
|
|
role: "user",
|
|
content,
|
|
timestamp: Date.now(),
|
|
},
|
|
];
|
|
} else {
|
|
msgs = [input];
|
|
}
|
|
|
|
await this._runLoop(msgs);
|
|
}
|
|
|
|
/**
|
|
* Continue from current context (used for retries and resuming queued messages).
|
|
*/
|
|
async continue() {
|
|
if (this._state.isStreaming) {
|
|
throw new Error("Agent is already processing. Wait for completion before continuing.");
|
|
}
|
|
|
|
const messages = this._state.messages;
|
|
if (messages.length === 0) {
|
|
throw new Error("No messages to continue from");
|
|
}
|
|
if (messages[messages.length - 1].role === "assistant") {
|
|
const queuedSteering = this.dequeueSteeringMessages();
|
|
if (queuedSteering.length > 0) {
|
|
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
|
return;
|
|
}
|
|
|
|
const queuedFollowUp = this.dequeueFollowUpMessages();
|
|
if (queuedFollowUp.length > 0) {
|
|
await this._runLoop(queuedFollowUp);
|
|
return;
|
|
}
|
|
|
|
throw new Error("Cannot continue from message role: assistant");
|
|
}
|
|
|
|
await this._runLoop(undefined);
|
|
}
|
|
|
|
/**
|
|
* Run the agent loop.
|
|
* If messages are provided, starts a new conversation turn with those messages.
|
|
* Otherwise, continues from existing context.
|
|
*/
|
|
private async _runLoop(messages?: AgentMessage[], options?: { skipInitialSteeringPoll?: boolean }) {
|
|
const model = this._state.model;
|
|
if (!model) throw new Error("No model configured");
|
|
|
|
this.runningPrompt = new Promise<void>((resolve) => {
|
|
this.resolveRunningPrompt = resolve;
|
|
});
|
|
|
|
this.abortController = new AbortController();
|
|
this._state.isStreaming = true;
|
|
this._state.streamMessage = null;
|
|
this._state.error = undefined;
|
|
|
|
const reasoning = this._state.thinkingLevel === "off" ? undefined : this._state.thinkingLevel;
|
|
|
|
const context: AgentContext = {
|
|
systemPrompt: this._state.systemPrompt,
|
|
messages: this._state.messages.slice(),
|
|
tools: this._state.tools,
|
|
};
|
|
|
|
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
|
|
|
const config: AgentLoopConfig = {
|
|
model,
|
|
reasoning,
|
|
sessionId: this._sessionId,
|
|
transport: this._transport,
|
|
thinkingBudgets: this._thinkingBudgets,
|
|
maxRetryDelayMs: this._maxRetryDelayMs,
|
|
convertToLlm: this.convertToLlm,
|
|
transformContext: this.transformContext,
|
|
getApiKey: this.getApiKey,
|
|
getSteeringMessages: async () => {
|
|
if (skipInitialSteeringPoll) {
|
|
skipInitialSteeringPoll = false;
|
|
return [];
|
|
}
|
|
return this.dequeueSteeringMessages();
|
|
},
|
|
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
|
};
|
|
|
|
let partial: AgentMessage | null = null;
|
|
|
|
try {
|
|
const stream = messages
|
|
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
|
|
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
|
|
|
|
for await (const event of stream) {
|
|
// Update internal state based on events
|
|
switch (event.type) {
|
|
case "message_start":
|
|
partial = event.message;
|
|
this._state.streamMessage = event.message;
|
|
break;
|
|
|
|
case "message_update":
|
|
partial = event.message;
|
|
this._state.streamMessage = event.message;
|
|
break;
|
|
|
|
case "message_end":
|
|
partial = null;
|
|
this._state.streamMessage = null;
|
|
this.appendMessage(event.message);
|
|
break;
|
|
|
|
case "tool_execution_start": {
|
|
const s = new Set(this._state.pendingToolCalls);
|
|
s.add(event.toolCallId);
|
|
this._state.pendingToolCalls = s;
|
|
break;
|
|
}
|
|
|
|
case "tool_execution_end": {
|
|
const s = new Set(this._state.pendingToolCalls);
|
|
s.delete(event.toolCallId);
|
|
this._state.pendingToolCalls = s;
|
|
break;
|
|
}
|
|
|
|
case "turn_end":
|
|
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
|
this._state.error = (event.message as any).errorMessage;
|
|
}
|
|
break;
|
|
|
|
case "agent_end":
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
break;
|
|
}
|
|
|
|
// Emit to listeners
|
|
this.emit(event);
|
|
}
|
|
|
|
// Handle any remaining partial message
|
|
if (partial && partial.role === "assistant" && partial.content.length > 0) {
|
|
const onlyEmpty = !partial.content.some(
|
|
(c) =>
|
|
(c.type === "thinking" && c.thinking.trim().length > 0) ||
|
|
(c.type === "text" && c.text.trim().length > 0) ||
|
|
(c.type === "toolCall" && c.name.trim().length > 0),
|
|
);
|
|
if (!onlyEmpty) {
|
|
this.appendMessage(partial);
|
|
} else {
|
|
if (this.abortController?.signal.aborted) {
|
|
throw new Error("Request was aborted");
|
|
}
|
|
}
|
|
}
|
|
} catch (err: any) {
|
|
const errorMsg: AgentMessage = {
|
|
role: "assistant",
|
|
content: [{ type: "text", text: "" }],
|
|
api: model.api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
|
errorMessage: err?.message || String(err),
|
|
timestamp: Date.now(),
|
|
} as AgentMessage;
|
|
|
|
this.appendMessage(errorMsg);
|
|
this._state.error = err?.message || String(err);
|
|
this.emit({ type: "agent_end", messages: [errorMsg] });
|
|
} finally {
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
this._state.pendingToolCalls = new Set<string>();
|
|
this.abortController = undefined;
|
|
this.resolveRunningPrompt?.();
|
|
this.runningPrompt = undefined;
|
|
this.resolveRunningPrompt = undefined;
|
|
}
|
|
}
|
|
|
|
private emit(e: AgentEvent) {
|
|
for (const listener of this.listeners) {
|
|
listener(e);
|
|
}
|
|
}
|
|
}
|