feat(ai): add cacheRetention stream option

This commit is contained in:
Mario Zechner 2026-02-01 09:30:23 +01:00
parent e9ca0be769
commit abfd04b5c5
5 changed files with 174 additions and 39 deletions

View file

@ -5,6 +5,7 @@
### Added
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134))
## [0.50.8] - 2026-02-01

View file

@ -9,6 +9,7 @@ import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
ImageContent,
Message,
@ -31,19 +32,32 @@ import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.j
import { transformMessages } from "./transform-messages.js";
/**
* Get cache TTL based on PI_CACHE_RETENTION env var.
* Only applies to direct Anthropic API calls (api.anthropic.com).
* Returns '1h' for long retention, undefined for default (5m).
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function getCacheTtl(baseUrl: string): "1h" | undefined {
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long" &&
baseUrl.includes("api.anthropic.com")
) {
return "1h";
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
return undefined;
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
function getCacheControl(
baseUrl: string,
cacheRetention?: CacheRetention,
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
const retention = resolveCacheRetention(cacheRetention);
if (retention === "none") {
return { retention };
}
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
return {
retention,
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
};
}
// Stealth mode: Mimic Claude Code's tool naming exactly
@ -460,34 +474,28 @@ function buildParams(
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(context.messages, model, isOAuthToken),
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
const cacheTtl = getCacheTtl(model.baseUrl);
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: {
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
cache_control: {
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
...(cacheControl ? { cache_control: cacheControl } : {}),
});
}
} else if (context.systemPrompt) {
@ -496,10 +504,7 @@ function buildParams(
{
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
cache_control: {
type: "ephemeral",
...(cacheTtl && { ttl: cacheTtl }),
},
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
}
@ -539,6 +544,7 @@ function convertMessages(
messages: Message[],
model: Model<"anthropic-messages">,
isOAuthToken: boolean,
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
): MessageParam[] {
const params: MessageParam[] = [];
@ -665,7 +671,7 @@ function convertMessages(
}
// Add cache_control to the last user message to cache conversation history
if (params.length > 0) {
if (cacheControl && params.length > 0) {
const lastMessage = params[params.length - 1];
if (lastMessage.role === "user") {
// Add cache control to the last content block
@ -675,8 +681,7 @@ function convertMessages(
lastBlock &&
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
) {
const cacheTtl = getCacheTtl(model.baseUrl);
(lastBlock as any).cache_control = { type: "ephemeral", ...(cacheTtl && { ttl: cacheTtl }) };
(lastBlock as any).cache_control = cacheControl;
}
}
}

View file

@ -5,6 +5,7 @@ import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
@ -19,16 +20,28 @@ import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
/**
* Get prompt cache retention based on PI_CACHE_RETENTION env var.
* Only applies to direct OpenAI API calls (api.openai.com).
* Returns '24h' for long retention, undefined for default (in-memory).
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function getPromptCacheRetention(baseUrl: string): "24h" | undefined {
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long" &&
baseUrl.includes("api.openai.com")
) {
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
/**
* Get prompt cache retention based on cacheRetention and base URL.
* Only applies to direct OpenAI API calls (api.openai.com).
*/
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
if (cacheRetention !== "long") {
return undefined;
}
if (baseUrl.includes("api.openai.com")) {
return "24h";
}
return undefined;
@ -186,12 +199,13 @@ function createClient(
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const params: ResponseCreateParamsStreaming = {
model: model.id,
input: messages,
stream: true,
prompt_cache_key: options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(model.baseUrl),
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
};
if (options?.maxTokens) {

View file

@ -51,11 +51,18 @@ export interface ThinkingBudgets {
}
// Base options all providers share
export type CacheRetention = "none" | "short" | "long";
export interface StreamOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
/**
* Prompt cache retention preference. Providers map this to their supported values.
* Default: "short".
*/
cacheRetention?: CacheRetention;
/**
* Optional session identifier for providers that support session-based caching.
* Providers can use this to enable prompt caching, request routing, or other

View file

@ -112,6 +112,58 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" });
}
});
it("should omit cache_control when cacheRetention is none", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "none",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toBeUndefined();
});
it("should set 1h cache TTL when cacheRetention is long", async () => {
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
let capturedPayload: any = null;
const { streamAnthropic } = await import("../src/providers/anthropic.js");
try {
const s = streamAnthropic(baseModel, context, {
apiKey: "fake-key",
cacheRetention: "long",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral", ttl: "1h" });
});
});
describe("OpenAI Responses Provider", () => {
@ -195,5 +247,61 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
}
});
it("should omit prompt_cache_key when cacheRetention is none", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "none",
sessionId: "session-1",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBeUndefined();
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
});
it("should set prompt_cache_retention when cacheRetention is long", async () => {
const model = getModel("openai", "gpt-4o-mini");
let capturedPayload: any = null;
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
try {
const s = streamOpenAIResponses(model, context, {
apiKey: "fake-key",
cacheRetention: "long",
sessionId: "session-2",
onPayload: (payload) => {
capturedPayload = payload;
},
});
for await (const event of s) {
if (event.type === "error") break;
}
} catch {
// Expected to fail
}
expect(capturedPayload).not.toBeNull();
expect(capturedPayload.prompt_cache_key).toBe("session-2");
expect(capturedPayload.prompt_cache_retention).toBe("24h");
});
});
});