mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 22:03:45 +00:00
feat(ai): add cacheRetention stream option
This commit is contained in:
parent
e9ca0be769
commit
abfd04b5c5
5 changed files with 174 additions and 39 deletions
|
|
@ -5,6 +5,7 @@
|
|||
### Added
|
||||
|
||||
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
|
||||
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134))
|
||||
|
||||
## [0.50.8] - 2026-02-01
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import { calculateCost } from "../models.js";
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
|
|
@ -31,19 +32,32 @@ import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.j
|
|||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Get cache TTL based on PI_CACHE_RETENTION env var.
|
||||
* Only applies to direct Anthropic API calls (api.anthropic.com).
|
||||
* Returns '1h' for long retention, undefined for default (5m).
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function getCacheTtl(baseUrl: string): "1h" | undefined {
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long" &&
|
||||
baseUrl.includes("api.anthropic.com")
|
||||
) {
|
||||
return "1h";
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
return undefined;
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
function getCacheControl(
|
||||
baseUrl: string,
|
||||
cacheRetention?: CacheRetention,
|
||||
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
|
||||
return {
|
||||
retention,
|
||||
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||
};
|
||||
}
|
||||
|
||||
// Stealth mode: Mimic Claude Code's tool naming exactly
|
||||
|
|
@ -460,34 +474,28 @@ function buildParams(
|
|||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, model, isOAuthToken),
|
||||
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
|
||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
const cacheTtl = getCacheTtl(model.baseUrl);
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
...(cacheTtl && { ttl: cacheTtl }),
|
||||
},
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
...(cacheTtl && { ttl: cacheTtl }),
|
||||
},
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
|
|
@ -496,10 +504,7 @@ function buildParams(
|
|||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
...(cacheTtl && { ttl: cacheTtl }),
|
||||
},
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
|
@ -539,6 +544,7 @@ function convertMessages(
|
|||
messages: Message[],
|
||||
model: Model<"anthropic-messages">,
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
|
|
@ -665,7 +671,7 @@ function convertMessages(
|
|||
}
|
||||
|
||||
// Add cache_control to the last user message to cache conversation history
|
||||
if (params.length > 0) {
|
||||
if (cacheControl && params.length > 0) {
|
||||
const lastMessage = params[params.length - 1];
|
||||
if (lastMessage.role === "user") {
|
||||
// Add cache control to the last content block
|
||||
|
|
@ -675,8 +681,7 @@ function convertMessages(
|
|||
lastBlock &&
|
||||
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
|
||||
) {
|
||||
const cacheTtl = getCacheTtl(model.baseUrl);
|
||||
(lastBlock as any).cache_control = { type: "ephemeral", ...(cacheTtl && { ttl: cacheTtl }) };
|
||||
(lastBlock as any).cache_control = cacheControl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { supportsXhigh } from "../models.js";
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
|
|
@ -19,16 +20,28 @@ import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
|||
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on PI_CACHE_RETENTION env var.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
* Returns '24h' for long retention, undefined for default (in-memory).
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function getPromptCacheRetention(baseUrl: string): "24h" | undefined {
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long" &&
|
||||
baseUrl.includes("api.openai.com")
|
||||
) {
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on cacheRetention and base URL.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
*/
|
||||
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
|
||||
if (cacheRetention !== "long") {
|
||||
return undefined;
|
||||
}
|
||||
if (baseUrl.includes("api.openai.com")) {
|
||||
return "24h";
|
||||
}
|
||||
return undefined;
|
||||
|
|
@ -186,12 +199,13 @@ function createClient(
|
|||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl),
|
||||
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
|
|
|
|||
|
|
@ -51,11 +51,18 @@ export interface ThinkingBudgets {
|
|||
}
|
||||
|
||||
// Base options all providers share
|
||||
export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export interface StreamOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
/**
|
||||
* Prompt cache retention preference. Providers map this to their supported values.
|
||||
* Default: "short".
|
||||
*/
|
||||
cacheRetention?: CacheRetention;
|
||||
/**
|
||||
* Optional session identifier for providers that support session-based caching.
|
||||
* Providers can use this to enable prompt caching, request routing, or other
|
||||
|
|
|
|||
|
|
@ -112,6 +112,58 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
|
|||
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" });
|
||||
}
|
||||
});
|
||||
|
||||
it("should omit cache_control when cacheRetention is none", async () => {
|
||||
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
let capturedPayload: any = null;
|
||||
|
||||
const { streamAnthropic } = await import("../src/providers/anthropic.js");
|
||||
|
||||
try {
|
||||
const s = streamAnthropic(baseModel, context, {
|
||||
apiKey: "fake-key",
|
||||
cacheRetention: "none",
|
||||
onPayload: (payload) => {
|
||||
capturedPayload = payload;
|
||||
},
|
||||
});
|
||||
|
||||
for await (const event of s) {
|
||||
if (event.type === "error") break;
|
||||
}
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
expect(capturedPayload).not.toBeNull();
|
||||
expect(capturedPayload.system[0].cache_control).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should set 1h cache TTL when cacheRetention is long", async () => {
|
||||
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
let capturedPayload: any = null;
|
||||
|
||||
const { streamAnthropic } = await import("../src/providers/anthropic.js");
|
||||
|
||||
try {
|
||||
const s = streamAnthropic(baseModel, context, {
|
||||
apiKey: "fake-key",
|
||||
cacheRetention: "long",
|
||||
onPayload: (payload) => {
|
||||
capturedPayload = payload;
|
||||
},
|
||||
});
|
||||
|
||||
for await (const event of s) {
|
||||
if (event.type === "error") break;
|
||||
}
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
expect(capturedPayload).not.toBeNull();
|
||||
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral", ttl: "1h" });
|
||||
});
|
||||
});
|
||||
|
||||
describe("OpenAI Responses Provider", () => {
|
||||
|
|
@ -195,5 +247,61 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
|
|||
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
|
||||
}
|
||||
});
|
||||
|
||||
it("should omit prompt_cache_key when cacheRetention is none", async () => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
let capturedPayload: any = null;
|
||||
|
||||
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
|
||||
|
||||
try {
|
||||
const s = streamOpenAIResponses(model, context, {
|
||||
apiKey: "fake-key",
|
||||
cacheRetention: "none",
|
||||
sessionId: "session-1",
|
||||
onPayload: (payload) => {
|
||||
capturedPayload = payload;
|
||||
},
|
||||
});
|
||||
|
||||
for await (const event of s) {
|
||||
if (event.type === "error") break;
|
||||
}
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
expect(capturedPayload).not.toBeNull();
|
||||
expect(capturedPayload.prompt_cache_key).toBeUndefined();
|
||||
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should set prompt_cache_retention when cacheRetention is long", async () => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
let capturedPayload: any = null;
|
||||
|
||||
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
|
||||
|
||||
try {
|
||||
const s = streamOpenAIResponses(model, context, {
|
||||
apiKey: "fake-key",
|
||||
cacheRetention: "long",
|
||||
sessionId: "session-2",
|
||||
onPayload: (payload) => {
|
||||
capturedPayload = payload;
|
||||
},
|
||||
});
|
||||
|
||||
for await (const event of s) {
|
||||
if (event.type === "error") break;
|
||||
}
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
expect(capturedPayload).not.toBeNull();
|
||||
expect(capturedPayload.prompt_cache_key).toBe("session-2");
|
||||
expect(capturedPayload.prompt_cache_retention).toBe("24h");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue