From 5ef3cc90d1bdea1c4207e0c353a1f7c22dbb5388 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Fri, 2 Jan 2026 21:52:45 +0100 Subject: [PATCH] Add guard against concurrent prompt() calls Agent.prompt() and Agent.continue() now throw if called while already streaming, preventing race conditions and corrupted state. Use queueMessage() to queue messages during streaming, or await the previous call. AgentSession.prompt() has the same guard with a message directing users to queueMessage(). Ref #403 --- packages/agent/CHANGELOG.md | 4 + packages/agent/src/agent.ts | 8 + packages/agent/test/agent.test.ts | 112 +++++++++- packages/coding-agent/CHANGELOG.md | 4 + .../coding-agent/src/core/agent-session.ts | 5 +- .../test/agent-session-concurrent.test.ts | 196 ++++++++++++++++++ 6 files changed, 327 insertions(+), 2 deletions(-) create mode 100644 packages/coding-agent/test/agent-session-concurrent.test.ts diff --git a/packages/agent/CHANGELOG.md b/packages/agent/CHANGELOG.md index 991be7cb..23d4c637 100644 --- a/packages/agent/CHANGELOG.md +++ b/packages/agent/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- `prompt()` and `continue()` now throw if called while the agent is already streaming, preventing race conditions and corrupted state. Use `queueMessage()` to queue messages during streaming, or `await` the previous call. + ## [0.31.1] - 2026-01-02 ## [0.31.0] - 2026-01-02 diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 078b707e..851b70b2 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -171,6 +171,10 @@ export class Agent { async prompt(message: AgentMessage | AgentMessage[]): Promise; async prompt(input: string, images?: ImageContent[]): Promise; async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) { + if (this._state.isStreaming) { + throw new Error("Agent is already processing a prompt. Use queueMessage() or wait for completion."); + } + const model = this._state.model; if (!model) throw new Error("No model configured"); @@ -199,6 +203,10 @@ export class Agent { /** Continue from current context (for retry after overflow) */ async continue() { + if (this._state.isStreaming) { + throw new Error("Agent is already processing. Wait for completion before continuing."); + } + const messages = this._state.messages; if (messages.length === 0) { throw new Error("No messages to continue from"); diff --git a/packages/agent/test/agent.test.ts b/packages/agent/test/agent.test.ts index 8fee033f..21012841 100644 --- a/packages/agent/test/agent.test.ts +++ b/packages/agent/test/agent.test.ts @@ -1,7 +1,41 @@ -import { getModel } from "@mariozechner/pi-ai"; +import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; import { Agent } from "../src/index.js"; +// Mock stream that mimics AssistantMessageEventStream +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createAssistantMessage(text: string): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text }], + api: "openai-responses", + provider: "openai", + model: "mock", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; +} + describe("Agent", () => { it("should create an agent instance with default state", () => { const agent = new Agent(); @@ -109,4 +143,80 @@ describe("Agent", () => { // Should not throw even if nothing is running expect(() => agent.abort()).not.toThrow(); }); + + it("should throw when prompt() called while streaming", async () => { + let abortSignal: AbortSignal | undefined; + const agent = new Agent({ + // Use a stream function that responds to abort + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + // Check abort signal periodically + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + // Start first prompt (don't await, it will block until abort) + const firstPrompt = agent.prompt("First message"); + + // Wait a tick for isStreaming to be set + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(agent.state.isStreaming).toBe(true); + + // Second prompt should reject + await expect(agent.prompt("Second message")).rejects.toThrow( + "Agent is already processing a prompt. Use queueMessage() or wait for completion.", + ); + + // Cleanup - abort to stop the stream + agent.abort(); + await firstPrompt.catch(() => {}); // Ignore abort error + }); + + it("should throw when continue() called while streaming", async () => { + let abortSignal: AbortSignal | undefined; + const agent = new Agent({ + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + // Start first prompt + const firstPrompt = agent.prompt("First message"); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(agent.state.isStreaming).toBe(true); + + // continue() should reject + await expect(agent.continue()).rejects.toThrow( + "Agent is already processing. Wait for completion before continuing.", + ); + + // Cleanup + agent.abort(); + await firstPrompt.catch(() => {}); + }); }); diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index dda47046..8698ac07 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- `AgentSession.prompt()` now throws if called while the agent is already streaming, preventing race conditions. Use `queueMessage()` to queue messages during streaming. + ## [0.31.1] - 2026-01-02 ### Fixed diff --git a/packages/coding-agent/src/core/agent-session.ts b/packages/coding-agent/src/core/agent-session.ts index 9a4d752c..06601e3e 100644 --- a/packages/coding-agent/src/core/agent-session.ts +++ b/packages/coding-agent/src/core/agent-session.ts @@ -112,7 +112,6 @@ export interface SessionStats { cost: number; } -/** Internal marker for hook messages queued through the agent loop */ // ============================================================================ // Constants // ============================================================================ @@ -456,6 +455,10 @@ export class AgentSession { * @throws Error if no model selected or no API key available */ async prompt(text: string, options?: PromptOptions): Promise { + if (this.isStreaming) { + throw new Error("Agent is already processing. Use queueMessage() to queue messages during streaming."); + } + // Flush any pending bash messages before the new prompt this._flushPendingBashMessages(); diff --git a/packages/coding-agent/test/agent-session-concurrent.test.ts b/packages/coding-agent/test/agent-session-concurrent.test.ts new file mode 100644 index 00000000..75c68c2e --- /dev/null +++ b/packages/coding-agent/test/agent-session-concurrent.test.ts @@ -0,0 +1,196 @@ +/** + * Tests for AgentSession concurrent prompt guard. + */ + +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { Agent } from "@mariozechner/pi-agent-core"; +import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { AgentSession } from "../src/core/agent-session.js"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { ModelRegistry } from "../src/core/model-registry.js"; +import { SessionManager } from "../src/core/session-manager.js"; +import { SettingsManager } from "../src/core/settings-manager.js"; + +// Mock stream that mimics AssistantMessageEventStream +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createAssistantMessage(text: string): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text }], + api: "anthropic-messages", + provider: "anthropic", + model: "mock", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; +} + +describe("AgentSession concurrent prompt guard", () => { + let session: AgentSession; + let tempDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `pi-concurrent-test-${Date.now()}`); + mkdirSync(tempDir, { recursive: true }); + }); + + afterEach(async () => { + if (session) { + session.dispose(); + } + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }); + + function createSession() { + const model = getModel("anthropic", "claude-sonnet-4-5")!; + let abortSignal: AbortSignal | undefined; + + // Use a stream function that responds to abort + const agent = new Agent({ + getApiKey: () => "test-key", + initialState: { + model, + systemPrompt: "Test", + tools: [], + }, + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.create(tempDir, tempDir); + const authStorage = new AuthStorage(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + // Set a runtime API key so validation passes + authStorage.setRuntimeApiKey("anthropic", "test-key"); + + session = new AgentSession({ + agent, + sessionManager, + settingsManager, + modelRegistry, + }); + + return session; + } + + it("should throw when prompt() called while streaming", async () => { + createSession(); + + // Start first prompt (don't await, it will block until abort) + const firstPrompt = session.prompt("First message"); + + // Wait a tick for isStreaming to be set + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Verify we're streaming + expect(session.isStreaming).toBe(true); + + // Second prompt should reject + await expect(session.prompt("Second message")).rejects.toThrow( + "Agent is already processing. Use queueMessage() to queue messages during streaming.", + ); + + // Cleanup + await session.abort(); + await firstPrompt.catch(() => {}); // Ignore abort error + }); + + it("should allow queueMessage() while streaming", async () => { + createSession(); + + // Start first prompt + const firstPrompt = session.prompt("First message"); + await new Promise((resolve) => setTimeout(resolve, 10)); + + // queueMessage should work while streaming + expect(() => session.queueMessage("Queued message")).not.toThrow(); + expect(session.queuedMessageCount).toBe(1); + + // Cleanup + await session.abort(); + await firstPrompt.catch(() => {}); + }); + + it("should allow prompt() after previous completes", async () => { + // Create session with a stream that completes immediately + const model = getModel("anthropic", "claude-sonnet-4-5")!; + const agent = new Agent({ + getApiKey: () => "test-key", + initialState: { + model, + systemPrompt: "Test", + tools: [], + }, + streamFn: () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + stream.push({ type: "done", reason: "stop", message: createAssistantMessage("Done") }); + }); + return stream; + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.create(tempDir, tempDir); + const authStorage = new AuthStorage(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + authStorage.setRuntimeApiKey("anthropic", "test-key"); + + session = new AgentSession({ + agent, + sessionManager, + settingsManager, + modelRegistry, + }); + + // First prompt completes + await session.prompt("First message"); + + // Should not be streaming anymore + expect(session.isStreaming).toBe(false); + + // Second prompt should work + await expect(session.prompt("Second message")).resolves.not.toThrow(); + }); +});