From 4a29c13e0d604605c12dff76bb5c4cb70bd0b6ff Mon Sep 17 00:00:00 2001 From: Harivansh Rathi Date: Mon, 9 Mar 2026 16:56:24 -0700 Subject: [PATCH] feat: steer active chat messages Co-authored-by: Codex --- .../coding-agent/src/core/gateway/runtime.ts | 55 +++++++- .../coding-agent/test/gateway-steer.test.ts | 118 ++++++++++++++++++ 2 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 packages/coding-agent/test/gateway-steer.test.ts diff --git a/packages/coding-agent/src/core/gateway/runtime.ts b/packages/coding-agent/src/core/gateway/runtime.ts index 0161700..136a32c 100644 --- a/packages/coding-agent/src/core/gateway/runtime.ts +++ b/packages/coding-agent/src/core/gateway/runtime.ts @@ -849,7 +849,7 @@ export class GatewayRuntime { } const sessionMatch = path.match( - /^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload|state))?$/, + /^\/sessions\/([^/]+)(?:\/(events|messages|abort|reset|chat|history|model|reload|state|steer))?$/, ); if (!sessionMatch) { this.writeJson(response, 404, { error: "Not found" }); @@ -910,6 +910,18 @@ export class GatewayRuntime { return; } + if (action === "steer" && method === "POST") { + const body = await this.readJsonBody(request); + const text = extractUserText(body); + if (!text) { + this.writeJson(response, 400, { error: "Missing user message text" }); + return; + } + const result = await this.handleSteer(sessionKey, text); + this.writeJson(response, 200, result); + return; + } + if (action === "reset" && method === "POST") { await this.requireExistingSession(sessionKey); await this.resetSession(sessionKey); @@ -1099,6 +1111,47 @@ export class GatewayRuntime { } } + private async handleSteer( + sessionKey: string, + text: string, + ): Promise<{ ok: true; mode: "steer" | "queued"; sessionKey: string }> { + const managedSession = await this.requireExistingSession(sessionKey); + const preview = text.length > 80 ? `${text.slice(0, 80)}...` : text; + + if (managedSession.processing) { + this.logSession(managedSession, `steer text="${preview}"`); + await managedSession.session.steer(text); + return { ok: true, mode: "steer", sessionKey }; + } + + if (managedSession.queue.length >= this.config.session.maxQueuePerSession) { + throw new HttpError( + 409, + `Queue full (${this.config.session.maxQueuePerSession} pending).`, + ); + } + + this.logSession( + managedSession, + `steer-fallback queue text="${preview}" depth=${managedSession.queue.length + 1}`, + ); + void this.enqueueManagedMessage({ + request: { + sessionKey, + text, + source: "extension", + }, + }).then((result) => { + if (!result.ok) { + this.log( + `[steer] session=${sessionKey} queued fallback failed: ${result.error ?? "Unknown error"}`, + ); + } + }); + + return { ok: true, mode: "queued", sessionKey }; + } + private requireAuth( request: IncomingMessage, response: ServerResponse, diff --git a/packages/coding-agent/test/gateway-steer.test.ts b/packages/coding-agent/test/gateway-steer.test.ts new file mode 100644 index 0000000..8327cb9 --- /dev/null +++ b/packages/coding-agent/test/gateway-steer.test.ts @@ -0,0 +1,118 @@ +import { describe, expect, it, vi } from "vitest"; +import { GatewayRuntime } from "../src/core/gateway/runtime.js"; + +function createMockSession() { + return { + sessionId: "session-1", + messages: [], + prompt: vi.fn().mockResolvedValue(undefined), + steer: vi.fn().mockResolvedValue(undefined), + abort: vi.fn().mockResolvedValue(undefined), + dispose: vi.fn(), + subscribe: vi.fn(() => () => {}), + sessionManager: { + getSessionDir: () => "/tmp/pi-gateway-test", + }, + }; +} + +function createRuntime(session = createMockSession()) { + return new GatewayRuntime({ + config: { + bind: "127.0.0.1", + port: 0, + session: { + idleMinutes: 5, + maxQueuePerSession: 4, + }, + webhook: { + enabled: false, + basePath: "/webhooks", + }, + }, + primarySessionKey: "primary", + primarySession: session as never, + createSession: async () => session as never, + }); +} + +function addManagedSession( + runtime: GatewayRuntime, + sessionKey: string, + session: ReturnType, + processing: boolean, +) { + const managedSession = { + sessionKey, + session, + queue: [], + processing, + activeAssistantMessage: null, + pendingToolResults: [], + createdAt: Date.now(), + lastActiveAt: Date.now(), + listeners: new Set(), + unsubscribe: () => {}, + }; + + (runtime as unknown as { sessions: Map }).sessions.set( + sessionKey, + managedSession, + ); + + return managedSession; +} + +describe("GatewayRuntime steer handling", () => { + it("steers the active session instead of queueing a second prompt", async () => { + const session = createMockSession(); + const runtime = createRuntime(session); + addManagedSession(runtime, "chat", session, true); + + const result = await ( + runtime as unknown as { + handleSteer: ( + sessionKey: string, + text: string, + ) => Promise<{ ok: true; mode: "steer" | "queued"; sessionKey: string }>; + } + ).handleSteer("chat", "keep going"); + + expect(result).toEqual({ + ok: true, + mode: "steer", + sessionKey: "chat", + }); + expect(session.steer).toHaveBeenCalledWith("keep going"); + expect(session.prompt).not.toHaveBeenCalled(); + }); + + it("queues a prompt immediately when steer races an idle session", async () => { + const session = createMockSession(); + const runtime = createRuntime(session); + addManagedSession(runtime, "chat", session, false); + + const result = await ( + runtime as unknown as { + handleSteer: ( + sessionKey: string, + text: string, + ) => Promise<{ ok: true; mode: "steer" | "queued"; sessionKey: string }>; + } + ).handleSteer("chat", "pick this up next"); + + expect(result).toEqual({ + ok: true, + mode: "queued", + sessionKey: "chat", + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(session.steer).not.toHaveBeenCalled(); + expect(session.prompt).toHaveBeenCalledWith("pick this up next", { + images: undefined, + source: "extension", + }); + }); +});