diff --git a/packages/coding-agent/CHANGELOG.md b/packages/coding-agent/CHANGELOG.md index 2e467e24..74307eff 100644 --- a/packages/coding-agent/CHANGELOG.md +++ b/packages/coding-agent/CHANGELOG.md @@ -18,6 +18,7 @@ - Fixed `pi.registerTool()` dynamic registration after session initialization. Tools registered in `session_start` and later handlers now refresh immediately, become active, and are visible to the LLM without `/reload` ([#1720](https://github.com/badlogic/pi-mono/issues/1720)) - Fixed session message persistence ordering by serializing `AgentSession` event processing, preventing `toolResult` entries from being written before their corresponding assistant tool-call messages when extension handlers are asynchronous ([#1717](https://github.com/badlogic/pi-mono/issues/1717)) - Fixed spacing artifacts when custom tool renderers intentionally suppress per-call transcript output, including extra blank rows in interactive streaming and non-zero transcript footprint for empty custom renders ([#1719](https://github.com/badlogic/pi-mono/pull/1719) by [@alasano](https://github.com/alasano)) +- Fixed `session.prompt()` returning before retry completion by creating the retry promise synchronously at `agent_end` dispatch, which closes a race when earlier queued event handlers are async ([#1726](https://github.com/badlogic/pi-mono/pull/1726) by [@pasky](https://github.com/pasky)) ## [0.55.3] - 2026-02-27 diff --git a/packages/coding-agent/src/core/agent-session.ts b/packages/coding-agent/src/core/agent-session.ts index 029ce440..2a600ca9 100644 --- a/packages/coding-agent/src/core/agent-session.ts +++ b/packages/coding-agent/src/core/agent-session.ts @@ -318,6 +318,13 @@ export class AgentSession { /** Internal handler for agent events - shared by subscribe and reconnect */ private _handleAgentEvent = (event: AgentEvent): void => { + // Create retry promise synchronously before queueing async processing. + // Agent.emit() calls this handler synchronously, and prompt() calls waitForRetry() + // as soon as agent.prompt() resolves. If _retryPromise is created only inside + // _processAgentEvent, slow earlier queued events can delay agent_end processing + // and waitForRetry() can miss the in-flight retry. + this._createRetryPromiseForAgentEnd(event); + this._agentEventQueue = this._agentEventQueue.then( () => this._processAgentEvent(event), () => this._processAgentEvent(event), @@ -327,6 +334,36 @@ export class AgentSession { this._agentEventQueue.catch(() => {}); }; + private _createRetryPromiseForAgentEnd(event: AgentEvent): void { + if (event.type !== "agent_end" || this._retryPromise) { + return; + } + + const settings = this.settingsManager.getRetrySettings(); + if (!settings.enabled) { + return; + } + + const lastAssistant = this._findLastAssistantInMessages(event.messages); + if (!lastAssistant || !this._isRetryableError(lastAssistant)) { + return; + } + + this._retryPromise = new Promise((resolve) => { + this._retryResolve = resolve; + }); + } + + private _findLastAssistantInMessages(messages: AgentMessage[]): AssistantMessage | undefined { + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message.role === "assistant") { + return message as AssistantMessage; + } + } + return undefined; + } + private async _processAgentEvent(event: AgentEvent): Promise { // When a user message starts, check if it's from either queue and remove it BEFORE emitting // This ensures the UI sees the updated queue state @@ -2178,17 +2215,21 @@ export class AgentSession { */ private async _handleRetryableError(message: AssistantMessage): Promise { const settings = this.settingsManager.getRetrySettings(); - if (!settings.enabled) return false; + if (!settings.enabled) { + this._resolveRetry(); + return false; + } - this._retryAttempt++; - - // Create retry promise on first attempt so waitForRetry() can await it - if (this._retryAttempt === 1 && !this._retryPromise) { + // Retry promise is created synchronously in _handleAgentEvent for agent_end. + // Keep a defensive fallback here in case a future refactor bypasses that path. + if (!this._retryPromise) { this._retryPromise = new Promise((resolve) => { this._retryResolve = resolve; }); } + this._retryAttempt++; + if (this._retryAttempt > settings.maxRetries) { // Max retries exceeded, emit final failure and reset this._emit({ diff --git a/packages/coding-agent/test/agent-session-retry.test.ts b/packages/coding-agent/test/agent-session-retry.test.ts new file mode 100644 index 00000000..91bd799c --- /dev/null +++ b/packages/coding-agent/test/agent-session-retry.test.ts @@ -0,0 +1,171 @@ +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { Agent, type AgentEvent } from "@mariozechner/pi-agent-core"; +import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { AgentSession } from "../src/core/agent-session.js"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { ModelRegistry } from "../src/core/model-registry.js"; +import { SessionManager } from "../src/core/session-manager.js"; +import { SettingsManager } from "../src/core/settings-manager.js"; +import { createTestResourceLoader } from "./utilities.js"; + +class MockAssistantStream extends EventStream { + constructor() { + super( + (event) => event.type === "done" || event.type === "error", + (event) => { + if (event.type === "done") return event.message; + if (event.type === "error") return event.error; + throw new Error("Unexpected event type"); + }, + ); + } +} + +function createAssistantMessage(text: string, overrides?: Partial): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text }], + api: "anthropic-messages", + provider: "anthropic", + model: "mock", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + ...overrides, + }; +} + +type SessionWithExtensionEmitHook = { + _emitExtensionEvent: (event: AgentEvent) => Promise; +}; + +describe("AgentSession retry", () => { + let session: AgentSession; + let tempDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `pi-retry-test-${Date.now()}`); + mkdirSync(tempDir, { recursive: true }); + }); + + afterEach(() => { + if (session) { + session.dispose(); + } + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }); + + function createSession(options?: { failCount?: number; maxRetries?: number; delayAssistantMessageEndMs?: number }) { + const failCount = options?.failCount ?? 1; + const maxRetries = options?.maxRetries ?? 3; + const delayAssistantMessageEndMs = options?.delayAssistantMessageEndMs ?? 0; + let callCount = 0; + + const model = getModel("anthropic", "claude-sonnet-4-5")!; + const agent = new Agent({ + getApiKey: () => "test-key", + initialState: { model, systemPrompt: "Test", tools: [] }, + streamFn: () => { + callCount++; + const stream = new MockAssistantStream(); + queueMicrotask(() => { + if (callCount <= failCount) { + const msg = createAssistantMessage("", { + stopReason: "error", + errorMessage: "overloaded_error", + }); + stream.push({ type: "start", partial: msg }); + stream.push({ type: "error", reason: "error", error: msg }); + } else { + const msg = createAssistantMessage("Success"); + stream.push({ type: "start", partial: msg }); + stream.push({ type: "done", reason: "stop", message: msg }); + } + }); + return stream; + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.create(tempDir, tempDir); + const authStorage = AuthStorage.create(join(tempDir, "auth.json")); + const modelRegistry = new ModelRegistry(authStorage, tempDir); + authStorage.setRuntimeApiKey("anthropic", "test-key"); + settingsManager.applyOverrides({ retry: { enabled: true, maxRetries, baseDelayMs: 1 } }); + + session = new AgentSession({ + agent, + sessionManager, + settingsManager, + cwd: tempDir, + modelRegistry, + resourceLoader: createTestResourceLoader(), + }); + + if (delayAssistantMessageEndMs > 0) { + const sessionWithHook = session as unknown as SessionWithExtensionEmitHook; + const original = sessionWithHook._emitExtensionEvent.bind(sessionWithHook); + sessionWithHook._emitExtensionEvent = async (event: AgentEvent) => { + if (event.type === "message_end" && event.message.role === "assistant") { + await new Promise((resolve) => setTimeout(resolve, delayAssistantMessageEndMs)); + } + await original(event); + }; + } + + return { session, getCallCount: () => callCount }; + } + + it("retries after a transient error and succeeds", async () => { + const created = createSession({ failCount: 1 }); + const events: string[] = []; + created.session.subscribe((event) => { + if (event.type === "auto_retry_start") events.push(`start:${event.attempt}`); + if (event.type === "auto_retry_end") events.push(`end:success=${event.success}`); + }); + + await created.session.prompt("Test"); + + expect(created.getCallCount()).toBe(2); + expect(events).toEqual(["start:1", "end:success=true"]); + expect(created.session.isRetrying).toBe(false); + }); + + it("exhausts max retries and emits failure", async () => { + const created = createSession({ failCount: 99, maxRetries: 2 }); + const events: string[] = []; + created.session.subscribe((event) => { + if (event.type === "auto_retry_start") events.push(`start:${event.attempt}`); + if (event.type === "auto_retry_end") events.push(`end:success=${event.success}`); + }); + + await created.session.prompt("Test"); + + expect(created.getCallCount()).toBe(3); + expect(events).toContain("start:1"); + expect(events).toContain("start:2"); + expect(events).toContain("end:success=false"); + expect(created.session.isRetrying).toBe(false); + }); + + it("prompt waits for retry completion even when assistant message_end handling is delayed", async () => { + const created = createSession({ failCount: 1, delayAssistantMessageEndMs: 40 }); + + await created.session.prompt("Test"); + + expect(created.getCallCount()).toBe(2); + expect(created.session.isRetrying).toBe(false); + }); +});