diff --git a/packages/coding-agent/src/core/gateway/durable-chat-run.ts b/packages/coding-agent/src/core/gateway/durable-chat-run.ts new file mode 100644 index 0000000..a16a349 --- /dev/null +++ b/packages/coding-agent/src/core/gateway/durable-chat-run.ts @@ -0,0 +1,246 @@ +import type { AgentMessage } from "@mariozechner/companion-agent-core"; +import type { AgentSessionEvent } from "../agent-session.js"; +import { extractMessageText } from "./helpers.js"; +import { messageContentToHistoryParts } from "./session-state.js"; +import type { GatewayTransientToolResult } from "./session-state.js"; +import type { GatewayMessageResult, GatewayMessageRequest } from "./types.js"; + +const FLUSH_INTERVAL_MS = 500; + +type PersistHistoryItem = { + role: "user" | "assistant" | "toolResult"; + text?: string; + partsJson: string; + timestamp: number; + idempotencyKey: string; +}; + +type ConvexRunStatus = "completed" | "failed" | "interrupted"; + +function normalizeErrorMessage(error: unknown): string { + if (error instanceof Error) { + return error.message; + } + return typeof error === "string" ? error : String(error); +} + +type DurableChatRunEventBody = + | { + items: PersistHistoryItem[]; + final?: { + status: ConvexRunStatus; + error?: string; + }; + } + | { + items?: PersistHistoryItem[]; + final: { + status: ConvexRunStatus; + error?: string; + }; + }; + +function buildAuthHeaders(token: string): Record { + return { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }; +} + +export class DurableChatRunReporter { + private readonly assistantMessageId: string; + private latestAssistantMessage: AgentMessage | null = null; + private readonly knownToolResults = new Map< + string, + GatewayTransientToolResult + >(); + private flushTimer: ReturnType | null = null; + private flushChain: Promise = Promise.resolve(); + private flushFailure: Error | null = null; + + constructor( + private readonly durableRun: NonNullable< + GatewayMessageRequest["durableRun"] + >, + ) { + if ( + durableRun.callbackUrl.trim().length === 0 || + durableRun.callbackToken.trim().length === 0 + ) { + throw new Error( + "Durable chat run reporting requires callbackUrl and callbackToken", + ); + } + this.assistantMessageId = `run:${this.durableRun.runId}:assistant`; + } + + handleSessionEvent( + event: AgentSessionEvent, + pendingToolResults: GatewayTransientToolResult[], + ): void { + for (const toolResult of pendingToolResults) { + this.knownToolResults.set(toolResult.toolCallId, toolResult); + } + + if (event.type === "message_start" && event.message.role === "assistant") { + this.latestAssistantMessage = event.message; + return; + } + + if (event.type === "message_update" && event.message.role === "assistant") { + this.latestAssistantMessage = event.message; + this.scheduleFlush(); + return; + } + + if (event.type === "message_end" && event.message.role === "assistant") { + this.latestAssistantMessage = event.message; + this.scheduleFlush(); + return; + } + + if ( + event.type === "tool_execution_end" || + event.type === "turn_end" || + (event.type === "message_end" && event.message.role === "toolResult") + ) { + this.scheduleFlush(); + } + } + + async finalize(result: GatewayMessageResult): Promise { + let status: ConvexRunStatus = result.ok + ? "completed" + : result.error?.includes("aborted") + ? "interrupted" + : "failed"; + let errorMessage = result.error; + + try { + await this.finalFlush(); + } catch (error) { + status = "failed"; + errorMessage = normalizeErrorMessage(error); + } + await this.postEvent({ + final: { + status, + ...(status === "failed" && errorMessage ? { error: errorMessage } : {}), + }, + }); + } + + private scheduleFlush(): void { + if (this.flushTimer) return; + this.flushTimer = globalThis.setTimeout(() => { + this.flushTimer = null; + void this.flush().catch(() => undefined); + }, FLUSH_INTERVAL_MS); + } + + private async flush(): Promise { + this.throwIfFlushFailed(); + if (this.flushTimer) { + globalThis.clearTimeout(this.flushTimer); + this.flushTimer = null; + } + + const items = this.buildItems(); + if (items.length === 0) { + return; + } + + const flushPromise = this.flushChain.then(async () => { + this.throwIfFlushFailed(); + await this.postEvent({ + items, + }); + }); + this.flushChain = flushPromise.catch(() => undefined); + + try { + await flushPromise; + } catch (error) { + throw this.markFlushFailed(error); + } + } + + private async finalFlush(): Promise { + await this.flush(); + await this.flushChain; + this.throwIfFlushFailed(); + } + + private buildItems(): PersistHistoryItem[] { + const assistantParts = + this.latestAssistantMessage?.role === "assistant" + ? messageContentToHistoryParts(this.latestAssistantMessage) + : []; + + for (const toolResult of this.knownToolResults.values()) { + assistantParts.push({ + type: "tool-invocation", + toolCallId: toolResult.toolCallId, + toolName: toolResult.toolName, + args: undefined, + state: toolResult.isError ? "error" : "result", + result: toolResult.result, + }); + } + + const firstToolResult = this.knownToolResults.values().next().value; + + if ( + this.latestAssistantMessage?.role === "assistant" || + assistantParts.length > 0 + ) { + return [ + { + role: "assistant", + text: + this.latestAssistantMessage?.role === "assistant" + ? extractMessageText(this.latestAssistantMessage) || undefined + : undefined, + partsJson: JSON.stringify(assistantParts), + timestamp: + this.latestAssistantMessage?.timestamp ?? + firstToolResult?.timestamp ?? + Date.now(), + idempotencyKey: this.assistantMessageId, + }, + ]; + } + + return []; + } + + private async postEvent(body: DurableChatRunEventBody): Promise { + const response = await fetch(this.durableRun.callbackUrl, { + method: "POST", + headers: buildAuthHeaders(this.durableRun.callbackToken), + body: JSON.stringify(body), + }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error( + `Chat run relay failed: ${response.status} ${text}`.trim(), + ); + } + } + + private throwIfFlushFailed(): void { + if (this.flushFailure) { + throw this.flushFailure; + } + } + + private markFlushFailed(error: unknown): Error { + if (this.flushFailure) { + return this.flushFailure; + } + const normalizedError = + error instanceof Error ? error : new Error(normalizeErrorMessage(error)); + this.flushFailure = normalizedError; + return normalizedError; + } +} diff --git a/packages/coding-agent/src/core/gateway/internal-types.ts b/packages/coding-agent/src/core/gateway/internal-types.ts index 88274d4..2bbb086 100644 --- a/packages/coding-agent/src/core/gateway/internal-types.ts +++ b/packages/coding-agent/src/core/gateway/internal-types.ts @@ -1,5 +1,6 @@ import type { AgentMessage } from "@mariozechner/companion-agent-core"; import type { AgentSession } from "../agent-session.js"; +import type { DurableChatRunReporter } from "./durable-chat-run.js"; import type { GatewayMessageRequest, GatewayMessageResult, @@ -63,7 +64,13 @@ export type GatewayEvent = payload: { teamId: string; status: string; - members: Array<{ id: string; name: string; role?: string; status: string; message?: string }>; + members: Array<{ + id: string; + name: string; + role?: string; + status: string; + message?: string; + }>; }; } | { @@ -84,6 +91,7 @@ export interface ManagedGatewaySession { session: AgentSession; queue: GatewayQueuedMessage[]; processing: boolean; + activeDurableRun: DurableChatRunReporter | null; activeAssistantMessage: AgentMessage | null; pendingToolResults: GatewayTransientToolResult[]; createdAt: number; diff --git a/packages/coding-agent/src/core/gateway/runtime.ts b/packages/coding-agent/src/core/gateway/runtime.ts index c8df267..586979a 100644 --- a/packages/coding-agent/src/core/gateway/runtime.ts +++ b/packages/coding-agent/src/core/gateway/runtime.ts @@ -10,6 +10,7 @@ import { URL } from "node:url"; import type { AgentMessage } from "@mariozechner/companion-agent-core"; import type { AgentSession, AgentSessionEvent } from "../agent-session.js"; import type { Settings } from "../settings-manager.js"; +import { DurableChatRunReporter } from "./durable-chat-run.js"; import { extractMessageText, getLastAssistantText } from "./helpers.js"; import { type GatewayEvent, @@ -108,6 +109,28 @@ function readString(value: unknown): string | undefined { return trimmed.length > 0 ? trimmed : undefined; } +function readDurableRun( + value: unknown, +): GatewayMessageRequest["durableRun"] | undefined { + if (!isRecord(value)) { + return undefined; + } + + const runId = readString(value.runId); + const callbackUrl = readString(value.callbackUrl); + const callbackToken = readString(value.callbackToken); + + if (!runId || !callbackUrl || !callbackToken) { + return undefined; + } + + return { + runId, + callbackUrl, + callbackToken, + }; +} + export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void { activeGatewayRuntime = runtime; } @@ -419,6 +442,7 @@ export class GatewayRuntime { session, queue: [], processing: false, + activeDurableRun: null, activeAssistantMessage: null, pendingToolResults: [], createdAt: Date.now(), @@ -462,18 +486,32 @@ export class GatewayRuntime { ); this.emitState(managedSession); + let result: GatewayMessageResult = { + ok: false, + response: "", + error: "Unknown error", + sessionKey: managedSession.sessionKey, + }; + let durableRunReporter: DurableChatRunReporter | null = null; + try { queued.onStart?.(); + if (queued.request.durableRun) { + durableRunReporter = new DurableChatRunReporter( + queued.request.durableRun, + ); + managedSession.activeDurableRun = durableRunReporter; + } await managedSession.session.prompt(queued.request.text, { images: queued.request.images, source: queued.request.source ?? "extension", }); const response = getLastAssistantText(managedSession.session); - queued.resolve({ + result = { ok: true, response, sessionKey: managedSession.sessionKey, - }); + }; } catch (error) { const message = error instanceof Error ? error.message : String(error); this.log( @@ -491,15 +529,39 @@ export class GatewayRuntime { error: message, }); } - queued.resolve({ + result = { ok: false, response: "", error: message, sessionKey: managedSession.sessionKey, - }); + }; } finally { queued.onFinish?.(); + if (durableRunReporter) { + try { + await durableRunReporter.finalize(result); + } catch (error) { + const message = + error instanceof Error ? error.message : String(error); + this.log( + `[chat-run] session=${managedSession.sessionKey} finalize error: ${message}`, + ); + this.emit(managedSession, { + type: "error", + sessionKey: managedSession.sessionKey, + error: message, + }); + result = { + ok: false, + response: result.response, + error: message, + sessionKey: managedSession.sessionKey, + }; + } + } + queued.resolve(result); managedSession.processing = false; + managedSession.activeDurableRun = null; managedSession.activeAssistantMessage = null; managedSession.pendingToolResults = []; managedSession.lastActiveAt = Date.now(); @@ -529,6 +591,13 @@ export class GatewayRuntime { managedSession: ManagedGatewaySession, event: AgentSessionEvent, ): void { + const forwardToDurableRun = () => { + managedSession.activeDurableRun?.handleSessionEvent( + event, + managedSession.pendingToolResults, + ); + }; + switch (event.type) { case "turn_start": managedSession.lastActiveAt = Date.now(); @@ -537,6 +606,7 @@ export class GatewayRuntime { type: "turn_start", sessionKey: managedSession.sessionKey, }); + forwardToDurableRun(); return; case "turn_end": managedSession.lastActiveAt = Date.now(); @@ -545,6 +615,7 @@ export class GatewayRuntime { type: "turn_end", sessionKey: managedSession.sessionKey, }); + forwardToDurableRun(); return; case "message_start": managedSession.lastActiveAt = Date.now(); @@ -556,6 +627,7 @@ export class GatewayRuntime { sessionKey: managedSession.sessionKey, role: event.message.role, }); + forwardToDurableRun(); return; case "message_update": managedSession.lastActiveAt = Date.now(); @@ -570,6 +642,7 @@ export class GatewayRuntime { delta: event.assistantMessageEvent.delta, contentIndex: event.assistantMessageEvent.contentIndex, }); + forwardToDurableRun(); return; case "thinking_delta": this.emit(managedSession, { @@ -578,8 +651,10 @@ export class GatewayRuntime { delta: event.assistantMessageEvent.delta, contentIndex: event.assistantMessageEvent.contentIndex, }); + forwardToDurableRun(); return; } + forwardToDurableRun(); return; case "message_end": managedSession.lastActiveAt = Date.now(); @@ -595,6 +670,7 @@ export class GatewayRuntime { text: extractMessageText(event.message), }); this.emitStructuredParts(managedSession, event.message); + forwardToDurableRun(); return; } if (event.message.role === "toolResult") { @@ -610,6 +686,7 @@ export class GatewayRuntime { ); } } + forwardToDurableRun(); return; case "tool_execution_start": managedSession.lastActiveAt = Date.now(); @@ -624,6 +701,7 @@ export class GatewayRuntime { toolName: event.toolName, args: event.args, }); + forwardToDurableRun(); return; case "tool_execution_update": managedSession.lastActiveAt = Date.now(); @@ -634,6 +712,7 @@ export class GatewayRuntime { toolName: event.toolName, partialResult: event.partialResult, }); + forwardToDurableRun(); return; case "tool_execution_end": managedSession.lastActiveAt = Date.now(); @@ -661,6 +740,7 @@ export class GatewayRuntime { result: event.result, isError: event.isError, }); + forwardToDurableRun(); return; } } @@ -1030,7 +1110,7 @@ export class GatewayRuntime { } const sessionMatch = path.match( - /^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload|state|steer))?$/, + /^\/sessions\/([^/]+)(?:\/(enqueue|events|messages|abort|reset|chat|history|model|reload|state|steer))?$/, ); if (!sessionMatch) { this.writeJson(response, 404, { error: "Not found" }); @@ -1069,6 +1149,37 @@ export class GatewayRuntime { return; } + if (action === "enqueue" && method === "POST") { + const body = await this.readJsonBody(request); + const text = extractUserText(body); + if (!text) { + this.writeJson(response, 400, { error: "Missing user message text" }); + return; + } + const durableRun = readDurableRun(body.durableRun); + const queued = await this.queueManagedMessage({ + request: { + sessionKey, + text, + source: "extension", + metadata: isRecord(body.metadata) ? body.metadata : undefined, + durableRun, + }, + }); + if (!queued.accepted) { + this.writeJson(response, 409, queued.errorResult); + return; + } + void queued.completion.catch(() => undefined); + this.writeJson(response, 202, { + ok: true, + queued: true, + sessionKey, + ...(durableRun ? { runId: durableRun.runId } : {}), + }); + return; + } + if (action === "messages" && method === "POST") { const body = await this.readJsonBody(request); const text = typeof body.text === "string" ? body.text : ""; diff --git a/packages/coding-agent/src/core/gateway/types.ts b/packages/coding-agent/src/core/gateway/types.ts index 9aed60d..2311a43 100644 --- a/packages/coding-agent/src/core/gateway/types.ts +++ b/packages/coding-agent/src/core/gateway/types.ts @@ -26,6 +26,11 @@ export interface GatewayMessageRequest { source?: "interactive" | "rpc" | "extension"; images?: ImageContent[]; metadata?: Record; + durableRun?: { + runId: string; + callbackUrl: string; + callbackToken: string; + }; } export interface GatewayMessageResult { diff --git a/packages/coding-agent/test/durable-chat-run.test.ts b/packages/coding-agent/test/durable-chat-run.test.ts new file mode 100644 index 0000000..6a774da --- /dev/null +++ b/packages/coding-agent/test/durable-chat-run.test.ts @@ -0,0 +1,157 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { DurableChatRunReporter } from "../src/core/gateway/durable-chat-run.js"; + +function mockOkResponse() { + return { + ok: true, + status: 200, + text: vi.fn().mockResolvedValue(""), + } as unknown as Response; +} + +describe("DurableChatRunReporter", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("posts assistant state to the relay and completes the run", async () => { + const fetchMock = vi.fn().mockResolvedValue(mockOkResponse()); + vi.stubGlobal("fetch", fetchMock); + + const reporter = new DurableChatRunReporter({ + runId: "run-1", + callbackUrl: "https://web.example/api/chat/runs/run-1/events", + callbackToken: "callback-token", + }); + + const assistantMessage = { + role: "assistant", + timestamp: 111, + content: [{ type: "text", text: "hello from the sandbox" }], + }; + + reporter.handleSessionEvent( + { + type: "message_start", + message: assistantMessage, + } as never, + [], + ); + reporter.handleSessionEvent( + { + type: "message_update", + message: assistantMessage, + assistantMessageEvent: { + type: "text_delta", + delta: "hello from the sandbox", + contentIndex: 0, + }, + } as never, + [], + ); + reporter.handleSessionEvent( + { + type: "tool_execution_end", + toolCallId: "call-1", + toolName: "bash", + result: { stdout: "done" }, + isError: false, + } as never, + [ + { + toolCallId: "call-1", + toolName: "bash", + result: { stdout: "done" }, + isError: false, + timestamp: 222, + }, + ], + ); + + await reporter.finalize({ + ok: true, + response: "hello from the sandbox", + sessionKey: "session-1", + }); + + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[0]?.[0]).toBe( + "https://web.example/api/chat/runs/run-1/events", + ); + expect(fetchMock.mock.calls[1]?.[0]).toBe( + "https://web.example/api/chat/runs/run-1/events", + ); + + expect(fetchMock.mock.calls[0]?.[1]?.headers).toMatchObject({ + Authorization: "Bearer callback-token", + "Content-Type": "application/json", + }); + + const runMessagesCall = fetchMock.mock.calls.find((call) => + String(call[1]?.body).includes('"items"'), + ); + const runMessagesBody = JSON.parse(String(runMessagesCall?.[1]?.body)) as { + items: Array<{ + idempotencyKey: string; + partsJson: string; + }>; + }; + expect(runMessagesBody.items).toHaveLength(1); + expect(runMessagesBody.items[0]).toMatchObject({ + idempotencyKey: "run:run-1:assistant", + }); + expect(JSON.parse(runMessagesBody.items[0]?.partsJson ?? "[]")).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: "text", + text: "hello from the sandbox", + }), + expect.objectContaining({ + type: "tool-invocation", + toolCallId: "call-1", + toolName: "bash", + state: "result", + result: { stdout: "done" }, + }), + ]), + ); + + expect( + JSON.parse(String(fetchMock.mock.calls[1]?.[1]?.body)), + ).toMatchObject({ + final: { + status: "completed", + }, + }); + }); + + it("marks aborted runs as interrupted", async () => { + const fetchMock = vi.fn().mockResolvedValue(mockOkResponse()); + vi.stubGlobal("fetch", fetchMock); + + const reporter = new DurableChatRunReporter({ + runId: "run-2", + callbackUrl: "https://web.example/api/chat/runs/run-2/events", + callbackToken: "callback-token", + }); + + await reporter.finalize({ + ok: false, + response: "", + error: "Session aborted", + sessionKey: "session-1", + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0]?.[0]).toBe( + "https://web.example/api/chat/runs/run-2/events", + ); + expect( + JSON.parse(String(fetchMock.mock.calls[0]?.[1]?.body)), + ).toMatchObject({ + final: { + status: "interrupted", + }, + }); + }); +}); diff --git a/packages/coding-agent/test/gateway-session-titles.test.ts b/packages/coding-agent/test/gateway-session-titles.test.ts index 9846f19..2573e2f 100644 --- a/packages/coding-agent/test/gateway-session-titles.test.ts +++ b/packages/coding-agent/test/gateway-session-titles.test.ts @@ -49,6 +49,7 @@ function addManagedSession( session, queue: [], processing: false, + activeDurableRun: null, activeAssistantMessage: null, pendingToolResults: [], createdAt: Date.now(), diff --git a/packages/coding-agent/test/gateway-steer.test.ts b/packages/coding-agent/test/gateway-steer.test.ts index 19d2d29..0a63ca4 100644 --- a/packages/coding-agent/test/gateway-steer.test.ts +++ b/packages/coding-agent/test/gateway-steer.test.ts @@ -49,6 +49,7 @@ function addManagedSession( session: session as never, queue: [], processing, + activeDurableRun: null, activeAssistantMessage: null, pendingToolResults: [], createdAt: Date.now(),