feat(ai): add cacheRetention stream option

This commit is contained in:
Mario Zechner 2026-02-01 09:30:23 +01:00
parent e9ca0be769
commit abfd04b5c5
5 changed files with 174 additions and 39 deletions

View file

@ -5,6 +5,7 @@
### Added ### Added
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129)) - Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134))
## [0.50.8] - 2026-02-01 ## [0.50.8] - 2026-02-01

View file

@ -9,6 +9,7 @@ import { calculateCost } from "../models.js";
import type { import type {
Api, Api,
AssistantMessage, AssistantMessage,
CacheRetention,
Context, Context,
ImageContent, ImageContent,
Message, Message,
@ -31,19 +32,32 @@ import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.j
import { transformMessages } from "./transform-messages.js"; import { transformMessages } from "./transform-messages.js";
/** /**
* Get cache TTL based on PI_CACHE_RETENTION env var. * Resolve cache retention preference.
* Only applies to direct Anthropic API calls (api.anthropic.com). * Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
* Returns '1h' for long retention, undefined for default (5m).
*/ */
function getCacheTtl(baseUrl: string): "1h" | undefined { function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if ( if (cacheRetention) {
typeof process !== "undefined" && return cacheRetention;
process.env.PI_CACHE_RETENTION === "long" &&
baseUrl.includes("api.anthropic.com")
) {
return "1h";
} }
return undefined; if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
function getCacheControl(
baseUrl: string,
cacheRetention?: CacheRetention,
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
const retention = resolveCacheRetention(cacheRetention);
if (retention === "none") {
return { retention };
}
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
return {
retention,
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
};
} }
// Stealth mode: Mimic Claude Code's tool naming exactly // Stealth mode: Mimic Claude Code's tool naming exactly
@ -460,34 +474,28 @@ function buildParams(
isOAuthToken: boolean, isOAuthToken: boolean,
options?: AnthropicOptions, options?: AnthropicOptions,
): MessageCreateParamsStreaming { ): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
const params: MessageCreateParamsStreaming = { const params: MessageCreateParamsStreaming = {
model: model.id, model: model.id,
messages: convertMessages(context.messages, model, isOAuthToken), messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0, max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
stream: true, stream: true,
}; };
// For OAuth tokens, we MUST include Claude Code identity // For OAuth tokens, we MUST include Claude Code identity
const cacheTtl = getCacheTtl(model.baseUrl);
if (isOAuthToken) { if (isOAuthToken) {
params.system = [ params.system = [
{ {
type: "text", type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.", text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: { ...(cacheControl ? { cache_control: cacheControl } : {}),
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
}, },
]; ];
if (context.systemPrompt) { if (context.systemPrompt) {
params.system.push({ params.system.push({
type: "text", type: "text",
text: sanitizeSurrogates(context.systemPrompt), text: sanitizeSurrogates(context.systemPrompt),
cache_control: { ...(cacheControl ? { cache_control: cacheControl } : {}),
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
}); });
} }
} else if (context.systemPrompt) { } else if (context.systemPrompt) {
@ -496,10 +504,7 @@ function buildParams(
{ {
type: "text", type: "text",
text: sanitizeSurrogates(context.systemPrompt), text: sanitizeSurrogates(context.systemPrompt),
cache_control: { ...(cacheControl ? { cache_control: cacheControl } : {}),
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
}, },
]; ];
} }
@ -539,6 +544,7 @@ function convertMessages(
messages: Message[], messages: Message[],
model: Model<"anthropic-messages">, model: Model<"anthropic-messages">,
isOAuthToken: boolean, isOAuthToken: boolean,
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
): MessageParam[] { ): MessageParam[] {
const params: MessageParam[] = []; const params: MessageParam[] = [];
@ -665,7 +671,7 @@ function convertMessages(
} }
// Add cache_control to the last user message to cache conversation history // Add cache_control to the last user message to cache conversation history
if (params.length > 0) { if (cacheControl && params.length > 0) {
const lastMessage = params[params.length - 1]; const lastMessage = params[params.length - 1];
if (lastMessage.role === "user") { if (lastMessage.role === "user") {
// Add cache control to the last content block // Add cache control to the last content block
@ -675,8 +681,7 @@ function convertMessages(
lastBlock && lastBlock &&
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result") (lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
) { ) {
const cacheTtl = getCacheTtl(model.baseUrl); (lastBlock as any).cache_control = cacheControl;
(lastBlock as any).cache_control = { type: "ephemeral", ...(cacheTtl && { ttl: cacheTtl }) };
} }
} }
} }

View file

@ -5,6 +5,7 @@ import { supportsXhigh } from "../models.js";
import type { import type {
Api, Api,
AssistantMessage, AssistantMessage,
CacheRetention,
Context, Context,
Model, Model,
SimpleStreamOptions, SimpleStreamOptions,
@ -19,16 +20,28 @@ import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]); const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
/** /**
* Get prompt cache retention based on PI_CACHE_RETENTION env var. * Resolve cache retention preference.
* Only applies to direct OpenAI API calls (api.openai.com). * Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
* Returns '24h' for long retention, undefined for default (in-memory).
*/ */
function getPromptCacheRetention(baseUrl: string): "24h" | undefined { function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if ( if (cacheRetention) {
typeof process !== "undefined" && return cacheRetention;
process.env.PI_CACHE_RETENTION === "long" && }
baseUrl.includes("api.openai.com") if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
) { return "long";
}
return "short";
}
/**
* Get prompt cache retention based on cacheRetention and base URL.
* Only applies to direct OpenAI API calls (api.openai.com).
*/
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
if (cacheRetention !== "long") {
return undefined;
}
if (baseUrl.includes("api.openai.com")) {
return "24h"; return "24h";
} }
return undefined; return undefined;
@ -186,12 +199,13 @@ function createClient(
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) { function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS); const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const params: ResponseCreateParamsStreaming = { const params: ResponseCreateParamsStreaming = {
model: model.id, model: model.id,
input: messages, input: messages,
stream: true, stream: true,
prompt_cache_key: options?.sessionId, prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(model.baseUrl), prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
}; };
if (options?.maxTokens) { if (options?.maxTokens) {

View file

@ -51,11 +51,18 @@ export interface ThinkingBudgets {
} }
// Base options all providers share // Base options all providers share
export type CacheRetention = "none" | "short" | "long";
export interface StreamOptions { export interface StreamOptions {
temperature?: number; temperature?: number;
maxTokens?: number; maxTokens?: number;
signal?: AbortSignal; signal?: AbortSignal;
apiKey?: string; apiKey?: string;
/**
* Prompt cache retention preference. Providers map this to their supported values.
* Default: "short".
*/
cacheRetention?: CacheRetention;
/** /**
* Optional session identifier for providers that support session-based caching. * Optional session identifier for providers that support session-based caching.
* Providers can use this to enable prompt caching, request routing, or other * Providers can use this to enable prompt caching, request routing, or other

View file

@ -112,6 +112,58 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" }); expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" });
} }
}); });
it("should omit cache_control when cacheRetention is none", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "none",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toBeUndefined();
});
it("should set 1h cache TTL when cacheRetention is long", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "long",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral", ttl: "1h" });
});
}); });
describe("OpenAI Responses Provider", () => { describe("OpenAI Responses Provider", () => {
@ -195,5 +247,61 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
expect(capturedPayload.prompt_cache_retention).toBeUndefined(); expect(capturedPayload.prompt_cache_retention).toBeUndefined();
} }
}); });
it("should omit prompt_cache_key when cacheRetention is none", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "none",
sessionId: "session-1",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBeUndefined();
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
});
it("should set prompt_cache_retention when cacheRetention is long", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "long",
sessionId: "session-2",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBe("session-2");
expect(capturedPayload.prompt_cache_retention).toBe("24h");
});
}); });
}); });