diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index f8bf2565..9add6cce 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -5,6 +5,7 @@ ### Added - Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129)) +- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134)) ## [0.50.8] - 2026-02-01 diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index a52f7322..c2ce38b4 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -9,6 +9,7 @@ import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, + CacheRetention, Context, ImageContent, Message, @@ -31,19 +32,32 @@ import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.j import { transformMessages } from "./transform-messages.js"; /** - * Get cache TTL based on PI_CACHE_RETENTION env var. - * Only applies to direct Anthropic API calls (api.anthropic.com). - * Returns '1h' for long retention, undefined for default (5m). + * Resolve cache retention preference. + * Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility. */ -function getCacheTtl(baseUrl: string): "1h" | undefined { - if ( - typeof process !== "undefined" && - process.env.PI_CACHE_RETENTION === "long" && - baseUrl.includes("api.anthropic.com") - ) { - return "1h"; +function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention { + if (cacheRetention) { + return cacheRetention; } - return undefined; + if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") { + return "long"; + } + return "short"; +} + +function getCacheControl( + baseUrl: string, + cacheRetention?: CacheRetention, +): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } { + const retention = resolveCacheRetention(cacheRetention); + if (retention === "none") { + return { retention }; + } + const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined; + return { + retention, + cacheControl: { type: "ephemeral", ...(ttl && { ttl }) }, + }; } // Stealth mode: Mimic Claude Code's tool naming exactly @@ -460,34 +474,28 @@ function buildParams( isOAuthToken: boolean, options?: AnthropicOptions, ): MessageCreateParamsStreaming { + const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention); const params: MessageCreateParamsStreaming = { model: model.id, - messages: convertMessages(context.messages, model, isOAuthToken), + messages: convertMessages(context.messages, model, isOAuthToken, cacheControl), max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0, stream: true, }; // For OAuth tokens, we MUST include Claude Code identity - const cacheTtl = getCacheTtl(model.baseUrl); if (isOAuthToken) { params.system = [ { type: "text", text: "You are Claude Code, Anthropic's official CLI for Claude.", - cache_control: { - type: "ephemeral", - ...(cacheTtl && { ttl: cacheTtl }), - }, + ...(cacheControl ? { cache_control: cacheControl } : {}), }, ]; if (context.systemPrompt) { params.system.push({ type: "text", text: sanitizeSurrogates(context.systemPrompt), - cache_control: { - type: "ephemeral", - ...(cacheTtl && { ttl: cacheTtl }), - }, + ...(cacheControl ? { cache_control: cacheControl } : {}), }); } } else if (context.systemPrompt) { @@ -496,10 +504,7 @@ function buildParams( { type: "text", text: sanitizeSurrogates(context.systemPrompt), - cache_control: { - type: "ephemeral", - ...(cacheTtl && { ttl: cacheTtl }), - }, + ...(cacheControl ? { cache_control: cacheControl } : {}), }, ]; } @@ -539,6 +544,7 @@ function convertMessages( messages: Message[], model: Model<"anthropic-messages">, isOAuthToken: boolean, + cacheControl?: { type: "ephemeral"; ttl?: "1h" }, ): MessageParam[] { const params: MessageParam[] = []; @@ -665,7 +671,7 @@ function convertMessages( } // Add cache_control to the last user message to cache conversation history - if (params.length > 0) { + if (cacheControl && params.length > 0) { const lastMessage = params[params.length - 1]; if (lastMessage.role === "user") { // Add cache control to the last content block @@ -675,8 +681,7 @@ function convertMessages( lastBlock && (lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result") ) { - const cacheTtl = getCacheTtl(model.baseUrl); - (lastBlock as any).cache_control = { type: "ephemeral", ...(cacheTtl && { ttl: cacheTtl }) }; + (lastBlock as any).cache_control = cacheControl; } } } diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index 21db9d60..f96215a1 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -5,6 +5,7 @@ import { supportsXhigh } from "../models.js"; import type { Api, AssistantMessage, + CacheRetention, Context, Model, SimpleStreamOptions, @@ -19,16 +20,28 @@ import { buildBaseOptions, clampReasoning } from "./simple-options.js"; const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]); /** - * Get prompt cache retention based on PI_CACHE_RETENTION env var. - * Only applies to direct OpenAI API calls (api.openai.com). - * Returns '24h' for long retention, undefined for default (in-memory). + * Resolve cache retention preference. + * Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility. */ -function getPromptCacheRetention(baseUrl: string): "24h" | undefined { - if ( - typeof process !== "undefined" && - process.env.PI_CACHE_RETENTION === "long" && - baseUrl.includes("api.openai.com") - ) { +function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention { + if (cacheRetention) { + return cacheRetention; + } + if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") { + return "long"; + } + return "short"; +} + +/** + * Get prompt cache retention based on cacheRetention and base URL. + * Only applies to direct OpenAI API calls (api.openai.com). + */ +function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined { + if (cacheRetention !== "long") { + return undefined; + } + if (baseUrl.includes("api.openai.com")) { return "24h"; } return undefined; @@ -186,12 +199,13 @@ function createClient( function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) { const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS); + const cacheRetention = resolveCacheRetention(options?.cacheRetention); const params: ResponseCreateParamsStreaming = { model: model.id, input: messages, stream: true, - prompt_cache_key: options?.sessionId, - prompt_cache_retention: getPromptCacheRetention(model.baseUrl), + prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId, + prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention), }; if (options?.maxTokens) { diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index f8c8921f..3485fd3b 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -51,11 +51,18 @@ export interface ThinkingBudgets { } // Base options all providers share +export type CacheRetention = "none" | "short" | "long"; + export interface StreamOptions { temperature?: number; maxTokens?: number; signal?: AbortSignal; apiKey?: string; + /** + * Prompt cache retention preference. Providers map this to their supported values. + * Default: "short". + */ + cacheRetention?: CacheRetention; /** * Optional session identifier for providers that support session-based caching. * Providers can use this to enable prompt caching, request routing, or other diff --git a/packages/ai/test/cache-retention.test.ts b/packages/ai/test/cache-retention.test.ts index 2e585a47..8fdcadb9 100644 --- a/packages/ai/test/cache-retention.test.ts +++ b/packages/ai/test/cache-retention.test.ts @@ -112,6 +112,58 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => { expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" }); } }); + + it("should omit cache_control when cacheRetention is none", async () => { + const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022"); + let capturedPayload: any = null; + + const { streamAnthropic } = await import("../src/providers/anthropic.js"); + + try { + const s = streamAnthropic(baseModel, context, { + apiKey: "fake-key", + cacheRetention: "none", + onPayload: (payload) => { + capturedPayload = payload; + }, + }); + + for await (const event of s) { + if (event.type === "error") break; + } + } catch { + // Expected to fail + } + + expect(capturedPayload).not.toBeNull(); + expect(capturedPayload.system[0].cache_control).toBeUndefined(); + }); + + it("should set 1h cache TTL when cacheRetention is long", async () => { + const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022"); + let capturedPayload: any = null; + + const { streamAnthropic } = await import("../src/providers/anthropic.js"); + + try { + const s = streamAnthropic(baseModel, context, { + apiKey: "fake-key", + cacheRetention: "long", + onPayload: (payload) => { + capturedPayload = payload; + }, + }); + + for await (const event of s) { + if (event.type === "error") break; + } + } catch { + // Expected to fail + } + + expect(capturedPayload).not.toBeNull(); + expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral", ttl: "1h" }); + }); }); describe("OpenAI Responses Provider", () => { @@ -195,5 +247,61 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => { expect(capturedPayload.prompt_cache_retention).toBeUndefined(); } }); + + it("should omit prompt_cache_key when cacheRetention is none", async () => { + const model = getModel("openai", "gpt-4o-mini"); + let capturedPayload: any = null; + + const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js"); + + try { + const s = streamOpenAIResponses(model, context, { + apiKey: "fake-key", + cacheRetention: "none", + sessionId: "session-1", + onPayload: (payload) => { + capturedPayload = payload; + }, + }); + + for await (const event of s) { + if (event.type === "error") break; + } + } catch { + // Expected to fail + } + + expect(capturedPayload).not.toBeNull(); + expect(capturedPayload.prompt_cache_key).toBeUndefined(); + expect(capturedPayload.prompt_cache_retention).toBeUndefined(); + }); + + it("should set prompt_cache_retention when cacheRetention is long", async () => { + const model = getModel("openai", "gpt-4o-mini"); + let capturedPayload: any = null; + + const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js"); + + try { + const s = streamOpenAIResponses(model, context, { + apiKey: "fake-key", + cacheRetention: "long", + sessionId: "session-2", + onPayload: (payload) => { + capturedPayload = payload; + }, + }); + + for await (const event of s) { + if (event.type === "error") break; + } + } catch { + // Expected to fail + } + + expect(capturedPayload).not.toBeNull(); + expect(capturedPayload.prompt_cache_key).toBe("session-2"); + expect(capturedPayload.prompt_cache_retention).toBe("24h"); + }); }); });