diff --git a/packages/agent/CHANGELOG.md b/packages/agent/CHANGELOG.md index 991be7cb..23d4c637 100644 --- a/packages/agent/CHANGELOG.md +++ b/packages/agent/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- `prompt()` and `continue()` now throw if called while the agent is already streaming, preventing race conditions and corrupted state. Use `queueMessage()` to queue messages during streaming, or `await` the previous call. + ## [0.31.1] - 2026-01-02 ## [0.31.0] - 2026-01-02 diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 078b707e..851b70b2 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -171,6 +171,10 @@ export class Agent { async prompt(message: AgentMessage | AgentMessage[]): Promise; async prompt(input: string, images?: ImageContent[]): Promise; async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) { + if (this._state.isStreaming) { + throw new Error("Agent is already processing a prompt. Use queueMessage() or wait for completion."); + } + const model = this._state.model; if (!model) throw new Error("No model configured"); @@ -199,6 +203,10 @@ export class Agent { /** Continue from current context (for retry after overflow) */ async continue() { + if (this._state.isStreaming) { + throw new Error("Agent is already processing. Wait for completion before continuing."); + } + const messages = this._state.messages; if (messages.length === 0) { throw new Error("No messages to continue from"); diff --git a/packages/agent/test/agent.test.ts b/packages/agent/test/agent.test.ts index 8fee033f..21012841 100644 --- a/packages/agent/test/agent.test.ts +++ b/packages/agent/test/agent.test.ts @@ -1,7 +1,41 @@ -import { getModel } from "@mariozechner/pi-ai"; +import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; import { Agent } from "../src/index.js"; +// Mock stream that mimics AssistantMessageEventStream +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createAssistantMessage(text: string): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text }], + api: "openai-responses", + provider: "openai", + model: "mock", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; +} + describe("Agent", () => { it("should create an agent instance with default state", () => { const agent = new Agent(); @@ -109,4 +143,80 @@ describe("Agent", () => { // Should not throw even if nothing is running expect(() => agent.abort()).not.toThrow(); }); + + it("should throw when prompt() called while streaming", async () => { + let abortSignal: AbortSignal | undefined; + const agent = new Agent({ + // Use a stream function that responds to abort + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + // Check abort signal periodically + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + // Start first prompt (don't await, it will block until abort) + const firstPrompt = agent.prompt("First message"); + + // Wait a tick for isStreaming to be set + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(agent.state.isStreaming).toBe(true); + + // Second prompt should reject + await expect(agent.prompt("Second message")).rejects.toThrow( + "Agent is already processing a prompt. Use queueMessage() or wait for completion.", + ); + + // Cleanup - abort to stop the stream + agent.abort(); + await firstPrompt.catch(() => {}); // Ignore abort error + }); + + it("should throw when continue() called while streaming", async () => { + let abortSignal: AbortSignal | undefined; + const agent = new Agent({ + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + // Start first prompt + const firstPrompt = agent.prompt("First message"); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(agent.state.isStreaming).toBe(true); + + // continue() should reject + await expect(agent.continue()).rejects.toThrow( + "Agent is already processing. Wait for completion before continuing.", + ); + + // Cleanup + agent.abort(); + await firstPrompt.catch(() => {}); + }); }); diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index dda47046..8698ac07 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- `AgentSession.prompt()` now throws if called while the agent is already streaming, preventing race conditions. Use `queueMessage()` to queue messages during streaming. + ## [0.31.1] - 2026-01-02 ### Fixed diff --git a/packages/coding-agent/src/core/agent-session.ts b/packages/coding-agent/src/core/agent-session.ts index 9a4d752c..06601e3e 100644 --- a/packages/coding-agent/src/core/agent-session.ts +++ b/packages/coding-agent/src/core/agent-session.ts @@ -112,7 +112,6 @@ export interface SessionStats { cost: number; } -/** Internal marker for hook messages queued through the agent loop */ // ============================================================================ // Constants // ============================================================================ @@ -456,6 +455,10 @@ export class AgentSession { * @throws Error if no model selected or no API key available */ async prompt(text: string, options?: PromptOptions): Promise { + if (this.isStreaming) { + throw new Error("Agent is already processing. Use queueMessage() to queue messages during streaming."); + } + // Flush any pending bash messages before the new prompt this._flushPendingBashMessages(); diff --git a/packages/coding-agent/test/agent-session-concurrent.test.ts b/packages/coding-agent/test/agent-session-concurrent.test.ts new file mode 100644 index 00000000..75c68c2e --- /dev/null +++ b/packages/coding-agent/test/agent-session-concurrent.test.ts @@ -0,0 +1,196 @@ +/** + * Tests for AgentSession concurrent prompt guard. + */ + +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { Agent } from "@mariozechner/pi-agent-core"; +import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { AgentSession } from "../src/core/agent-session.js"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { ModelRegistry } from "../src/core/model-registry.js"; +import { SessionManager } from "../src/core/session-manager.js"; +import { SettingsManager } from "../src/core/settings-manager.js"; + +// Mock stream that mimics AssistantMessageEventStream +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createAssistantMessage(text: string): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text }], + api: "anthropic-messages", + provider: "anthropic", + model: "mock", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; +} + +describe("AgentSession concurrent prompt guard", () => { + let session: AgentSession; + let tempDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `pi-concurrent-test-${Date.now()}`); + mkdirSync(tempDir, { recursive: true }); + }); + + afterEach(async () => { + if (session) { + session.dispose(); + } + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }); + + function createSession() { + const model = getModel("anthropic", "claude-sonnet-4-5")!; + let abortSignal: AbortSignal | undefined; + + // Use a stream function that responds to abort + const agent = new Agent({ + getApiKey: () => "test-key", + initialState: { + model, + systemPrompt: "Test", + tools: [], + }, + streamFn: (_model, _context, options) => { + abortSignal = options?.signal; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + const checkAbort = () => { + if (abortSignal?.aborted) { + stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") }); + } else { + setTimeout(checkAbort, 5); + } + }; + checkAbort(); + }); + return stream; + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.create(tempDir, tempDir); + const authStorage = new AuthStorage(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + // Set a runtime API key so validation passes + authStorage.setRuntimeApiKey("anthropic", "test-key"); + + session = new AgentSession({ + agent, + sessionManager, + settingsManager, + modelRegistry, + }); + + return session; + } + + it("should throw when prompt() called while streaming", async () => { + createSession(); + + // Start first prompt (don't await, it will block until abort) + const firstPrompt = session.prompt("First message"); + + // Wait a tick for isStreaming to be set + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Verify we're streaming + expect(session.isStreaming).toBe(true); + + // Second prompt should reject + await expect(session.prompt("Second message")).rejects.toThrow( + "Agent is already processing. Use queueMessage() to queue messages during streaming.", + ); + + // Cleanup + await session.abort(); + await firstPrompt.catch(() => {}); // Ignore abort error + }); + + it("should allow queueMessage() while streaming", async () => { + createSession(); + + // Start first prompt + const firstPrompt = session.prompt("First message"); + await new Promise((resolve) => setTimeout(resolve, 10)); + + // queueMessage should work while streaming + expect(() => session.queueMessage("Queued message")).not.toThrow(); + expect(session.queuedMessageCount).toBe(1); + + // Cleanup + await session.abort(); + await firstPrompt.catch(() => {}); + }); + + it("should allow prompt() after previous completes", async () => { + // Create session with a stream that completes immediately + const model = getModel("anthropic", "claude-sonnet-4-5")!; + const agent = new Agent({ + getApiKey: () => "test-key", + initialState: { + model, + systemPrompt: "Test", + tools: [], + }, + streamFn: () => { + const stream = new MockAssistantStream(); + queueMicrotask(() => { + stream.push({ type: "start", partial: createAssistantMessage("") }); + stream.push({ type: "done", reason: "stop", message: createAssistantMessage("Done") }); + }); + return stream; + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.create(tempDir, tempDir); + const authStorage = new AuthStorage(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + authStorage.setRuntimeApiKey("anthropic", "test-key"); + + session = new AgentSession({ + agent, + sessionManager, + settingsManager, + modelRegistry, + }); + + // First prompt completes + await session.prompt("First message"); + + // Should not be streaming anymore + expect(session.isStreaming).toBe(false); + + // Second prompt should work + await expect(session.prompt("Second message")).resolves.not.toThrow(); + }); +});