mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 20:01:06 +00:00
feat(ai): add cacheRetention stream option
This commit is contained in:
parent
e9ca0be769
commit
abfd04b5c5
5 changed files with 174 additions and 39 deletions
|
|
@ -5,6 +5,7 @@
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
|
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/badlogic/pi-mono/issues/1129))
|
||||||
|
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/badlogic/pi-mono/issues/1134))
|
||||||
|
|
||||||
## [0.50.8] - 2026-02-01
|
## [0.50.8] - 2026-02-01
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
Api,
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
|
CacheRetention,
|
||||||
Context,
|
Context,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
Message,
|
Message,
|
||||||
|
|
@ -31,19 +32,32 @@ import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.j
|
||||||
import { transformMessages } from "./transform-messages.js";
|
import { transformMessages } from "./transform-messages.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get cache TTL based on PI_CACHE_RETENTION env var.
|
* Resolve cache retention preference.
|
||||||
* Only applies to direct Anthropic API calls (api.anthropic.com).
|
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||||
* Returns '1h' for long retention, undefined for default (5m).
|
|
||||||
*/
|
*/
|
||||||
function getCacheTtl(baseUrl: string): "1h" | undefined {
|
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||||
if (
|
if (cacheRetention) {
|
||||||
typeof process !== "undefined" &&
|
return cacheRetention;
|
||||||
process.env.PI_CACHE_RETENTION === "long" &&
|
|
||||||
baseUrl.includes("api.anthropic.com")
|
|
||||||
) {
|
|
||||||
return "1h";
|
|
||||||
}
|
}
|
||||||
return undefined;
|
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||||
|
return "long";
|
||||||
|
}
|
||||||
|
return "short";
|
||||||
|
}
|
||||||
|
|
||||||
|
function getCacheControl(
|
||||||
|
baseUrl: string,
|
||||||
|
cacheRetention?: CacheRetention,
|
||||||
|
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
|
||||||
|
const retention = resolveCacheRetention(cacheRetention);
|
||||||
|
if (retention === "none") {
|
||||||
|
return { retention };
|
||||||
|
}
|
||||||
|
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
|
||||||
|
return {
|
||||||
|
retention,
|
||||||
|
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stealth mode: Mimic Claude Code's tool naming exactly
|
// Stealth mode: Mimic Claude Code's tool naming exactly
|
||||||
|
|
@ -460,34 +474,28 @@ function buildParams(
|
||||||
isOAuthToken: boolean,
|
isOAuthToken: boolean,
|
||||||
options?: AnthropicOptions,
|
options?: AnthropicOptions,
|
||||||
): MessageCreateParamsStreaming {
|
): MessageCreateParamsStreaming {
|
||||||
|
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
|
||||||
const params: MessageCreateParamsStreaming = {
|
const params: MessageCreateParamsStreaming = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: convertMessages(context.messages, model, isOAuthToken),
|
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
|
||||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
// For OAuth tokens, we MUST include Claude Code identity
|
// For OAuth tokens, we MUST include Claude Code identity
|
||||||
const cacheTtl = getCacheTtl(model.baseUrl);
|
|
||||||
if (isOAuthToken) {
|
if (isOAuthToken) {
|
||||||
params.system = [
|
params.system = [
|
||||||
{
|
{
|
||||||
type: "text",
|
type: "text",
|
||||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||||
cache_control: {
|
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||||
type: "ephemeral",
|
|
||||||
...(cacheTtl && { ttl: cacheTtl }),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
if (context.systemPrompt) {
|
if (context.systemPrompt) {
|
||||||
params.system.push({
|
params.system.push({
|
||||||
type: "text",
|
type: "text",
|
||||||
text: sanitizeSurrogates(context.systemPrompt),
|
text: sanitizeSurrogates(context.systemPrompt),
|
||||||
cache_control: {
|
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||||
type: "ephemeral",
|
|
||||||
...(cacheTtl && { ttl: cacheTtl }),
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (context.systemPrompt) {
|
} else if (context.systemPrompt) {
|
||||||
|
|
@ -496,10 +504,7 @@ function buildParams(
|
||||||
{
|
{
|
||||||
type: "text",
|
type: "text",
|
||||||
text: sanitizeSurrogates(context.systemPrompt),
|
text: sanitizeSurrogates(context.systemPrompt),
|
||||||
cache_control: {
|
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||||
type: "ephemeral",
|
|
||||||
...(cacheTtl && { ttl: cacheTtl }),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
@ -539,6 +544,7 @@ function convertMessages(
|
||||||
messages: Message[],
|
messages: Message[],
|
||||||
model: Model<"anthropic-messages">,
|
model: Model<"anthropic-messages">,
|
||||||
isOAuthToken: boolean,
|
isOAuthToken: boolean,
|
||||||
|
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||||
): MessageParam[] {
|
): MessageParam[] {
|
||||||
const params: MessageParam[] = [];
|
const params: MessageParam[] = [];
|
||||||
|
|
||||||
|
|
@ -665,7 +671,7 @@ function convertMessages(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add cache_control to the last user message to cache conversation history
|
// Add cache_control to the last user message to cache conversation history
|
||||||
if (params.length > 0) {
|
if (cacheControl && params.length > 0) {
|
||||||
const lastMessage = params[params.length - 1];
|
const lastMessage = params[params.length - 1];
|
||||||
if (lastMessage.role === "user") {
|
if (lastMessage.role === "user") {
|
||||||
// Add cache control to the last content block
|
// Add cache control to the last content block
|
||||||
|
|
@ -675,8 +681,7 @@ function convertMessages(
|
||||||
lastBlock &&
|
lastBlock &&
|
||||||
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
|
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
|
||||||
) {
|
) {
|
||||||
const cacheTtl = getCacheTtl(model.baseUrl);
|
(lastBlock as any).cache_control = cacheControl;
|
||||||
(lastBlock as any).cache_control = { type: "ephemeral", ...(cacheTtl && { ttl: cacheTtl }) };
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import { supportsXhigh } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
Api,
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
|
CacheRetention,
|
||||||
Context,
|
Context,
|
||||||
Model,
|
Model,
|
||||||
SimpleStreamOptions,
|
SimpleStreamOptions,
|
||||||
|
|
@ -19,16 +20,28 @@ import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||||
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get prompt cache retention based on PI_CACHE_RETENTION env var.
|
* Resolve cache retention preference.
|
||||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||||
* Returns '24h' for long retention, undefined for default (in-memory).
|
|
||||||
*/
|
*/
|
||||||
function getPromptCacheRetention(baseUrl: string): "24h" | undefined {
|
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||||
if (
|
if (cacheRetention) {
|
||||||
typeof process !== "undefined" &&
|
return cacheRetention;
|
||||||
process.env.PI_CACHE_RETENTION === "long" &&
|
}
|
||||||
baseUrl.includes("api.openai.com")
|
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||||
) {
|
return "long";
|
||||||
|
}
|
||||||
|
return "short";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get prompt cache retention based on cacheRetention and base URL.
|
||||||
|
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||||
|
*/
|
||||||
|
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
|
||||||
|
if (cacheRetention !== "long") {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
if (baseUrl.includes("api.openai.com")) {
|
||||||
return "24h";
|
return "24h";
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
|
|
@ -186,12 +199,13 @@ function createClient(
|
||||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||||
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
||||||
|
|
||||||
|
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||||
const params: ResponseCreateParamsStreaming = {
|
const params: ResponseCreateParamsStreaming = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
input: messages,
|
input: messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
prompt_cache_key: options?.sessionId,
|
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
|
||||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl),
|
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options?.maxTokens) {
|
if (options?.maxTokens) {
|
||||||
|
|
|
||||||
|
|
@ -51,11 +51,18 @@ export interface ThinkingBudgets {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base options all providers share
|
// Base options all providers share
|
||||||
|
export type CacheRetention = "none" | "short" | "long";
|
||||||
|
|
||||||
export interface StreamOptions {
|
export interface StreamOptions {
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
maxTokens?: number;
|
maxTokens?: number;
|
||||||
signal?: AbortSignal;
|
signal?: AbortSignal;
|
||||||
apiKey?: string;
|
apiKey?: string;
|
||||||
|
/**
|
||||||
|
* Prompt cache retention preference. Providers map this to their supported values.
|
||||||
|
* Default: "short".
|
||||||
|
*/
|
||||||
|
cacheRetention?: CacheRetention;
|
||||||
/**
|
/**
|
||||||
* Optional session identifier for providers that support session-based caching.
|
* Optional session identifier for providers that support session-based caching.
|
||||||
* Providers can use this to enable prompt caching, request routing, or other
|
* Providers can use this to enable prompt caching, request routing, or other
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,58 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
|
||||||
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" });
|
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral" });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should omit cache_control when cacheRetention is none", async () => {
|
||||||
|
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||||
|
let capturedPayload: any = null;
|
||||||
|
|
||||||
|
const { streamAnthropic } = await import("../src/providers/anthropic.js");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const s = streamAnthropic(baseModel, context, {
|
||||||
|
apiKey: "fake-key",
|
||||||
|
cacheRetention: "none",
|
||||||
|
onPayload: (payload) => {
|
||||||
|
capturedPayload = payload;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const event of s) {
|
||||||
|
if (event.type === "error") break;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Expected to fail
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(capturedPayload).not.toBeNull();
|
||||||
|
expect(capturedPayload.system[0].cache_control).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should set 1h cache TTL when cacheRetention is long", async () => {
|
||||||
|
const baseModel = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||||
|
let capturedPayload: any = null;
|
||||||
|
|
||||||
|
const { streamAnthropic } = await import("../src/providers/anthropic.js");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const s = streamAnthropic(baseModel, context, {
|
||||||
|
apiKey: "fake-key",
|
||||||
|
cacheRetention: "long",
|
||||||
|
onPayload: (payload) => {
|
||||||
|
capturedPayload = payload;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const event of s) {
|
||||||
|
if (event.type === "error") break;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Expected to fail
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(capturedPayload).not.toBeNull();
|
||||||
|
expect(capturedPayload.system[0].cache_control).toEqual({ type: "ephemeral", ttl: "1h" });
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("OpenAI Responses Provider", () => {
|
describe("OpenAI Responses Provider", () => {
|
||||||
|
|
@ -195,5 +247,61 @@ describe("Cache Retention (PI_CACHE_RETENTION)", () => {
|
||||||
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
|
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should omit prompt_cache_key when cacheRetention is none", async () => {
|
||||||
|
const model = getModel("openai", "gpt-4o-mini");
|
||||||
|
let capturedPayload: any = null;
|
||||||
|
|
||||||
|
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const s = streamOpenAIResponses(model, context, {
|
||||||
|
apiKey: "fake-key",
|
||||||
|
cacheRetention: "none",
|
||||||
|
sessionId: "session-1",
|
||||||
|
onPayload: (payload) => {
|
||||||
|
capturedPayload = payload;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const event of s) {
|
||||||
|
if (event.type === "error") break;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Expected to fail
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(capturedPayload).not.toBeNull();
|
||||||
|
expect(capturedPayload.prompt_cache_key).toBeUndefined();
|
||||||
|
expect(capturedPayload.prompt_cache_retention).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should set prompt_cache_retention when cacheRetention is long", async () => {
|
||||||
|
const model = getModel("openai", "gpt-4o-mini");
|
||||||
|
let capturedPayload: any = null;
|
||||||
|
|
||||||
|
const { streamOpenAIResponses } = await import("../src/providers/openai-responses.js");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const s = streamOpenAIResponses(model, context, {
|
||||||
|
apiKey: "fake-key",
|
||||||
|
cacheRetention: "long",
|
||||||
|
sessionId: "session-2",
|
||||||
|
onPayload: (payload) => {
|
||||||
|
capturedPayload = payload;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const event of s) {
|
||||||
|
if (event.type === "error") break;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Expected to fail
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(capturedPayload).not.toBeNull();
|
||||||
|
expect(capturedPayload.prompt_cache_key).toBe("session-2");
|
||||||
|
expect(capturedPayload.prompt_cache_retention).toBe("24h");
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue