This commit is contained in:
Harivansh Rathi 2026-03-06 17:40:33 -08:00
parent 48d3e90b8c
commit 6bdf0ec058

View file

@ -1,8 +1,8 @@
import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http"; import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http";
import { join } from "node:path"; import { join } from "node:path";
import { URL } from "node:url"; import { URL } from "node:url";
import type { ImageContent } from "@mariozechner/pi-ai";
import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ImageContent } from "@mariozechner/pi-ai";
import type { AgentSession, AgentSessionEvent } from "./agent-session.js"; import type { AgentSession, AgentSessionEvent } from "./agent-session.js";
import { SessionManager } from "./session-manager.js"; import { SessionManager } from "./session-manager.js";
import type { Settings } from "./settings-manager.js"; import type { Settings } from "./settings-manager.js";
@ -132,6 +132,15 @@ interface ManagedGatewaySession {
unsubscribe: () => void; unsubscribe: () => void;
} }
class HttpError extends Error {
constructor(
public readonly statusCode: number,
message: string,
) {
super(message);
}
}
let activeGatewayRuntime: GatewayRuntime | null = null; let activeGatewayRuntime: GatewayRuntime | null = null;
export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void { export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void {
@ -179,7 +188,10 @@ export class GatewayRuntime {
this.server = createServer((request, response) => { this.server = createServer((request, response) => {
void this.handleHttpRequest(request, response).catch((error) => { void this.handleHttpRequest(request, response).catch((error) => {
const message = error instanceof Error ? error.message : String(error); const message = error instanceof Error ? error.message : String(error);
this.writeJson(response, 500, { error: message }); const statusCode = error instanceof HttpError ? error.statusCode : 500;
if (!response.writableEnded) {
this.writeJson(response, statusCode, { error: message });
}
}); });
}); });
@ -293,18 +305,20 @@ export class GatewayRuntime {
const managedSession = this.sessions.get(sessionKey); const managedSession = this.sessions.get(sessionKey);
if (!managedSession) return; if (!managedSession) return;
if (managedSession.processing) {
await managedSession.session.abort();
}
if (sessionKey === this.primarySessionKey) { if (sessionKey === this.primarySessionKey) {
this.rejectQueuedMessages(managedSession, "Session reset");
await managedSession.session.newSession(); await managedSession.session.newSession();
managedSession.queue.length = 0;
managedSession.processing = false; managedSession.processing = false;
managedSession.lastActiveAt = Date.now(); managedSession.lastActiveAt = Date.now();
this.emitState(managedSession); this.emitState(managedSession);
return; return;
} }
if (managedSession.processing) { this.rejectQueuedMessages(managedSession, "Session reset");
await managedSession.session.abort();
}
managedSession.unsubscribe(); managedSession.unsubscribe();
managedSession.session.dispose(); managedSession.session.dispose();
this.sessions.delete(sessionKey); this.sessions.delete(sessionKey);
@ -393,6 +407,26 @@ export class GatewayRuntime {
} }
} }
private getManagedSessionOrThrow(sessionKey: string): ManagedGatewaySession {
const managedSession = this.sessions.get(sessionKey);
if (!managedSession) {
throw new HttpError(404, `Session not found: ${sessionKey}`);
}
return managedSession;
}
private rejectQueuedMessages(managedSession: ManagedGatewaySession, error: string): void {
const queuedMessages = managedSession.queue.splice(0);
for (const queuedMessage of queuedMessages) {
queuedMessage.resolve({
ok: false,
response: "",
error,
sessionKey: managedSession.sessionKey,
});
}
}
private handleSessionEvent(managedSession: ManagedGatewaySession, event: AgentSessionEvent): void { private handleSessionEvent(managedSession: ManagedGatewaySession, event: AgentSessionEvent): void {
switch (event.type) { switch (event.type) {
case "turn_start": case "turn_start":
@ -587,7 +621,7 @@ export class GatewayRuntime {
return; return;
} }
if (method === "POST" && path === "/config") { if (method === "PATCH" && path === "/config") {
const body = await this.readJsonBody(request); const body = await this.readJsonBody(request);
await this.handlePatchConfig(body); await this.handlePatchConfig(body);
this.writeJson(response, 200, { ok: true }); this.writeJson(response, 200, { ok: true });
@ -606,7 +640,9 @@ export class GatewayRuntime {
return; return;
} }
const sessionMatch = path.match(/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload))?$/); const sessionMatch = path.match(
/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload))?$/,
);
if (!sessionMatch) { if (!sessionMatch) {
this.writeJson(response, 404, { error: "Not found" }); this.writeJson(response, 404, { error: "Not found" });
return; return;
@ -616,7 +652,7 @@ export class GatewayRuntime {
const action = sessionMatch[2]; const action = sessionMatch[2];
if (!action && method === "GET") { if (!action && method === "GET") {
const session = await this.ensureSession(sessionKey); const session = this.getManagedSessionOrThrow(sessionKey);
this.writeJson(response, 200, { session: this.createSnapshot(session) }); this.writeJson(response, 200, { session: this.createSnapshot(session) });
return; return;
} }
@ -661,11 +697,13 @@ export class GatewayRuntime {
} }
if (action === "abort" && method === "POST") { if (action === "abort" && method === "POST") {
this.getManagedSessionOrThrow(sessionKey);
this.writeJson(response, 200, { ok: this.abortSession(sessionKey) }); this.writeJson(response, 200, { ok: this.abortSession(sessionKey) });
return; return;
} }
if (action === "reset" && method === "POST") { if (action === "reset" && method === "POST") {
this.getManagedSessionOrThrow(sessionKey);
await this.resetSession(sessionKey); await this.resetSession(sessionKey);
this.writeJson(response, 200, { ok: true }); this.writeJson(response, 200, { ok: true });
return; return;
@ -837,7 +875,11 @@ export class GatewayRuntime {
return {}; return {};
} }
const body = Buffer.concat(chunks).toString("utf8"); const body = Buffer.concat(chunks).toString("utf8");
return JSON.parse(body) as Record<string, unknown>; try {
return JSON.parse(body) as Record<string, unknown>;
} catch {
throw new HttpError(400, "Invalid JSON body");
}
} }
private writeJson(response: ServerResponse, statusCode: number, payload: unknown): void { private writeJson(response: ServerResponse, statusCode: number, payload: unknown): void {
@ -850,16 +892,16 @@ export class GatewayRuntime {
// New handler methods added for companion-cloud web app integration // New handler methods added for companion-cloud web app integration
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
private async handleGetModels(): Promise<{ models: ModelInfo[]; current: { provider: string; modelId: string } | null }> { private async handleGetModels(): Promise<{
models: ModelInfo[];
current: { provider: string; modelId: string } | null;
}> {
const available = this.primarySession.modelRegistry.getAvailable(); const available = this.primarySession.modelRegistry.getAvailable();
const models: ModelInfo[] = available.map((m) => ({ const models: ModelInfo[] = available.map((m) => ({
provider: m.provider, provider: m.provider,
modelId: m.id, modelId: m.id,
displayName: m.name, displayName: m.name,
capabilities: [ capabilities: [...(m.reasoning ? ["reasoning"] : []), ...(m.input.includes("image") ? ["vision"] : [])],
...(m.reasoning ? ["reasoning"] : []),
...(m.input.includes("image") ? ["vision"] : []),
],
})); }));
const currentModel = this.primarySession.model; const currentModel = this.primarySession.model;
const current = currentModel ? { provider: currentModel.provider, modelId: currentModel.id } : null; const current = currentModel ? { provider: currentModel.provider, modelId: currentModel.id } : null;
@ -871,23 +913,20 @@ export class GatewayRuntime {
provider: string, provider: string,
modelId: string, modelId: string,
): Promise<{ ok: true; model: { provider: string; modelId: string } }> { ): Promise<{ ok: true; model: { provider: string; modelId: string } }> {
const managed = this.sessions.get(sessionKey); const managed = this.getManagedSessionOrThrow(sessionKey);
if (!managed) {
throw new Error(`Session not found: ${sessionKey}`);
}
const found = managed.session.modelRegistry.find(provider, modelId); const found = managed.session.modelRegistry.find(provider, modelId);
if (!found) { if (!found) {
throw new Error(`Model not found: ${provider}/${modelId}`); throw new HttpError(404, `Model not found: ${provider}/${modelId}`);
} }
await managed.session.setModel(found); await managed.session.setModel(found);
return { ok: true, model: { provider, modelId } }; return { ok: true, model: { provider, modelId } };
} }
private handleGetHistory(sessionKey: string, limit?: number): HistoryMessage[] { private handleGetHistory(sessionKey: string, limit?: number): HistoryMessage[] {
const managed = this.sessions.get(sessionKey); if (limit !== undefined && (!Number.isFinite(limit) || limit < 1)) {
if (!managed) { throw new HttpError(400, "History limit must be a positive integer");
return [];
} }
const managed = this.getManagedSessionOrThrow(sessionKey);
const rawMessages = managed.session.messages; const rawMessages = managed.session.messages;
const messages: HistoryMessage[] = []; const messages: HistoryMessage[] = [];
for (const msg of rawMessages) { for (const msg of rawMessages) {
@ -905,27 +944,26 @@ export class GatewayRuntime {
} }
private async handlePatchSession(sessionKey: string, patch: { name?: string }): Promise<void> { private async handlePatchSession(sessionKey: string, patch: { name?: string }): Promise<void> {
const managed = this.sessions.get(sessionKey); const managed = this.getManagedSessionOrThrow(sessionKey);
if (!managed) {
throw new Error(`Session not found: ${sessionKey}`);
}
if (patch.name !== undefined) { if (patch.name !== undefined) {
// Labels in pi-mono are per-entry; we label the current leaf entry // Labels in pi-mono are per-entry; we label the current leaf entry
const leafId = managed.session.sessionManager.getLeafId?.(); const leafId = managed.session.sessionManager.getLeafId();
if (leafId) { if (!leafId) {
managed.session.sessionManager.appendLabelChange(leafId, patch.name); throw new HttpError(409, `Cannot rename session without an active leaf entry: ${sessionKey}`);
} }
managed.session.sessionManager.appendLabelChange(leafId, patch.name);
} }
} }
private async handleDeleteSession(sessionKey: string): Promise<void> { private async handleDeleteSession(sessionKey: string): Promise<void> {
if (sessionKey === this.primarySessionKey) { if (sessionKey === this.primarySessionKey) {
throw new Error("Cannot delete primary session"); throw new HttpError(400, "Cannot delete primary session");
} }
const managed = this.sessions.get(sessionKey); const managed = this.getManagedSessionOrThrow(sessionKey);
if (!managed) { if (managed.processing) {
throw new Error(`Session not found: ${sessionKey}`); await managed.session.abort();
} }
this.rejectQueuedMessages(managed, `Session deleted: ${sessionKey}`);
managed.unsubscribe(); managed.unsubscribe();
managed.session.dispose(); managed.session.dispose();
this.sessions.delete(sessionKey); this.sessions.delete(sessionKey);
@ -934,8 +972,16 @@ export class GatewayRuntime {
private getPublicConfig(): Record<string, unknown> { private getPublicConfig(): Record<string, unknown> {
const settings = this.primarySession.settingsManager.getGlobalSettings(); const settings = this.primarySession.settingsManager.getGlobalSettings();
const { gateway, ...rest } = settings as Record<string, unknown> & { gateway?: Record<string, unknown> }; const { gateway, ...rest } = settings as Record<string, unknown> & { gateway?: Record<string, unknown> };
const { bearerToken: _bearerToken, ...safeGateway } = gateway ?? {}; const { bearerToken: _bearerToken, webhook, ...safeGatewayRest } = gateway ?? {};
return { ...rest, gateway: safeGateway }; const { secret: _secret, ...safeWebhook } =
webhook && typeof webhook === "object" ? (webhook as Record<string, unknown>) : {};
return {
...rest,
gateway: {
...safeGatewayRest,
...(webhook && typeof webhook === "object" ? { webhook: safeWebhook } : {}),
},
};
} }
private async handlePatchConfig(patch: Record<string, unknown>): Promise<void> { private async handlePatchConfig(patch: Record<string, unknown>): Promise<void> {
@ -954,10 +1000,7 @@ export class GatewayRuntime {
} }
private async handleReloadSession(sessionKey: string): Promise<void> { private async handleReloadSession(sessionKey: string): Promise<void> {
const managed = this.sessions.get(sessionKey); const managed = this.getManagedSessionOrThrow(sessionKey);
if (!managed) {
throw new Error(`Session not found: ${sessionKey}`);
}
// Reloading config by calling settingsManager.reload() on the session // Reloading config by calling settingsManager.reload() on the session
managed.session.settingsManager.reload(); managed.session.settingsManager.reload();
} }
@ -970,7 +1013,9 @@ export class GatewayRuntime {
} }
if (Array.isArray(content)) { if (Array.isArray(content)) {
return content return content
.filter((c): c is { type: "text"; text: string } => typeof c === "object" && c !== null && c.type === "text") .filter(
(c): c is { type: "text"; text: string } => typeof c === "object" && c !== null && c.type === "text",
)
.map((c) => ({ type: "text" as const, text: c.text })); .map((c) => ({ type: "text" as const, text: c.text }));
} }
return []; return [];
@ -1001,7 +1046,13 @@ export class GatewayRuntime {
} }
if (msg.role === "toolResult") { if (msg.role === "toolResult") {
const tr = msg as { role: "toolResult"; toolCallId: string; toolName: string; content: unknown; isError: boolean }; const tr = msg as {
role: "toolResult";
toolCallId: string;
toolName: string;
content: unknown;
isError: boolean;
};
const textParts = Array.isArray(tr.content) const textParts = Array.isArray(tr.content)
? (tr.content as { type: string; text?: string }[]) ? (tr.content as { type: string; text?: string }[])
.filter((c) => c.type === "text" && typeof c.text === "string") .filter((c) => c.type === "text" && typeof c.text === "string")