diff --git a/packages/coding-agent/src/core/gateway-runtime.ts b/packages/coding-agent/src/core/gateway-runtime.ts index 570942b1..55e069c4 100644 --- a/packages/coding-agent/src/core/gateway-runtime.ts +++ b/packages/coding-agent/src/core/gateway-runtime.ts @@ -1,8 +1,8 @@ import { createServer, type IncomingMessage, type Server, type ServerResponse } from "node:http"; import { join } from "node:path"; import { URL } from "node:url"; -import type { ImageContent } from "@mariozechner/pi-ai"; import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { ImageContent } from "@mariozechner/pi-ai"; import type { AgentSession, AgentSessionEvent } from "./agent-session.js"; import { SessionManager } from "./session-manager.js"; import type { Settings } from "./settings-manager.js"; @@ -132,6 +132,15 @@ interface ManagedGatewaySession { unsubscribe: () => void; } +class HttpError extends Error { + constructor( + public readonly statusCode: number, + message: string, + ) { + super(message); + } +} + let activeGatewayRuntime: GatewayRuntime | null = null; export function setActiveGatewayRuntime(runtime: GatewayRuntime | null): void { @@ -179,7 +188,10 @@ export class GatewayRuntime { this.server = createServer((request, response) => { void this.handleHttpRequest(request, response).catch((error) => { const message = error instanceof Error ? error.message : String(error); - this.writeJson(response, 500, { error: message }); + const statusCode = error instanceof HttpError ? error.statusCode : 500; + if (!response.writableEnded) { + this.writeJson(response, statusCode, { error: message }); + } }); }); @@ -293,18 +305,20 @@ export class GatewayRuntime { const managedSession = this.sessions.get(sessionKey); if (!managedSession) return; + if (managedSession.processing) { + await managedSession.session.abort(); + } + if (sessionKey === this.primarySessionKey) { + this.rejectQueuedMessages(managedSession, "Session reset"); await managedSession.session.newSession(); - managedSession.queue.length = 0; managedSession.processing = false; managedSession.lastActiveAt = Date.now(); this.emitState(managedSession); return; } - if (managedSession.processing) { - await managedSession.session.abort(); - } + this.rejectQueuedMessages(managedSession, "Session reset"); managedSession.unsubscribe(); managedSession.session.dispose(); this.sessions.delete(sessionKey); @@ -393,6 +407,26 @@ export class GatewayRuntime { } } + private getManagedSessionOrThrow(sessionKey: string): ManagedGatewaySession { + const managedSession = this.sessions.get(sessionKey); + if (!managedSession) { + throw new HttpError(404, `Session not found: ${sessionKey}`); + } + return managedSession; + } + + private rejectQueuedMessages(managedSession: ManagedGatewaySession, error: string): void { + const queuedMessages = managedSession.queue.splice(0); + for (const queuedMessage of queuedMessages) { + queuedMessage.resolve({ + ok: false, + response: "", + error, + sessionKey: managedSession.sessionKey, + }); + } + } + private handleSessionEvent(managedSession: ManagedGatewaySession, event: AgentSessionEvent): void { switch (event.type) { case "turn_start": @@ -587,7 +621,7 @@ export class GatewayRuntime { return; } - if (method === "POST" && path === "/config") { + if (method === "PATCH" && path === "/config") { const body = await this.readJsonBody(request); await this.handlePatchConfig(body); this.writeJson(response, 200, { ok: true }); @@ -606,7 +640,9 @@ export class GatewayRuntime { return; } - const sessionMatch = path.match(/^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload))?$/); + const sessionMatch = path.match( + /^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload))?$/, + ); if (!sessionMatch) { this.writeJson(response, 404, { error: "Not found" }); return; @@ -616,7 +652,7 @@ export class GatewayRuntime { const action = sessionMatch[2]; if (!action && method === "GET") { - const session = await this.ensureSession(sessionKey); + const session = this.getManagedSessionOrThrow(sessionKey); this.writeJson(response, 200, { session: this.createSnapshot(session) }); return; } @@ -661,11 +697,13 @@ export class GatewayRuntime { } if (action === "abort" && method === "POST") { + this.getManagedSessionOrThrow(sessionKey); this.writeJson(response, 200, { ok: this.abortSession(sessionKey) }); return; } if (action === "reset" && method === "POST") { + this.getManagedSessionOrThrow(sessionKey); await this.resetSession(sessionKey); this.writeJson(response, 200, { ok: true }); return; @@ -837,7 +875,11 @@ export class GatewayRuntime { return {}; } const body = Buffer.concat(chunks).toString("utf8"); - return JSON.parse(body) as Record; + try { + return JSON.parse(body) as Record; + } catch { + throw new HttpError(400, "Invalid JSON body"); + } } private writeJson(response: ServerResponse, statusCode: number, payload: unknown): void { @@ -850,16 +892,16 @@ export class GatewayRuntime { // New handler methods added for companion-cloud web app integration // --------------------------------------------------------------------------- - private async handleGetModels(): Promise<{ models: ModelInfo[]; current: { provider: string; modelId: string } | null }> { + private async handleGetModels(): Promise<{ + models: ModelInfo[]; + current: { provider: string; modelId: string } | null; + }> { const available = this.primarySession.modelRegistry.getAvailable(); const models: ModelInfo[] = available.map((m) => ({ provider: m.provider, modelId: m.id, displayName: m.name, - capabilities: [ - ...(m.reasoning ? ["reasoning"] : []), - ...(m.input.includes("image") ? ["vision"] : []), - ], + capabilities: [...(m.reasoning ? ["reasoning"] : []), ...(m.input.includes("image") ? ["vision"] : [])], })); const currentModel = this.primarySession.model; const current = currentModel ? { provider: currentModel.provider, modelId: currentModel.id } : null; @@ -871,23 +913,20 @@ export class GatewayRuntime { provider: string, modelId: string, ): Promise<{ ok: true; model: { provider: string; modelId: string } }> { - const managed = this.sessions.get(sessionKey); - if (!managed) { - throw new Error(`Session not found: ${sessionKey}`); - } + const managed = this.getManagedSessionOrThrow(sessionKey); const found = managed.session.modelRegistry.find(provider, modelId); if (!found) { - throw new Error(`Model not found: ${provider}/${modelId}`); + throw new HttpError(404, `Model not found: ${provider}/${modelId}`); } await managed.session.setModel(found); return { ok: true, model: { provider, modelId } }; } private handleGetHistory(sessionKey: string, limit?: number): HistoryMessage[] { - const managed = this.sessions.get(sessionKey); - if (!managed) { - return []; + if (limit !== undefined && (!Number.isFinite(limit) || limit < 1)) { + throw new HttpError(400, "History limit must be a positive integer"); } + const managed = this.getManagedSessionOrThrow(sessionKey); const rawMessages = managed.session.messages; const messages: HistoryMessage[] = []; for (const msg of rawMessages) { @@ -905,27 +944,26 @@ export class GatewayRuntime { } private async handlePatchSession(sessionKey: string, patch: { name?: string }): Promise { - const managed = this.sessions.get(sessionKey); - if (!managed) { - throw new Error(`Session not found: ${sessionKey}`); - } + const managed = this.getManagedSessionOrThrow(sessionKey); if (patch.name !== undefined) { // Labels in pi-mono are per-entry; we label the current leaf entry - const leafId = managed.session.sessionManager.getLeafId?.(); - if (leafId) { - managed.session.sessionManager.appendLabelChange(leafId, patch.name); + const leafId = managed.session.sessionManager.getLeafId(); + if (!leafId) { + throw new HttpError(409, `Cannot rename session without an active leaf entry: ${sessionKey}`); } + managed.session.sessionManager.appendLabelChange(leafId, patch.name); } } private async handleDeleteSession(sessionKey: string): Promise { if (sessionKey === this.primarySessionKey) { - throw new Error("Cannot delete primary session"); + throw new HttpError(400, "Cannot delete primary session"); } - const managed = this.sessions.get(sessionKey); - if (!managed) { - throw new Error(`Session not found: ${sessionKey}`); + const managed = this.getManagedSessionOrThrow(sessionKey); + if (managed.processing) { + await managed.session.abort(); } + this.rejectQueuedMessages(managed, `Session deleted: ${sessionKey}`); managed.unsubscribe(); managed.session.dispose(); this.sessions.delete(sessionKey); @@ -934,8 +972,16 @@ export class GatewayRuntime { private getPublicConfig(): Record { const settings = this.primarySession.settingsManager.getGlobalSettings(); const { gateway, ...rest } = settings as Record & { gateway?: Record }; - const { bearerToken: _bearerToken, ...safeGateway } = gateway ?? {}; - return { ...rest, gateway: safeGateway }; + const { bearerToken: _bearerToken, webhook, ...safeGatewayRest } = gateway ?? {}; + const { secret: _secret, ...safeWebhook } = + webhook && typeof webhook === "object" ? (webhook as Record) : {}; + return { + ...rest, + gateway: { + ...safeGatewayRest, + ...(webhook && typeof webhook === "object" ? { webhook: safeWebhook } : {}), + }, + }; } private async handlePatchConfig(patch: Record): Promise { @@ -954,10 +1000,7 @@ export class GatewayRuntime { } private async handleReloadSession(sessionKey: string): Promise { - const managed = this.sessions.get(sessionKey); - if (!managed) { - throw new Error(`Session not found: ${sessionKey}`); - } + const managed = this.getManagedSessionOrThrow(sessionKey); // Reloading config by calling settingsManager.reload() on the session managed.session.settingsManager.reload(); } @@ -970,7 +1013,9 @@ export class GatewayRuntime { } if (Array.isArray(content)) { return content - .filter((c): c is { type: "text"; text: string } => typeof c === "object" && c !== null && c.type === "text") + .filter( + (c): c is { type: "text"; text: string } => typeof c === "object" && c !== null && c.type === "text", + ) .map((c) => ({ type: "text" as const, text: c.text })); } return []; @@ -1001,7 +1046,13 @@ export class GatewayRuntime { } if (msg.role === "toolResult") { - const tr = msg as { role: "toolResult"; toolCallId: string; toolName: string; content: unknown; isError: boolean }; + const tr = msg as { + role: "toolResult"; + toolCallId: string; + toolName: string; + content: unknown; + isError: boolean; + }; const textParts = Array.isArray(tr.content) ? (tr.content as { type: string; text?: string }[]) .filter((c) => c.type === "text" && typeof c.text === "string")