Merge pull request #325 from getcompanion-ai/chat-timeout

fix chat timeout
This commit is contained in:
Hari 2026-03-14 14:57:53 -04:00 committed by GitHub
commit e59144d2b8
7 changed files with 535 additions and 6 deletions

View file

@ -0,0 +1,246 @@
import type { AgentMessage } from "@mariozechner/companion-agent-core";
import type { AgentSessionEvent } from "../agent-session.js";
import { extractMessageText } from "./helpers.js";
import { messageContentToHistoryParts } from "./session-state.js";
import type { GatewayTransientToolResult } from "./session-state.js";
import type { GatewayMessageResult, GatewayMessageRequest } from "./types.js";
const FLUSH_INTERVAL_MS = 500;
type PersistHistoryItem = {
role: "user" | "assistant" | "toolResult";
text?: string;
partsJson: string;
timestamp: number;
idempotencyKey: string;
};
type ConvexRunStatus = "completed" | "failed" | "interrupted";
function normalizeErrorMessage(error: unknown): string {
if (error instanceof Error) {
return error.message;
}
return typeof error === "string" ? error : String(error);
}
type DurableChatRunEventBody =
| {
items: PersistHistoryItem[];
final?: {
status: ConvexRunStatus;
error?: string;
};
}
| {
items?: PersistHistoryItem[];
final: {
status: ConvexRunStatus;
error?: string;
};
};
function buildAuthHeaders(token: string): Record<string, string> {
return {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
};
}
export class DurableChatRunReporter {
private readonly assistantMessageId: string;
private latestAssistantMessage: AgentMessage | null = null;
private readonly knownToolResults = new Map<
string,
GatewayTransientToolResult
>();
private flushTimer: ReturnType<typeof globalThis.setTimeout> | null = null;
private flushChain: Promise<void> = Promise.resolve();
private flushFailure: Error | null = null;
constructor(
private readonly durableRun: NonNullable<
GatewayMessageRequest["durableRun"]
>,
) {
if (
durableRun.callbackUrl.trim().length === 0 ||
durableRun.callbackToken.trim().length === 0
) {
throw new Error(
"Durable chat run reporting requires callbackUrl and callbackToken",
);
}
this.assistantMessageId = `run:${this.durableRun.runId}:assistant`;
}
handleSessionEvent(
event: AgentSessionEvent,
pendingToolResults: GatewayTransientToolResult[],
): void {
for (const toolResult of pendingToolResults) {
this.knownToolResults.set(toolResult.toolCallId, toolResult);
}
if (event.type === "message_start" && event.message.role === "assistant") {
this.latestAssistantMessage = event.message;
return;
}
if (event.type === "message_update" && event.message.role === "assistant") {
this.latestAssistantMessage = event.message;
this.scheduleFlush();
return;
}
if (event.type === "message_end" && event.message.role === "assistant") {
this.latestAssistantMessage = event.message;
this.scheduleFlush();
return;
}
if (
event.type === "tool_execution_end" ||
event.type === "turn_end" ||
(event.type === "message_end" && event.message.role === "toolResult")
) {
this.scheduleFlush();
}
}
async finalize(result: GatewayMessageResult): Promise<void> {
let status: ConvexRunStatus = result.ok
? "completed"
: result.error?.includes("aborted")
? "interrupted"
: "failed";
let errorMessage = result.error;
try {
await this.finalFlush();
} catch (error) {
status = "failed";
errorMessage = normalizeErrorMessage(error);
}
await this.postEvent({
final: {
status,
...(status === "failed" && errorMessage ? { error: errorMessage } : {}),
},
});
}
private scheduleFlush(): void {
if (this.flushTimer) return;
this.flushTimer = globalThis.setTimeout(() => {
this.flushTimer = null;
void this.flush().catch(() => undefined);
}, FLUSH_INTERVAL_MS);
}
private async flush(): Promise<void> {
this.throwIfFlushFailed();
if (this.flushTimer) {
globalThis.clearTimeout(this.flushTimer);
this.flushTimer = null;
}
const items = this.buildItems();
if (items.length === 0) {
return;
}
const flushPromise = this.flushChain.then(async () => {
this.throwIfFlushFailed();
await this.postEvent({
items,
});
});
this.flushChain = flushPromise.catch(() => undefined);
try {
await flushPromise;
} catch (error) {
throw this.markFlushFailed(error);
}
}
private async finalFlush(): Promise<void> {
await this.flush();
await this.flushChain;
this.throwIfFlushFailed();
}
private buildItems(): PersistHistoryItem[] {
const assistantParts =
this.latestAssistantMessage?.role === "assistant"
? messageContentToHistoryParts(this.latestAssistantMessage)
: [];
for (const toolResult of this.knownToolResults.values()) {
assistantParts.push({
type: "tool-invocation",
toolCallId: toolResult.toolCallId,
toolName: toolResult.toolName,
args: undefined,
state: toolResult.isError ? "error" : "result",
result: toolResult.result,
});
}
const firstToolResult = this.knownToolResults.values().next().value;
if (
this.latestAssistantMessage?.role === "assistant" ||
assistantParts.length > 0
) {
return [
{
role: "assistant",
text:
this.latestAssistantMessage?.role === "assistant"
? extractMessageText(this.latestAssistantMessage) || undefined
: undefined,
partsJson: JSON.stringify(assistantParts),
timestamp:
this.latestAssistantMessage?.timestamp ??
firstToolResult?.timestamp ??
Date.now(),
idempotencyKey: this.assistantMessageId,
},
];
}
return [];
}
private async postEvent(body: DurableChatRunEventBody): Promise<void> {
const response = await fetch(this.durableRun.callbackUrl, {
method: "POST",
headers: buildAuthHeaders(this.durableRun.callbackToken),
body: JSON.stringify(body),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`Chat run relay failed: ${response.status} ${text}`.trim(),
);
}
}
private throwIfFlushFailed(): void {
if (this.flushFailure) {
throw this.flushFailure;
}
}
private markFlushFailed(error: unknown): Error {
if (this.flushFailure) {
return this.flushFailure;
}
const normalizedError =
error instanceof Error ? error : new Error(normalizeErrorMessage(error));
this.flushFailure = normalizedError;
return normalizedError;
}
}

View file

@ -1,5 +1,6 @@
import type { AgentMessage } from "@mariozechner/companion-agent-core"; import type { AgentMessage } from "@mariozechner/companion-agent-core";
import type { AgentSession } from "../agent-session.js"; import type { AgentSession } from "../agent-session.js";
import type { DurableChatRunReporter } from "./durable-chat-run.js";
import type { import type {
GatewayMessageRequest, GatewayMessageRequest,
GatewayMessageResult, GatewayMessageResult,
@ -63,7 +64,13 @@ export type GatewayEvent =
payload: { payload: {
teamId: string; teamId: string;
status: string; status: string;
members: Array<{ id: string; name: string; role?: string; status: string; message?: string }>; members: Array<{
id: string;
name: string;
role?: string;
status: string;
message?: string;
}>;
}; };
} }
| { | {
@ -84,6 +91,7 @@ export interface ManagedGatewaySession {
session: AgentSession; session: AgentSession;
queue: GatewayQueuedMessage[]; queue: GatewayQueuedMessage[];
processing: boolean; processing: boolean;
activeDurableRun: DurableChatRunReporter | null;
activeAssistantMessage: AgentMessage | null; activeAssistantMessage: AgentMessage | null;
pendingToolResults: GatewayTransientToolResult[]; pendingToolResults: GatewayTransientToolResult[];
createdAt: number; createdAt: number;

View file

@ -10,6 +10,7 @@ import { URL } from "node:url";
import type { AgentMessage } from "@mariozechner/companion-agent-core"; import type { AgentMessage } from "@mariozechner/companion-agent-core";
import type { AgentSession, AgentSessionEvent } from "../agent-session.js"; import type { AgentSession, AgentSessionEvent } from "../agent-session.js";
import type { Settings } from "../settings-manager.js"; import type { Settings } from "../settings-manager.js";
import { DurableChatRunReporter } from "./durable-chat-run.js";
import { extractMessageText, getLastAssistantText } from "./helpers.js"; import { extractMessageText, getLastAssistantText } from "./helpers.js";
import { import {
type GatewayEvent, type GatewayEvent,
@ -108,6 +109,28 @@ function readString(value: unknown): string | undefined {
return trimmed.length > 0 ? trimmed : undefined; return trimmed.length > 0 ? trimmed : undefined;
} }
function readDurableRun(
value: unknown,
): GatewayMessageRequest["durableRun"] | undefined {
if (!isRecord(value)) {
return undefined;
}
const runId = readString(value.runId);
const callbackUrl = readString(value.callbackUrl);
const callbackToken = readString(value.callbackToken);
if (!runId || !callbackUrl || !callbackToken) {
return undefined;
}
return {
runId,
callbackUrl,
callbackToken,
};
}
export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void { export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void {
activeGatewayRuntime = runtime; activeGatewayRuntime = runtime;
} }
@ -419,6 +442,7 @@ export class GatewayRuntime {
session, session,
queue: [], queue: [],
processing: false, processing: false,
activeDurableRun: null,
activeAssistantMessage: null, activeAssistantMessage: null,
pendingToolResults: [], pendingToolResults: [],
createdAt: Date.now(), createdAt: Date.now(),
@ -462,18 +486,32 @@ export class GatewayRuntime {
); );
this.emitState(managedSession); this.emitState(managedSession);
let result: GatewayMessageResult = {
ok: false,
response: "",
error: "Unknown error",
sessionKey: managedSession.sessionKey,
};
let durableRunReporter: DurableChatRunReporter | null = null;
try { try {
queued.onStart?.(); queued.onStart?.();
if (queued.request.durableRun) {
durableRunReporter = new DurableChatRunReporter(
queued.request.durableRun,
);
managedSession.activeDurableRun = durableRunReporter;
}
await managedSession.session.prompt(queued.request.text, { await managedSession.session.prompt(queued.request.text, {
images: queued.request.images, images: queued.request.images,
source: queued.request.source ?? "extension", source: queued.request.source ?? "extension",
}); });
const response = getLastAssistantText(managedSession.session); const response = getLastAssistantText(managedSession.session);
queued.resolve({ result = {
ok: true, ok: true,
response, response,
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); };
} catch (error) { } catch (error) {
const message = error instanceof Error ? error.message : String(error); const message = error instanceof Error ? error.message : String(error);
this.log( this.log(
@ -491,15 +529,39 @@ export class GatewayRuntime {
error: message, error: message,
}); });
} }
queued.resolve({ result = {
ok: false, ok: false,
response: "", response: "",
error: message, error: message,
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); };
} finally { } finally {
queued.onFinish?.(); queued.onFinish?.();
if (durableRunReporter) {
try {
await durableRunReporter.finalize(result);
} catch (error) {
const message =
error instanceof Error ? error.message : String(error);
this.log(
`[chat-run] session=${managedSession.sessionKey} finalize error: ${message}`,
);
this.emit(managedSession, {
type: "error",
sessionKey: managedSession.sessionKey,
error: message,
});
result = {
ok: false,
response: result.response,
error: message,
sessionKey: managedSession.sessionKey,
};
}
}
queued.resolve(result);
managedSession.processing = false; managedSession.processing = false;
managedSession.activeDurableRun = null;
managedSession.activeAssistantMessage = null; managedSession.activeAssistantMessage = null;
managedSession.pendingToolResults = []; managedSession.pendingToolResults = [];
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -529,6 +591,13 @@ export class GatewayRuntime {
managedSession: ManagedGatewaySession, managedSession: ManagedGatewaySession,
event: AgentSessionEvent, event: AgentSessionEvent,
): void { ): void {
const forwardToDurableRun = () => {
managedSession.activeDurableRun?.handleSessionEvent(
event,
managedSession.pendingToolResults,
);
};
switch (event.type) { switch (event.type) {
case "turn_start": case "turn_start":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -537,6 +606,7 @@ export class GatewayRuntime {
type: "turn_start", type: "turn_start",
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); });
forwardToDurableRun();
return; return;
case "turn_end": case "turn_end":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -545,6 +615,7 @@ export class GatewayRuntime {
type: "turn_end", type: "turn_end",
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); });
forwardToDurableRun();
return; return;
case "message_start": case "message_start":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -556,6 +627,7 @@ export class GatewayRuntime {
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
role: event.message.role, role: event.message.role,
}); });
forwardToDurableRun();
return; return;
case "message_update": case "message_update":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -570,6 +642,7 @@ export class GatewayRuntime {
delta: event.assistantMessageEvent.delta, delta: event.assistantMessageEvent.delta,
contentIndex: event.assistantMessageEvent.contentIndex, contentIndex: event.assistantMessageEvent.contentIndex,
}); });
forwardToDurableRun();
return; return;
case "thinking_delta": case "thinking_delta":
this.emit(managedSession, { this.emit(managedSession, {
@ -578,8 +651,10 @@ export class GatewayRuntime {
delta: event.assistantMessageEvent.delta, delta: event.assistantMessageEvent.delta,
contentIndex: event.assistantMessageEvent.contentIndex, contentIndex: event.assistantMessageEvent.contentIndex,
}); });
forwardToDurableRun();
return; return;
} }
forwardToDurableRun();
return; return;
case "message_end": case "message_end":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -595,6 +670,7 @@ export class GatewayRuntime {
text: extractMessageText(event.message), text: extractMessageText(event.message),
}); });
this.emitStructuredParts(managedSession, event.message); this.emitStructuredParts(managedSession, event.message);
forwardToDurableRun();
return; return;
} }
if (event.message.role === "toolResult") { if (event.message.role === "toolResult") {
@ -610,6 +686,7 @@ export class GatewayRuntime {
); );
} }
} }
forwardToDurableRun();
return; return;
case "tool_execution_start": case "tool_execution_start":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -624,6 +701,7 @@ export class GatewayRuntime {
toolName: event.toolName, toolName: event.toolName,
args: event.args, args: event.args,
}); });
forwardToDurableRun();
return; return;
case "tool_execution_update": case "tool_execution_update":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -634,6 +712,7 @@ export class GatewayRuntime {
toolName: event.toolName, toolName: event.toolName,
partialResult: event.partialResult, partialResult: event.partialResult,
}); });
forwardToDurableRun();
return; return;
case "tool_execution_end": case "tool_execution_end":
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
@ -661,6 +740,7 @@ export class GatewayRuntime {
result: event.result, result: event.result,
isError: event.isError, isError: event.isError,
}); });
forwardToDurableRun();
return; return;
} }
} }
@ -1030,7 +1110,7 @@ export class GatewayRuntime {
} }
const sessionMatch = path.match( const sessionMatch = path.match(
/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload|state|steer))?$/, /^\/sessions\/([^/]+)(?:\/(enqueue|events|messages|abort|reset|chat|history|model|reload|state|steer))?$/,
); );
if (!sessionMatch) { if (!sessionMatch) {
this.writeJson(response, 404, { error: "Not found" }); this.writeJson(response, 404, { error: "Not found" });
@ -1069,6 +1149,37 @@ export class GatewayRuntime {
return; return;
} }
if (action === "enqueue" && method === "POST") {
const body = await this.readJsonBody(request);
const text = extractUserText(body);
if (!text) {
this.writeJson(response, 400, { error: "Missing user message text" });
return;
}
const durableRun = readDurableRun(body.durableRun);
const queued = await this.queueManagedMessage({
request: {
sessionKey,
text,
source: "extension",
metadata: isRecord(body.metadata) ? body.metadata : undefined,
durableRun,
},
});
if (!queued.accepted) {
this.writeJson(response, 409, queued.errorResult);
return;
}
void queued.completion.catch(() => undefined);
this.writeJson(response, 202, {
ok: true,
queued: true,
sessionKey,
...(durableRun ? { runId: durableRun.runId } : {}),
});
return;
}
if (action === "messages" && method === "POST") { if (action === "messages" && method === "POST") {
const body = await this.readJsonBody(request); const body = await this.readJsonBody(request);
const text = typeof body.text === "string" ? body.text : ""; const text = typeof body.text === "string" ? body.text : "";

View file

@ -26,6 +26,11 @@ export interface GatewayMessageRequest {
source?: "interactive" | "rpc" | "extension"; source?: "interactive" | "rpc" | "extension";
images?: ImageContent[]; images?: ImageContent[];
metadata?: Record<string, unknown>; metadata?: Record<string, unknown>;
durableRun?: {
runId: string;
callbackUrl: string;
callbackToken: string;
};
} }
export interface GatewayMessageResult { export interface GatewayMessageResult {

View file

@ -0,0 +1,157 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { DurableChatRunReporter } from "../src/core/gateway/durable-chat-run.js";
function mockOkResponse() {
return {
ok: true,
status: 200,
text: vi.fn().mockResolvedValue(""),
} as unknown as Response;
}
describe("DurableChatRunReporter", () => {
afterEach(() => {
vi.restoreAllMocks();
});
it("posts assistant state to the relay and completes the run", async () => {
const fetchMock = vi.fn<typeof fetch>().mockResolvedValue(mockOkResponse());
vi.stubGlobal("fetch", fetchMock);
const reporter = new DurableChatRunReporter({
runId: "run-1",
callbackUrl: "https://web.example/api/chat/runs/run-1/events",
callbackToken: "callback-token",
});
const assistantMessage = {
role: "assistant",
timestamp: 111,
content: [{ type: "text", text: "hello from the sandbox" }],
};
reporter.handleSessionEvent(
{
type: "message_start",
message: assistantMessage,
} as never,
[],
);
reporter.handleSessionEvent(
{
type: "message_update",
message: assistantMessage,
assistantMessageEvent: {
type: "text_delta",
delta: "hello from the sandbox",
contentIndex: 0,
},
} as never,
[],
);
reporter.handleSessionEvent(
{
type: "tool_execution_end",
toolCallId: "call-1",
toolName: "bash",
result: { stdout: "done" },
isError: false,
} as never,
[
{
toolCallId: "call-1",
toolName: "bash",
result: { stdout: "done" },
isError: false,
timestamp: 222,
},
],
);
await reporter.finalize({
ok: true,
response: "hello from the sandbox",
sessionKey: "session-1",
});
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(fetchMock.mock.calls[0]?.[0]).toBe(
"https://web.example/api/chat/runs/run-1/events",
);
expect(fetchMock.mock.calls[1]?.[0]).toBe(
"https://web.example/api/chat/runs/run-1/events",
);
expect(fetchMock.mock.calls[0]?.[1]?.headers).toMatchObject({
Authorization: "Bearer callback-token",
"Content-Type": "application/json",
});
const runMessagesCall = fetchMock.mock.calls.find((call) =>
String(call[1]?.body).includes('"items"'),
);
const runMessagesBody = JSON.parse(String(runMessagesCall?.[1]?.body)) as {
items: Array<{
idempotencyKey: string;
partsJson: string;
}>;
};
expect(runMessagesBody.items).toHaveLength(1);
expect(runMessagesBody.items[0]).toMatchObject({
idempotencyKey: "run:run-1:assistant",
});
expect(JSON.parse(runMessagesBody.items[0]?.partsJson ?? "[]")).toEqual(
expect.arrayContaining([
expect.objectContaining({
type: "text",
text: "hello from the sandbox",
}),
expect.objectContaining({
type: "tool-invocation",
toolCallId: "call-1",
toolName: "bash",
state: "result",
result: { stdout: "done" },
}),
]),
);
expect(
JSON.parse(String(fetchMock.mock.calls[1]?.[1]?.body)),
).toMatchObject({
final: {
status: "completed",
},
});
});
it("marks aborted runs as interrupted", async () => {
const fetchMock = vi.fn<typeof fetch>().mockResolvedValue(mockOkResponse());
vi.stubGlobal("fetch", fetchMock);
const reporter = new DurableChatRunReporter({
runId: "run-2",
callbackUrl: "https://web.example/api/chat/runs/run-2/events",
callbackToken: "callback-token",
});
await reporter.finalize({
ok: false,
response: "",
error: "Session aborted",
sessionKey: "session-1",
});
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock.mock.calls[0]?.[0]).toBe(
"https://web.example/api/chat/runs/run-2/events",
);
expect(
JSON.parse(String(fetchMock.mock.calls[0]?.[1]?.body)),
).toMatchObject({
final: {
status: "interrupted",
},
});
});
});

View file

@ -49,6 +49,7 @@ function addManagedSession(
session, session,
queue: [], queue: [],
processing: false, processing: false,
activeDurableRun: null,
activeAssistantMessage: null, activeAssistantMessage: null,
pendingToolResults: [], pendingToolResults: [],
createdAt: Date.now(), createdAt: Date.now(),

View file

@ -49,6 +49,7 @@ function addManagedSession(
session: session as never, session: session as never,
queue: [], queue: [],
processing, processing,
activeDurableRun: null,
activeAssistantMessage: null, activeAssistantMessage: null,
pendingToolResults: [], pendingToolResults: [],
createdAt: Date.now(), createdAt: Date.now(),