feat: better cache support in bedrock (#1326)

This commit is contained in:
xu0o0 2026-02-07 01:05:46 +08:00 committed by GitHub
parent b170341b14
commit e9f94ba6c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,6 +4,7 @@ import {
StopReason as BedrockStopReason,
type Tool as BedrockTool,
CachePointType,
CacheTTL,
type ContentBlock,
type ContentBlockDeltaEvent,
type ContentBlockStartEvent,
@ -23,6 +24,7 @@ import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
@ -134,10 +136,11 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
try {
const client = new BedrockRuntimeClient(config);
const cacheRetention = resolveCacheRetention(options.cacheRetention);
const commandInput = {
modelId: model.id,
messages: convertMessages(context, model),
system: buildSystemPrompt(context.systemPrompt, model),
messages: convertMessages(context, model, cacheRetention),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
toolConfig: convertToolConfig(context.tools, options.toolChoice),
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
@ -390,11 +393,29 @@ function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"]): "low
}
}
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
/**
* Check if the model supports prompt caching.
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
*/
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
if (model.cost.cacheRead || model.cost.cacheWrite) {
return true;
}
const id = model.id.toLowerCase();
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
@ -419,14 +440,17 @@ function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boo
function buildSystemPrompt(
systemPrompt: string | undefined,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): SystemContentBlock[] | undefined {
if (!systemPrompt) return undefined;
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
// Add cache point for supported Claude models
if (supportsPromptCaching(model)) {
blocks.push({ cachePoint: { type: CachePointType.DEFAULT } });
// Add cache point for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
blocks.push({
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
});
}
return blocks;
@ -437,7 +461,11 @@ function normalizeToolCallId(id: string): string {
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
}
function convertMessages(context: Context, model: Model<"bedrock-converse-stream">): Message[] {
function convertMessages(
context: Context,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): Message[] {
const result: Message[] = [];
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
@ -566,11 +594,16 @@ function convertMessages(context: Context, model: Model<"bedrock-converse-stream
}
}
// Add cache point to the last user message for supported Claude models
if (supportsPromptCaching(model) && result.length > 0) {
// Add cache point to the last user message for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
const lastMessage = result[result.length - 1];
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
(lastMessage.content as ContentBlock[]).push({ cachePoint: { type: CachePointType.DEFAULT } });
(lastMessage.content as ContentBlock[]).push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
}