This commit is contained in:
Harivansh Rathi 2026-03-06 10:05:58 -08:00
parent ca0861400d
commit 5a2172fb9d
4 changed files with 98 additions and 42 deletions

3
.gitignore vendored
View file

@ -33,3 +33,6 @@ pi-*.html
out.html out.html
packages/coding-agent/binaries/ packages/coding-agent/binaries/
todo.md todo.md
# Riptide artifacts (cloud-synced)
.humanlayer/tasks/

View file

@ -4,7 +4,12 @@ import { URL } from "node:url";
import type { ImageContent } from "@mariozechner/pi-ai"; import type { ImageContent } from "@mariozechner/pi-ai";
import type { AgentSession, AgentSessionEvent } from "./agent-session.js"; import type { AgentSession, AgentSessionEvent } from "./agent-session.js";
import { SessionManager } from "./session-manager.js"; import { SessionManager } from "./session-manager.js";
import { createVercelStreamListener, errorVercelStream, extractUserText, finishVercelStream } from "./vercel-ai-stream.js"; import {
createVercelStreamListener,
errorVercelStream,
extractUserText,
finishVercelStream,
} from "./vercel-ai-stream.js";
export interface GatewayConfig { export interface GatewayConfig {
bind: string; bind: string;
@ -59,6 +64,8 @@ export interface GatewayRuntimeOptions {
interface GatewayQueuedMessage { interface GatewayQueuedMessage {
request: GatewayMessageRequest; request: GatewayMessageRequest;
resolve: (result: GatewayMessageResult) => void; resolve: (result: GatewayMessageResult) => void;
onStart?: () => void;
onFinish?: () => void;
} }
type GatewayEvent = type GatewayEvent =
@ -186,18 +193,26 @@ export class GatewayRuntime {
} }
async enqueueMessage(request: GatewayMessageRequest): Promise<GatewayMessageResult> { async enqueueMessage(request: GatewayMessageRequest): Promise<GatewayMessageResult> {
const managedSession = await this.ensureSession(request.sessionKey); return this.enqueueManagedMessage({ request });
}
private async enqueueManagedMessage(queuedMessage: {
request: GatewayMessageRequest;
onStart?: () => void;
onFinish?: () => void;
}): Promise<GatewayMessageResult> {
const managedSession = await this.ensureSession(queuedMessage.request.sessionKey);
if (managedSession.queue.length >= this.config.session.maxQueuePerSession) { if (managedSession.queue.length >= this.config.session.maxQueuePerSession) {
return { return {
ok: false, ok: false,
response: "", response: "",
error: `Queue full (${this.config.session.maxQueuePerSession} pending).`, error: `Queue full (${this.config.session.maxQueuePerSession} pending).`,
sessionKey: request.sessionKey, sessionKey: queuedMessage.request.sessionKey,
}; };
} }
return new Promise<GatewayMessageResult>((resolve) => { return new Promise<GatewayMessageResult>((resolve) => {
managedSession.queue.push({ request, resolve }); managedSession.queue.push({ ...queuedMessage, resolve });
this.emitState(managedSession); this.emitState(managedSession);
void this.processNext(managedSession); void this.processNext(managedSession);
}); });
@ -303,6 +318,7 @@ export class GatewayRuntime {
this.emitState(managedSession); this.emitState(managedSession);
try { try {
queued.onStart?.();
await managedSession.session.prompt(queued.request.text, { await managedSession.session.prompt(queued.request.text, {
images: queued.request.images, images: queued.request.images,
source: queued.request.source ?? "extension", source: queued.request.source ?? "extension",
@ -327,6 +343,7 @@ export class GatewayRuntime {
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); });
} finally { } finally {
queued.onFinish?.();
managedSession.processing = false; managedSession.processing = false;
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
this.emitState(managedSession); this.emitState(managedSession);
@ -593,11 +610,7 @@ export class GatewayRuntime {
}); });
} }
private async handleChat( private async handleChat(sessionKey: string, request: IncomingMessage, response: ServerResponse): Promise<void> {
sessionKey: string,
request: IncomingMessage,
response: ServerResponse,
): Promise<void> {
const body = await this.readJsonBody(request); const body = await this.readJsonBody(request);
const text = extractUserText(body); const text = extractUserText(body);
if (!text) { if (!text) {
@ -614,27 +627,44 @@ export class GatewayRuntime {
}); });
response.write("\n"); response.write("\n");
// Subscribe to session events for Vercel AI SDK translation
const managedSession = await this.ensureSession(sessionKey); const managedSession = await this.ensureSession(sessionKey);
const listener = createVercelStreamListener(response); const listener = createVercelStreamListener(response);
const unsubscribe = managedSession.session.subscribe(listener); let unsubscribe: (() => void) | undefined;
let streamingActive = false;
const stopStreaming = () => {
if (!streamingActive) return;
streamingActive = false;
unsubscribe?.();
unsubscribe = undefined;
};
// Clean up on client disconnect // Clean up on client disconnect
let clientDisconnected = false; let clientDisconnected = false;
request.on("close", () => { request.on("close", () => {
clientDisconnected = true; clientDisconnected = true;
unsubscribe(); stopStreaming();
}); });
// Drive the session through the existing queue infrastructure // Drive the session through the existing queue infrastructure
try { try {
const result = await this.enqueueMessage({ const result = await this.enqueueManagedMessage({
sessionKey, request: {
text, sessionKey,
source: "extension", text,
source: "extension",
},
onStart: () => {
if (clientDisconnected || streamingActive) return;
unsubscribe = managedSession.session.subscribe(listener);
streamingActive = true;
},
onFinish: () => {
stopStreaming();
},
}); });
if (!clientDisconnected) { if (!clientDisconnected) {
unsubscribe(); stopStreaming();
if (result.ok) { if (result.ok) {
finishVercelStream(response, "stop"); finishVercelStream(response, "stop");
} else { } else {
@ -648,7 +678,7 @@ export class GatewayRuntime {
} }
} catch (error) { } catch (error) {
if (!clientDisconnected) { if (!clientDisconnected) {
unsubscribe(); stopStreaming();
const message = error instanceof Error ? error.message : String(error); const message = error instanceof Error ? error.message : String(error);
errorVercelStream(response, message); errorVercelStream(response, message);
} }

View file

@ -58,20 +58,32 @@ export function createVercelStreamListener(
response: ServerResponse, response: ServerResponse,
messageId?: string, messageId?: string,
): (event: AgentSessionEvent) => void { ): (event: AgentSessionEvent) => void {
let started = false; // Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
// handleChat now subscribes this listener immediately before the queued prompt starts,
// so these guards only need to bound the stream to that prompt's event span.
let active = false;
const msgId = messageId ?? randomUUID(); const msgId = messageId ?? randomUUID();
return (event: AgentSessionEvent) => { return (event: AgentSessionEvent) => {
if (response.writableEnded) return; if (response.writableEnded) return;
switch (event.type) { // Activate on our agent_start, deactivate on agent_end
case "agent_start": if (event.type === "agent_start") {
if (!started) { if (!active) {
writeChunk(response, { type: "start", messageId: msgId }); active = true;
started = true; writeChunk(response, { type: "start", messageId: msgId });
} }
return; return;
}
if (event.type === "agent_end") {
active = false;
return;
}
// Drop events that don't belong to our message
if (!active) return;
switch (event.type) {
case "turn_start": case "turn_start":
writeChunk(response, { type: "start-step" }); writeChunk(response, { type: "start-step" });
return; return;
@ -169,10 +181,7 @@ export function createVercelStreamListener(
/** /**
* Write the terminal finish sequence and end the response. * Write the terminal finish sequence and end the response.
*/ */
export function finishVercelStream( export function finishVercelStream(response: ServerResponse, finishReason: string = "stop"): void {
response: ServerResponse,
finishReason: string = "stop",
): void {
if (response.writableEnded) return; if (response.writableEnded) return;
writeChunk(response, { type: "finish", finishReason }); writeChunk(response, { type: "finish", finishReason });
writeChunk(response, "[DONE]"); writeChunk(response, "[DONE]");
@ -182,10 +191,7 @@ export function finishVercelStream(
/** /**
* Write an error chunk and end the response. * Write an error chunk and end the response.
*/ */
export function errorVercelStream( export function errorVercelStream(response: ServerResponse, errorText: string): void {
response: ServerResponse,
errorText: string,
): void {
if (response.writableEnded) return; if (response.writableEnded) return;
writeChunk(response, { type: "error", errorText }); writeChunk(response, { type: "error", errorText });
writeChunk(response, "[DONE]"); writeChunk(response, "[DONE]");

View file

@ -1,13 +1,11 @@
import { describe, it, expect } from "vitest"; import { describe, expect, it } from "vitest";
import type { AgentSessionEvent } from "../src/core/agent-session.js"; import type { AgentSessionEvent } from "../src/core/agent-session.js";
import { extractUserText, createVercelStreamListener, finishVercelStream } from "../src/core/vercel-ai-stream.js"; import { createVercelStreamListener, extractUserText } from "../src/core/vercel-ai-stream.js";
describe("extractUserText", () => { describe("extractUserText", () => {
it("extracts text from useChat v5+ format with parts", () => { it("extracts text from useChat v5+ format with parts", () => {
const body = { const body = {
messages: [ messages: [{ role: "user", parts: [{ type: "text", text: "hello world" }] }],
{ role: "user", parts: [{ type: "text", text: "hello world" }] },
],
}; };
expect(extractUserText(body)).toBe("hello world"); expect(extractUserText(body)).toBe("hello world");
}); });
@ -70,7 +68,9 @@ describe("createVercelStreamListener", () => {
this.writableEnded = true; this.writableEnded = true;
}, },
chunks, chunks,
get ended() { return ended; }, get ended() {
return ended;
},
} as any; } as any;
} }
@ -79,8 +79,11 @@ describe("createVercelStreamListener", () => {
.filter((c) => c.startsWith("data: ")) .filter((c) => c.startsWith("data: "))
.map((c) => { .map((c) => {
const payload = c.replace(/^data: /, "").replace(/\n\n$/, ""); const payload = c.replace(/^data: /, "").replace(/\n\n$/, "");
try { return JSON.parse(payload); } try {
catch { return payload; } return JSON.parse(payload);
} catch {
return payload;
}
}); });
} }
@ -129,4 +132,18 @@ describe("createVercelStreamListener", () => {
const parsed = parseChunks(response.chunks); const parsed = parseChunks(response.chunks);
expect(parsed).toEqual([{ type: "start", messageId: "test-msg-id" }]); expect(parsed).toEqual([{ type: "start", messageId: "test-msg-id" }]);
}); });
it("ignores events outside the active prompt lifecycle", () => {
const response = createMockResponse();
const listener = createVercelStreamListener(response, "test-msg-id");
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
listener({ type: "agent_start" } as AgentSessionEvent);
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
listener({ type: "agent_end", messages: [] } as AgentSessionEvent);
listener({ type: "turn_start", turnIndex: 1, timestamp: Date.now() } as AgentSessionEvent);
const parsed = parseChunks(response.chunks);
expect(parsed).toEqual([{ type: "start", messageId: "test-msg-id" }, { type: "start-step" }]);
});
}); });