mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 11:02:17 +00:00
feat: better cache support in bedrock (#1326)
This commit is contained in:
parent
b170341b14
commit
e9f94ba6c3
1 changed files with 42 additions and 9 deletions
|
|
@ -4,6 +4,7 @@ import {
|
|||
StopReason as BedrockStopReason,
|
||||
type Tool as BedrockTool,
|
||||
CachePointType,
|
||||
CacheTTL,
|
||||
type ContentBlock,
|
||||
type ContentBlockDeltaEvent,
|
||||
type ContentBlockStartEvent,
|
||||
|
|
@ -23,6 +24,7 @@ import { calculateCost } from "../models.js";
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
|
|
@ -134,10 +136,11 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
|||
try {
|
||||
const client = new BedrockRuntimeClient(config);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
const commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model),
|
||||
system: buildSystemPrompt(context.systemPrompt, model),
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
|
||||
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
||||
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
|
||||
|
|
@ -390,11 +393,29 @@ function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"]): "low
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports prompt caching.
|
||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||
*/
|
||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
||||
if (model.cost.cacheRead || model.cost.cacheWrite) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const id = model.id.toLowerCase();
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
|
||||
|
|
@ -419,14 +440,17 @@ function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boo
|
|||
function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
||||
|
||||
// Add cache point for supported Claude models
|
||||
if (supportsPromptCaching(model)) {
|
||||
blocks.push({ cachePoint: { type: CachePointType.DEFAULT } });
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
blocks.push({
|
||||
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
|
|
@ -437,7 +461,11 @@ function normalizeToolCallId(id: string): string {
|
|||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
}
|
||||
|
||||
function convertMessages(context: Context, model: Model<"bedrock-converse-stream">): Message[] {
|
||||
function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
|
|
@ -566,11 +594,16 @@ function convertMessages(context: Context, model: Model<"bedrock-converse-stream
|
|||
}
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models
|
||||
if (supportsPromptCaching(model) && result.length > 0) {
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({ cachePoint: { type: CachePointType.DEFAULT } });
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue