Merge pull request #12 from getcompanion-ai/rathi/vercelaisdk

Add Vercel AI SDK v6 chat endpoint
This commit is contained in:
Hari 2026-03-06 14:57:03 -05:00 committed by GitHub
commit 0c31586efa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 453 additions and 4 deletions

3
.gitignore vendored
View file

@ -33,3 +33,6 @@ pi-*.html
out.html out.html
packages/coding-agent/binaries/ packages/coding-agent/binaries/
todo.md todo.md
# Riptide artifacts (cloud-synced)
.humanlayer/tasks/

View file

@ -4,6 +4,12 @@ import { URL } from "node:url";
import type { ImageContent } from "@mariozechner/pi-ai"; import type { ImageContent } from "@mariozechner/pi-ai";
import type { AgentSession, AgentSessionEvent } from "./agent-session.js"; import type { AgentSession, AgentSessionEvent } from "./agent-session.js";
import { SessionManager } from "./session-manager.js"; import { SessionManager } from "./session-manager.js";
import {
createVercelStreamListener,
errorVercelStream,
extractUserText,
finishVercelStream,
} from "./vercel-ai-stream.js";
export interface GatewayConfig { export interface GatewayConfig {
bind: string; bind: string;
@ -58,6 +64,8 @@ export interface GatewayRuntimeOptions {
interface GatewayQueuedMessage { interface GatewayQueuedMessage {
request: GatewayMessageRequest; request: GatewayMessageRequest;
resolve: (result: GatewayMessageResult) => void; resolve: (result: GatewayMessageResult) => void;
onStart?: () => void;
onFinish?: () => void;
} }
type GatewayEvent = type GatewayEvent =
@ -185,18 +193,26 @@ export class GatewayRuntime {
} }
async enqueueMessage(request: GatewayMessageRequest): Promise<GatewayMessageResult> { async enqueueMessage(request: GatewayMessageRequest): Promise<GatewayMessageResult> {
const managedSession = await this.ensureSession(request.sessionKey); return this.enqueueManagedMessage({ request });
}
private async enqueueManagedMessage(queuedMessage: {
request: GatewayMessageRequest;
onStart?: () => void;
onFinish?: () => void;
}): Promise<GatewayMessageResult> {
const managedSession = await this.ensureSession(queuedMessage.request.sessionKey);
if (managedSession.queue.length >= this.config.session.maxQueuePerSession) { if (managedSession.queue.length >= this.config.session.maxQueuePerSession) {
return { return {
ok: false, ok: false,
response: "", response: "",
error: `Queue full (${this.config.session.maxQueuePerSession} pending).`, error: `Queue full (${this.config.session.maxQueuePerSession} pending).`,
sessionKey: request.sessionKey, sessionKey: queuedMessage.request.sessionKey,
}; };
} }
return new Promise<GatewayMessageResult>((resolve) => { return new Promise<GatewayMessageResult>((resolve) => {
managedSession.queue.push({ request, resolve }); managedSession.queue.push({ ...queuedMessage, resolve });
this.emitState(managedSession); this.emitState(managedSession);
void this.processNext(managedSession); void this.processNext(managedSession);
}); });
@ -302,6 +318,7 @@ export class GatewayRuntime {
this.emitState(managedSession); this.emitState(managedSession);
try { try {
queued.onStart?.();
await managedSession.session.prompt(queued.request.text, { await managedSession.session.prompt(queued.request.text, {
images: queued.request.images, images: queued.request.images,
source: queued.request.source ?? "extension", source: queued.request.source ?? "extension",
@ -326,6 +343,7 @@ export class GatewayRuntime {
sessionKey: managedSession.sessionKey, sessionKey: managedSession.sessionKey,
}); });
} finally { } finally {
queued.onFinish?.();
managedSession.processing = false; managedSession.processing = false;
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
this.emitState(managedSession); this.emitState(managedSession);
@ -491,7 +509,7 @@ export class GatewayRuntime {
return; return;
} }
const sessionMatch = path.match(/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset))?$/); const sessionMatch = path.match(/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat))?$/);
if (!sessionMatch) { if (!sessionMatch) {
this.writeJson(response, 404, { error: "Not found" }); this.writeJson(response, 404, { error: "Not found" });
return; return;
@ -511,6 +529,11 @@ export class GatewayRuntime {
return; return;
} }
if (action === "chat" && method === "POST") {
await this.handleChat(sessionKey, request, response);
return;
}
if (action === "messages" && method === "POST") { if (action === "messages" && method === "POST") {
const body = await this.readJsonBody(request); const body = await this.readJsonBody(request);
const text = typeof body.text === "string" ? body.text : ""; const text = typeof body.text === "string" ? body.text : "";
@ -587,6 +610,81 @@ export class GatewayRuntime {
}); });
} }
private async handleChat(sessionKey: string, request: IncomingMessage, response: ServerResponse): Promise<void> {
const body = await this.readJsonBody(request);
const text = extractUserText(body);
if (!text) {
this.writeJson(response, 400, { error: "Missing user message text" });
return;
}
// Set up SSE response headers
response.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
});
response.write("\n");
const listener = createVercelStreamListener(response);
let unsubscribe: (() => void) | undefined;
let streamingActive = false;
const stopStreaming = () => {
if (!streamingActive) return;
streamingActive = false;
unsubscribe?.();
unsubscribe = undefined;
};
// Clean up on client disconnect
let clientDisconnected = false;
request.on("close", () => {
clientDisconnected = true;
stopStreaming();
});
// Drive the session through the existing queue infrastructure
try {
const managedSession = await this.ensureSession(sessionKey);
const result = await this.enqueueManagedMessage({
request: {
sessionKey,
text,
source: "extension",
},
onStart: () => {
if (clientDisconnected || streamingActive) return;
unsubscribe = managedSession.session.subscribe(listener);
streamingActive = true;
},
onFinish: () => {
stopStreaming();
},
});
if (!clientDisconnected) {
stopStreaming();
if (result.ok) {
finishVercelStream(response, "stop");
} else {
const isAbort = result.error?.includes("aborted");
if (isAbort) {
finishVercelStream(response, "error");
} else {
errorVercelStream(response, result.error ?? "Unknown error");
}
}
}
} catch (error) {
if (!clientDisconnected) {
stopStreaming();
const message = error instanceof Error ? error.message : String(error);
errorVercelStream(response, message);
}
}
}
private requireAuth(request: IncomingMessage, response: ServerResponse): void { private requireAuth(request: IncomingMessage, response: ServerResponse): void {
if (!this.config.bearerToken) { if (!this.config.bearerToken) {
return; return;

View file

@ -0,0 +1,199 @@
import { randomUUID } from "node:crypto";
import type { ServerResponse } from "node:http";
import type { AgentSessionEvent } from "./agent-session.js";
/**
* Write a single Vercel AI SDK v5+ SSE chunk to the response.
* Format: `data: <JSON>\n\n`
* For the terminal [DONE] sentinel: `data: [DONE]\n\n`
*/
function writeChunk(response: ServerResponse, chunk: object | string): void {
if (response.writableEnded) return;
const payload = typeof chunk === "string" ? chunk : JSON.stringify(chunk);
response.write(`data: ${payload}\n\n`);
}
/**
* Extract the user's text from the request body.
* Supports both useChat format ({ messages: UIMessage[] }) and simple gateway format ({ text: string }).
*/
export function extractUserText(body: Record<string, unknown>): string | null {
// Simple gateway format
if (typeof body.text === "string" && body.text.trim()) {
return body.text;
}
// Convenience format
if (typeof body.prompt === "string" && body.prompt.trim()) {
return body.prompt;
}
// Vercel AI SDK useChat format - extract last user message
if (Array.isArray(body.messages)) {
for (let i = body.messages.length - 1; i >= 0; i--) {
const msg = body.messages[i] as Record<string, unknown>;
if (msg.role !== "user") continue;
// v5+ format with parts array
if (Array.isArray(msg.parts)) {
for (const part of msg.parts as Array<Record<string, unknown>>) {
if (part.type === "text" && typeof part.text === "string") {
return part.text;
}
}
}
// v4 format with content string
if (typeof msg.content === "string" && msg.content.trim()) {
return msg.content;
}
}
}
return null;
}
/**
* Create an AgentSessionEvent listener that translates events to Vercel AI SDK v5+ SSE
* chunks and writes them to the HTTP response.
*
* Returns the listener function. The caller is responsible for subscribing/unsubscribing.
*/
export function createVercelStreamListener(
response: ServerResponse,
messageId?: string,
): (event: AgentSessionEvent) => void {
// Gate: only forward events within a single prompt's agent_start -> agent_end lifecycle.
// handleChat now subscribes this listener immediately before the queued prompt starts,
// so these guards only need to bound the stream to that prompt's event span.
let active = false;
const msgId = messageId ?? randomUUID();
return (event: AgentSessionEvent) => {
if (response.writableEnded) return;
// Activate on our agent_start, deactivate on agent_end
if (event.type === "agent_start") {
if (!active) {
active = true;
writeChunk(response, { type: "start", messageId: msgId });
}
return;
}
if (event.type === "agent_end") {
active = false;
return;
}
// Drop events that don't belong to our message
if (!active) return;
switch (event.type) {
case "turn_start":
writeChunk(response, { type: "start-step" });
return;
case "message_update": {
const inner = event.assistantMessageEvent;
switch (inner.type) {
case "text_start":
writeChunk(response, {
type: "text-start",
id: `text_${inner.contentIndex}`,
});
return;
case "text_delta":
writeChunk(response, {
type: "text-delta",
id: `text_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "text_end":
writeChunk(response, {
type: "text-end",
id: `text_${inner.contentIndex}`,
});
return;
case "toolcall_start": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-start",
toolCallId: content.id,
toolName: content.name,
});
}
return;
}
case "toolcall_delta": {
const content = inner.partial.content[inner.contentIndex];
if (content?.type === "toolCall") {
writeChunk(response, {
type: "tool-input-delta",
toolCallId: content.id,
inputTextDelta: inner.delta,
});
}
return;
}
case "toolcall_end":
writeChunk(response, {
type: "tool-input-available",
toolCallId: inner.toolCall.id,
toolName: inner.toolCall.name,
input: inner.toolCall.arguments,
});
return;
case "thinking_start":
writeChunk(response, {
type: "reasoning-start",
id: `reasoning_${inner.contentIndex}`,
});
return;
case "thinking_delta":
writeChunk(response, {
type: "reasoning-delta",
id: `reasoning_${inner.contentIndex}`,
delta: inner.delta,
});
return;
case "thinking_end":
writeChunk(response, {
type: "reasoning-end",
id: `reasoning_${inner.contentIndex}`,
});
return;
}
return;
}
case "turn_end":
writeChunk(response, { type: "finish-step" });
return;
case "tool_execution_end":
writeChunk(response, {
type: "tool-output-available",
toolCallId: event.toolCallId,
output: event.result,
});
return;
}
};
}
/**
* Write the terminal finish sequence and end the response.
*/
export function finishVercelStream(response: ServerResponse, finishReason: string = "stop"): void {
if (response.writableEnded) return;
writeChunk(response, { type: "finish", finishReason });
writeChunk(response, "[DONE]");
response.end();
}
/**
* Write an error chunk and end the response.
*/
export function errorVercelStream(response: ServerResponse, errorText: string): void {
if (response.writableEnded) return;
writeChunk(response, { type: "error", errorText });
writeChunk(response, "[DONE]");
response.end();
}

View file

@ -0,0 +1,149 @@
import { describe, expect, it } from "vitest";
import type { AgentSessionEvent } from "../src/core/agent-session.js";
import { createVercelStreamListener, extractUserText } from "../src/core/vercel-ai-stream.js";
describe("extractUserText", () => {
it("extracts text from useChat v5+ format with parts", () => {
const body = {
messages: [{ role: "user", parts: [{ type: "text", text: "hello world" }] }],
};
expect(extractUserText(body)).toBe("hello world");
});
it("extracts text from useChat v4 format with content string", () => {
const body = {
messages: [{ role: "user", content: "hello world" }],
};
expect(extractUserText(body)).toBe("hello world");
});
it("extracts last user message when multiple messages present", () => {
const body = {
messages: [
{ role: "user", parts: [{ type: "text", text: "first" }] },
{ role: "assistant", parts: [{ type: "text", text: "response" }] },
{ role: "user", parts: [{ type: "text", text: "second" }] },
],
};
expect(extractUserText(body)).toBe("second");
});
it("extracts text from simple gateway format", () => {
expect(extractUserText({ text: "hello" })).toBe("hello");
});
it("extracts text from prompt format", () => {
expect(extractUserText({ prompt: "hello" })).toBe("hello");
});
it("returns null for empty body", () => {
expect(extractUserText({})).toBeNull();
});
it("returns null for empty messages array", () => {
expect(extractUserText({ messages: [] })).toBeNull();
});
it("prefers text field over messages", () => {
const body = {
text: "direct",
messages: [{ role: "user", parts: [{ type: "text", text: "from messages" }] }],
};
expect(extractUserText(body)).toBe("direct");
});
});
describe("createVercelStreamListener", () => {
function createMockResponse() {
const chunks: string[] = [];
let ended = false;
return {
writableEnded: false,
write(data: string) {
chunks.push(data);
return true;
},
end() {
ended = true;
this.writableEnded = true;
},
chunks,
get ended() {
return ended;
},
} as any;
}
function parseChunks(chunks: string[]): Array<object | string> {
return chunks
.filter((c) => c.startsWith("data: "))
.map((c) => {
const payload = c.replace(/^data: /, "").replace(/\n\n$/, "");
try {
return JSON.parse(payload);
} catch {
return payload;
}
});
}
it("translates text streaming events", () => {
const response = createMockResponse();
const listener = createVercelStreamListener(response, "test-msg-id");
listener({ type: "agent_start" } as AgentSessionEvent);
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
listener({
type: "message_update",
message: {} as any,
assistantMessageEvent: { type: "text_start", contentIndex: 0, partial: {} as any },
} as AgentSessionEvent);
listener({
type: "message_update",
message: {} as any,
assistantMessageEvent: { type: "text_delta", contentIndex: 0, delta: "hello", partial: {} as any },
} as AgentSessionEvent);
listener({
type: "message_update",
message: {} as any,
assistantMessageEvent: { type: "text_end", contentIndex: 0, content: "hello", partial: {} as any },
} as AgentSessionEvent);
listener({ type: "turn_end", turnIndex: 0, message: {} as any, toolResults: [] } as AgentSessionEvent);
const parsed = parseChunks(response.chunks);
expect(parsed).toEqual([
{ type: "start", messageId: "test-msg-id" },
{ type: "start-step" },
{ type: "text-start", id: "text_0" },
{ type: "text-delta", id: "text_0", delta: "hello" },
{ type: "text-end", id: "text_0" },
{ type: "finish-step" },
]);
});
it("does not write after response has ended", () => {
const response = createMockResponse();
const listener = createVercelStreamListener(response, "test-msg-id");
listener({ type: "agent_start" } as AgentSessionEvent);
response.end();
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
const parsed = parseChunks(response.chunks);
expect(parsed).toEqual([{ type: "start", messageId: "test-msg-id" }]);
});
it("ignores events outside the active prompt lifecycle", () => {
const response = createMockResponse();
const listener = createVercelStreamListener(response, "test-msg-id");
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
listener({ type: "agent_start" } as AgentSessionEvent);
listener({ type: "turn_start", turnIndex: 0, timestamp: Date.now() } as AgentSessionEvent);
listener({ type: "agent_end", messages: [] } as AgentSessionEvent);
listener({ type: "turn_start", turnIndex: 1, timestamp: Date.now() } as AgentSessionEvent);
const parsed = parseChunks(response.chunks);
expect(parsed).toEqual([{ type: "start", messageId: "test-msg-id" }, { type: "start-step" }]);
});
});