Merge pull request #270 from getcompanion-ai/chat-single

single chat sot
This commit is contained in:
Hari 2026-03-09 01:54:04 -04:00 committed by GitHub
commit 444ab5820d
6 changed files with 443 additions and 84 deletions

View file

@ -12,6 +12,7 @@ export type {
GatewayMessageResult,
GatewayRuntimeOptions,
GatewaySessionFactory,
GatewaySessionState,
GatewaySessionSnapshot,
HistoryMessage,
HistoryPart,

View file

@ -1,9 +1,11 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AgentSession } from "../agent-session.js";
import type {
GatewayMessageRequest,
GatewayMessageResult,
GatewaySessionSnapshot,
} from "./types.js";
import type { GatewayTransientToolResult } from "./session-state.js";
export interface GatewayQueuedMessage {
request: GatewayMessageRequest;
@ -60,6 +62,8 @@ export interface ManagedGatewaySession {
session: AgentSession;
queue: GatewayQueuedMessage[];
processing: boolean;
activeAssistantMessage: AgentMessage | null;
pendingToolResults: GatewayTransientToolResult[];
createdAt: number;
lastActiveAt: number;
listeners: Set<(event: GatewayEvent) => void>;

View file

@ -7,7 +7,6 @@ import {
import { rm } from "node:fs/promises";
import { join } from "node:path";
import { URL } from "node:url";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AgentSession, AgentSessionEvent } from "../agent-session.js";
import { extractMessageText, getLastAssistantText } from "./helpers.js";
import {
@ -24,6 +23,7 @@ import type {
GatewayMessageResult,
GatewayRuntimeOptions,
GatewaySessionFactory,
GatewaySessionState,
GatewaySessionSnapshot,
HistoryMessage,
HistoryPart,
@ -36,6 +36,10 @@ import {
extractUserText,
finishVercelStream,
} from "./vercel-ai-stream.js";
import {
buildGatewaySessionStateMessages,
messageContentToHistoryParts,
} from "./session-state.js";
export {
createGatewaySessionManager,
sanitizeSessionKey,
@ -47,6 +51,7 @@ export type {
GatewayMessageResult,
GatewayRuntimeOptions,
GatewaySessionFactory,
GatewaySessionState,
GatewaySessionSnapshot,
HistoryMessage,
HistoryPart,
@ -246,6 +251,8 @@ export class GatewayRuntime {
this.rejectQueuedMessages(managedSession, "Session reset");
await managedSession.session.newSession();
managedSession.processing = false;
managedSession.activeAssistantMessage = null;
managedSession.pendingToolResults = [];
managedSession.lastActiveAt = Date.now();
this.emitState(managedSession);
return;
@ -284,6 +291,8 @@ export class GatewayRuntime {
session,
queue: [],
processing: false,
activeAssistantMessage: null,
pendingToolResults: [],
createdAt: Date.now(),
lastActiveAt: Date.now(),
listeners: new Set(),
@ -359,6 +368,8 @@ export class GatewayRuntime {
} finally {
queued.onFinish?.();
managedSession.processing = false;
managedSession.activeAssistantMessage = null;
managedSession.pendingToolResults = [];
managedSession.lastActiveAt = Date.now();
this.emitState(managedSession);
if (managedSession.queue.length > 0) {
@ -396,18 +407,24 @@ export class GatewayRuntime {
): void {
switch (event.type) {
case "turn_start":
managedSession.lastActiveAt = Date.now();
this.emit(managedSession, {
type: "turn_start",
sessionKey: managedSession.sessionKey,
});
return;
case "turn_end":
managedSession.lastActiveAt = Date.now();
this.emit(managedSession, {
type: "turn_end",
sessionKey: managedSession.sessionKey,
});
return;
case "message_start":
managedSession.lastActiveAt = Date.now();
if (event.message.role === "assistant") {
managedSession.activeAssistantMessage = event.message;
}
this.emit(managedSession, {
type: "message_start",
sessionKey: managedSession.sessionKey,
@ -415,6 +432,10 @@ export class GatewayRuntime {
});
return;
case "message_update":
managedSession.lastActiveAt = Date.now();
if (event.message.role === "assistant") {
managedSession.activeAssistantMessage = event.message;
}
switch (event.assistantMessageEvent.type) {
case "text_delta":
this.emit(managedSession, {
@ -435,15 +456,32 @@ export class GatewayRuntime {
}
return;
case "message_end":
managedSession.lastActiveAt = Date.now();
if (event.message.role === "assistant") {
managedSession.activeAssistantMessage = null;
this.emit(managedSession, {
type: "message_complete",
sessionKey: managedSession.sessionKey,
text: extractMessageText(event.message),
});
return;
}
if (event.message.role === "toolResult") {
const toolCallId =
typeof (event.message as { toolCallId?: unknown }).toolCallId ===
"string"
? ((event.message as { toolCallId: string }).toolCallId ?? "")
: "";
if (toolCallId) {
managedSession.pendingToolResults =
managedSession.pendingToolResults.filter(
(entry) => entry.toolCallId !== toolCallId,
);
}
}
return;
case "tool_execution_start":
managedSession.lastActiveAt = Date.now();
this.emit(managedSession, {
type: "tool_start",
sessionKey: managedSession.sessionKey,
@ -453,6 +491,7 @@ export class GatewayRuntime {
});
return;
case "tool_execution_update":
managedSession.lastActiveAt = Date.now();
this.emit(managedSession, {
type: "tool_update",
sessionKey: managedSession.sessionKey,
@ -462,6 +501,19 @@ export class GatewayRuntime {
});
return;
case "tool_execution_end":
managedSession.lastActiveAt = Date.now();
managedSession.pendingToolResults = [
...managedSession.pendingToolResults.filter(
(entry) => entry.toolCallId !== event.toolCallId,
),
{
toolCallId: event.toolCallId,
toolName: event.toolName,
result: event.result,
isError: event.isError,
timestamp: Date.now(),
},
];
this.emit(managedSession, {
type: "tool_complete",
sessionKey: managedSession.sessionKey,
@ -491,6 +543,20 @@ export class GatewayRuntime {
});
}
private createSessionState(
managedSession: ManagedGatewaySession,
): GatewaySessionState {
return {
session: this.createSnapshot(managedSession),
messages: buildGatewaySessionStateMessages({
sessionKey: managedSession.sessionKey,
rawMessages: managedSession.session.messages,
activeAssistantMessage: managedSession.activeAssistantMessage,
pendingToolResults: managedSession.pendingToolResults,
}),
};
}
private createSnapshot(
managedSession: ManagedGatewaySession,
): GatewaySessionSnapshot {
@ -731,7 +797,7 @@ export class GatewayRuntime {
}
const sessionMatch = path.match(
/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload))?$/,
/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload|state))?$/,
);
if (!sessionMatch) {
this.writeJson(response, 404, { error: "Not found" });
@ -809,6 +875,12 @@ export class GatewayRuntime {
return;
}
if (action === "state" && method === "GET") {
const session = await this.ensureSession(sessionKey);
this.writeJson(response, 200, this.createSessionState(session));
return;
}
if (action === "model" && method === "POST") {
const body = await this.readJsonBody(request);
const provider = typeof body.provider === "string" ? body.provider : "";
@ -1080,7 +1152,7 @@ export class GatewayRuntime {
messages.push({
id: `${msg.timestamp}-${msg.role}-${index}`,
role: msg.role,
parts: this.messageContentToParts(msg),
parts: messageContentToHistoryParts(msg),
timestamp: msg.timestamp,
});
}
@ -1173,87 +1245,6 @@ export class GatewayRuntime {
managed.session.settingsManager.reload();
}
private messageContentToParts(msg: AgentMessage): HistoryPart[] {
if (msg.role === "user") {
const content = msg.content;
if (typeof content === "string") {
return [{ type: "text", text: content }];
}
if (Array.isArray(content)) {
return content
.filter(
(c): c is { type: "text"; text: string } =>
typeof c === "object" && c !== null && c.type === "text",
)
.map((c) => ({ type: "text" as const, text: c.text }));
}
return [];
}
if (msg.role === "assistant") {
const content = msg.content;
if (!Array.isArray(content)) return [];
const parts: HistoryPart[] = [];
for (const c of content) {
if (typeof c !== "object" || c === null) continue;
if (c.type === "text") {
parts.push({
type: "text",
text: (c as { type: "text"; text: string }).text,
});
} else if (c.type === "thinking") {
parts.push({
type: "reasoning",
text: (c as { type: "thinking"; thinking: string }).thinking,
});
} else if (c.type === "toolCall") {
const tc = c as {
type: "toolCall";
id: string;
name: string;
arguments: unknown;
};
parts.push({
type: "tool-invocation",
toolCallId: tc.id,
toolName: tc.name,
args: tc.arguments,
state: "call",
});
}
}
return parts;
}
if (msg.role === "toolResult") {
const tr = msg as {
role: "toolResult";
toolCallId: string;
toolName: string;
content: unknown;
isError: boolean;
};
const textParts = Array.isArray(tr.content)
? (tr.content as { type: string; text?: string }[])
.filter((c) => c.type === "text" && typeof c.text === "string")
.map((c) => c.text as string)
.join("")
: "";
return [
{
type: "tool-invocation",
toolCallId: tr.toolCallId,
toolName: tr.toolName,
args: undefined,
state: tr.isError ? "error" : "result",
result: textParts,
},
];
}
return [];
}
getGatewaySessionDir(sessionKey: string): string {
return join(this.sessionDirRoot, sanitizeSessionKey(sessionKey));
}

View file

@ -0,0 +1,195 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { HistoryMessage, HistoryPart } from "./types.js";
export interface GatewayTransientToolResult {
toolCallId: string;
toolName: string;
result: unknown;
isError: boolean;
timestamp: number;
}
function isSupportedHistoryRole(
role: AgentMessage["role"],
): role is "user" | "assistant" | "toolResult" {
return role === "user" || role === "assistant" || role === "toolResult";
}
function historyMessageId(message: AgentMessage, index: number): string {
return `${message.timestamp}-${message.role}-${index}`;
}
function transientAssistantId(
sessionKey: string,
message: AgentMessage | null,
): string {
return `draft:${sessionKey}:${message?.timestamp ?? 0}`;
}
function transientToolResultId(sessionKey: string, toolCallId: string): string {
return `draft-tool:${sessionKey}:${toolCallId}`;
}
export function messageContentToHistoryParts(msg: AgentMessage): HistoryPart[] {
if (msg.role === "user") {
const content = msg.content;
if (typeof content === "string") {
return [{ type: "text", text: content }];
}
if (Array.isArray(content)) {
return content
.filter(
(contentPart): contentPart is { type: "text"; text: string } =>
typeof contentPart === "object" &&
contentPart !== null &&
contentPart.type === "text",
)
.map((contentPart) => ({
type: "text" as const,
text: contentPart.text,
}));
}
return [];
}
if (msg.role === "assistant") {
const content = msg.content;
if (!Array.isArray(content)) return [];
const parts: HistoryPart[] = [];
for (const contentPart of content) {
if (typeof contentPart !== "object" || contentPart === null) {
continue;
}
if (contentPart.type === "text") {
parts.push({
type: "text",
text: (contentPart as { type: "text"; text: string }).text,
});
} else if (contentPart.type === "thinking") {
parts.push({
type: "reasoning",
text: (contentPart as { type: "thinking"; thinking: string })
.thinking,
});
} else if (contentPart.type === "toolCall") {
const toolCall = contentPart as {
type: "toolCall";
id: string;
name: string;
arguments: unknown;
};
parts.push({
type: "tool-invocation",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
state: "call",
});
}
}
return parts;
}
if (msg.role === "toolResult") {
const toolResult = msg as {
role: "toolResult";
toolCallId: string;
toolName: string;
content: unknown;
isError: boolean;
};
const textParts = Array.isArray(toolResult.content)
? (toolResult.content as { type: string; text?: string }[])
.filter((contentPart) => {
return (
contentPart.type === "text" &&
typeof contentPart.text === "string"
);
})
.map((contentPart) => contentPart.text as string)
.join("")
: "";
return [
{
type: "tool-invocation",
toolCallId: toolResult.toolCallId,
toolName: toolResult.toolName,
args: undefined,
state: toolResult.isError ? "error" : "result",
result: textParts,
},
];
}
return [];
}
export function buildGatewaySessionStateMessages(params: {
sessionKey: string;
rawMessages: AgentMessage[];
activeAssistantMessage: AgentMessage | null;
pendingToolResults: GatewayTransientToolResult[];
}): HistoryMessage[] {
const {
sessionKey,
rawMessages,
activeAssistantMessage,
pendingToolResults,
} = params;
const messages: HistoryMessage[] = [];
const persistedToolCallIds = new Set<string>();
for (const [index, message] of rawMessages.entries()) {
if (!isSupportedHistoryRole(message.role)) {
continue;
}
if (
message.role === "toolResult" &&
typeof (message as { toolCallId?: unknown }).toolCallId === "string"
) {
persistedToolCallIds.add((message as { toolCallId: string }).toolCallId);
}
messages.push({
id: historyMessageId(message, index),
role: message.role,
parts: messageContentToHistoryParts(message),
timestamp: message.timestamp,
});
}
if (activeAssistantMessage?.role === "assistant") {
messages.push({
id: transientAssistantId(sessionKey, activeAssistantMessage),
role: "assistant",
parts: messageContentToHistoryParts(activeAssistantMessage),
timestamp: activeAssistantMessage.timestamp ?? Date.now(),
});
}
for (const pendingToolResult of pendingToolResults) {
if (persistedToolCallIds.has(pendingToolResult.toolCallId)) {
continue;
}
messages.push({
id: transientToolResultId(sessionKey, pendingToolResult.toolCallId),
role: "toolResult",
parts: [
{
type: "tool-invocation",
toolCallId: pendingToolResult.toolCallId,
toolName: pendingToolResult.toolName,
args: undefined,
state: pendingToolResult.isError ? "error" : "result",
result: pendingToolResult.result,
},
],
timestamp: pendingToolResult.timestamp,
});
}
return messages;
}

View file

@ -48,6 +48,11 @@ export interface GatewaySessionSnapshot {
updatedAt: number;
}
export interface GatewaySessionState {
session: GatewaySessionSnapshot;
messages: HistoryMessage[];
}
export interface ModelInfo {
provider: string;
modelId: string;

View file

@ -0,0 +1,163 @@
import test from "node:test";
import assert from "node:assert/strict";
import {
buildGatewaySessionStateMessages,
messageContentToHistoryParts,
} from "../src/core/gateway/session-state.ts";
test("messageContentToHistoryParts converts assistant text, reasoning, and tool calls", () => {
const parts = messageContentToHistoryParts({
role: "assistant",
timestamp: 123,
content: [
{ type: "text", text: "hello" },
{ type: "thinking", thinking: "working" },
{
type: "toolCall",
id: "tool-1",
name: "bash",
arguments: { cmd: "pwd" },
},
],
});
assert.deepEqual(parts, [
{ type: "text", text: "hello" },
{ type: "reasoning", text: "working" },
{
type: "tool-invocation",
toolCallId: "tool-1",
toolName: "bash",
args: { cmd: "pwd" },
state: "call",
},
]);
});
test("buildGatewaySessionStateMessages includes the active assistant draft while a run is still processing", () => {
const messages = buildGatewaySessionStateMessages({
sessionKey: "agent:test:chat",
rawMessages: [
{
role: "user",
timestamp: 100,
content: "build a todo app",
},
],
activeAssistantMessage: {
role: "assistant",
timestamp: 200,
content: [
{ type: "text", text: "Working on it" },
{
type: "toolCall",
id: "tool-1",
name: "write",
arguments: { filePath: "app.tsx" },
},
],
},
pendingToolResults: [],
});
assert.deepEqual(messages, [
{
id: "100-user-0",
role: "user",
parts: [{ type: "text", text: "build a todo app" }],
timestamp: 100,
},
{
id: "draft:agent:test:chat:200",
role: "assistant",
parts: [
{ type: "text", text: "Working on it" },
{
type: "tool-invocation",
toolCallId: "tool-1",
toolName: "write",
args: { filePath: "app.tsx" },
state: "call",
},
],
timestamp: 200,
},
]);
});
test("buildGatewaySessionStateMessages keeps transient tool results until persisted history catches up", () => {
const messages = buildGatewaySessionStateMessages({
sessionKey: "agent:test:chat",
rawMessages: [
{
role: "user",
timestamp: 100,
content: "ship it",
},
{
role: "toolResult",
timestamp: 210,
toolCallId: "tool-1",
toolName: "bash",
content: [{ type: "text", text: "done" }],
isError: false,
},
],
activeAssistantMessage: null,
pendingToolResults: [
{
toolCallId: "tool-1",
toolName: "bash",
result: { stdout: "done" },
isError: false,
timestamp: 220,
},
{
toolCallId: "tool-2",
toolName: "write",
result: { path: "README.md" },
isError: false,
timestamp: 240,
},
],
});
assert.deepEqual(messages, [
{
id: "100-user-0",
role: "user",
parts: [{ type: "text", text: "ship it" }],
timestamp: 100,
},
{
id: "210-toolResult-1",
role: "toolResult",
parts: [
{
type: "tool-invocation",
toolCallId: "tool-1",
toolName: "bash",
args: undefined,
state: "result",
result: "done",
},
],
timestamp: 210,
},
{
id: "draft-tool:agent:test:chat:tool-2",
role: "toolResult",
parts: [
{
type: "tool-invocation",
toolCallId: "tool-2",
toolName: "write",
args: undefined,
state: "result",
result: { path: "README.md" },
},
],
timestamp: 240,
},
]);
});