import { existsSync, mkdirSync, rmSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; import { Agent, type AgentEvent } from "@mariozechner/pi-agent-core"; import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@mariozechner/pi-ai"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { AgentSession } from "../src/core/agent-session.js"; import { AuthStorage } from "../src/core/auth-storage.js"; import { ModelRegistry } from "../src/core/model-registry.js"; import { SessionManager } from "../src/core/session-manager.js"; import { SettingsManager } from "../src/core/settings-manager.js"; import { createTestResourceLoader } from "./utilities.js"; class MockAssistantStream extends EventStream { constructor() { super( (event) => event.type === "done" || event.type === "error", (event) => { if (event.type === "done") return event.message; if (event.type === "error") return event.error; throw new Error("Unexpected event type"); }, ); } } function createAssistantMessage(text: string, overrides?: Partial): AssistantMessage { return { role: "assistant", content: [{ type: "text", text }], api: "anthropic-messages", provider: "anthropic", model: "mock", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "stop", timestamp: Date.now(), ...overrides, }; } type SessionWithExtensionEmitHook = { _emitExtensionEvent: (event: AgentEvent) => Promise; }; describe("AgentSession retry", () => { let session: AgentSession; let tempDir: string; beforeEach(() => { tempDir = join(tmpdir(), `pi-retry-test-${Date.now()}`); mkdirSync(tempDir, { recursive: true }); }); afterEach(() => { if (session) { session.dispose(); } if (tempDir && existsSync(tempDir)) { rmSync(tempDir, { recursive: true }); } }); function createSession(options?: { failCount?: number; maxRetries?: number; delayAssistantMessageEndMs?: number }) { const failCount = options?.failCount ?? 1; const maxRetries = options?.maxRetries ?? 3; const delayAssistantMessageEndMs = options?.delayAssistantMessageEndMs ?? 0; let callCount = 0; const model = getModel("anthropic", "claude-sonnet-4-5")!; const agent = new Agent({ getApiKey: () => "test-key", initialState: { model, systemPrompt: "Test", tools: [] }, streamFn: () => { callCount++; const stream = new MockAssistantStream(); queueMicrotask(() => { if (callCount <= failCount) { const msg = createAssistantMessage("", { stopReason: "error", errorMessage: "overloaded_error", }); stream.push({ type: "start", partial: msg }); stream.push({ type: "error", reason: "error", error: msg }); } else { const msg = createAssistantMessage("Success"); stream.push({ type: "start", partial: msg }); stream.push({ type: "done", reason: "stop", message: msg }); } }); return stream; }, }); const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); const modelRegistry = new ModelRegistry(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); settingsManager.applyOverrides({ retry: { enabled: true, maxRetries, baseDelayMs: 1 } }); session = new AgentSession({ agent, sessionManager, settingsManager, cwd: tempDir, modelRegistry, resourceLoader: createTestResourceLoader(), }); if (delayAssistantMessageEndMs > 0) { const sessionWithHook = session as unknown as SessionWithExtensionEmitHook; const original = sessionWithHook._emitExtensionEvent.bind(sessionWithHook); sessionWithHook._emitExtensionEvent = async (event: AgentEvent) => { if (event.type === "message_end" && event.message.role === "assistant") { await new Promise((resolve) => setTimeout(resolve, delayAssistantMessageEndMs)); } await original(event); }; } return { session, getCallCount: () => callCount }; } it("retries after a transient error and succeeds", async () => { const created = createSession({ failCount: 1 }); const events: string[] = []; created.session.subscribe((event) => { if (event.type === "auto_retry_start") events.push(`start:${event.attempt}`); if (event.type === "auto_retry_end") events.push(`end:success=${event.success}`); }); await created.session.prompt("Test"); expect(created.getCallCount()).toBe(2); expect(events).toEqual(["start:1", "end:success=true"]); expect(created.session.isRetrying).toBe(false); }); it("exhausts max retries and emits failure", async () => { const created = createSession({ failCount: 99, maxRetries: 2 }); const events: string[] = []; created.session.subscribe((event) => { if (event.type === "auto_retry_start") events.push(`start:${event.attempt}`); if (event.type === "auto_retry_end") events.push(`end:success=${event.success}`); }); await created.session.prompt("Test"); expect(created.getCallCount()).toBe(3); expect(events).toContain("start:1"); expect(events).toContain("start:2"); expect(events).toContain("end:success=false"); expect(created.session.isRetrying).toBe(false); }); it("prompt waits for retry completion even when assistant message_end handling is delayed", async () => { const created = createSession({ failCount: 1, delayAssistantMessageEndMs: 40 }); await created.session.prompt("Test"); expect(created.getCallCount()).toBe(2); expect(created.session.isRetrying).toBe(false); }); });