mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-19 17:04:41 +00:00
feat: better cache support in bedrock (#1326)
This commit is contained in:
parent
b170341b14
commit
e9f94ba6c3
1 changed files with 42 additions and 9 deletions
|
|
@ -4,6 +4,7 @@ import {
|
||||||
StopReason as BedrockStopReason,
|
StopReason as BedrockStopReason,
|
||||||
type Tool as BedrockTool,
|
type Tool as BedrockTool,
|
||||||
CachePointType,
|
CachePointType,
|
||||||
|
CacheTTL,
|
||||||
type ContentBlock,
|
type ContentBlock,
|
||||||
type ContentBlockDeltaEvent,
|
type ContentBlockDeltaEvent,
|
||||||
type ContentBlockStartEvent,
|
type ContentBlockStartEvent,
|
||||||
|
|
@ -23,6 +24,7 @@ import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
Api,
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
|
CacheRetention,
|
||||||
Context,
|
Context,
|
||||||
Model,
|
Model,
|
||||||
SimpleStreamOptions,
|
SimpleStreamOptions,
|
||||||
|
|
@ -134,10 +136,11 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
||||||
try {
|
try {
|
||||||
const client = new BedrockRuntimeClient(config);
|
const client = new BedrockRuntimeClient(config);
|
||||||
|
|
||||||
|
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||||
const commandInput = {
|
const commandInput = {
|
||||||
modelId: model.id,
|
modelId: model.id,
|
||||||
messages: convertMessages(context, model),
|
messages: convertMessages(context, model, cacheRetention),
|
||||||
system: buildSystemPrompt(context.systemPrompt, model),
|
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||||
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
|
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
|
||||||
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
||||||
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
|
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
|
||||||
|
|
@ -390,11 +393,29 @@ function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"]): "low
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve cache retention preference.
|
||||||
|
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||||
|
*/
|
||||||
|
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||||
|
if (cacheRetention) {
|
||||||
|
return cacheRetention;
|
||||||
|
}
|
||||||
|
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||||
|
return "long";
|
||||||
|
}
|
||||||
|
return "short";
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the model supports prompt caching.
|
* Check if the model supports prompt caching.
|
||||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||||
*/
|
*/
|
||||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
||||||
|
if (model.cost.cacheRead || model.cost.cacheWrite) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
const id = model.id.toLowerCase();
|
const id = model.id.toLowerCase();
|
||||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||||
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
|
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
|
||||||
|
|
@ -419,14 +440,17 @@ function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boo
|
||||||
function buildSystemPrompt(
|
function buildSystemPrompt(
|
||||||
systemPrompt: string | undefined,
|
systemPrompt: string | undefined,
|
||||||
model: Model<"bedrock-converse-stream">,
|
model: Model<"bedrock-converse-stream">,
|
||||||
|
cacheRetention: CacheRetention,
|
||||||
): SystemContentBlock[] | undefined {
|
): SystemContentBlock[] | undefined {
|
||||||
if (!systemPrompt) return undefined;
|
if (!systemPrompt) return undefined;
|
||||||
|
|
||||||
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
||||||
|
|
||||||
// Add cache point for supported Claude models
|
// Add cache point for supported Claude models when caching is enabled
|
||||||
if (supportsPromptCaching(model)) {
|
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||||
blocks.push({ cachePoint: { type: CachePointType.DEFAULT } });
|
blocks.push({
|
||||||
|
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return blocks;
|
return blocks;
|
||||||
|
|
@ -437,7 +461,11 @@ function normalizeToolCallId(id: string): string {
|
||||||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||||
}
|
}
|
||||||
|
|
||||||
function convertMessages(context: Context, model: Model<"bedrock-converse-stream">): Message[] {
|
function convertMessages(
|
||||||
|
context: Context,
|
||||||
|
model: Model<"bedrock-converse-stream">,
|
||||||
|
cacheRetention: CacheRetention,
|
||||||
|
): Message[] {
|
||||||
const result: Message[] = [];
|
const result: Message[] = [];
|
||||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||||
|
|
||||||
|
|
@ -566,11 +594,16 @@ function convertMessages(context: Context, model: Model<"bedrock-converse-stream
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add cache point to the last user message for supported Claude models
|
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||||
if (supportsPromptCaching(model) && result.length > 0) {
|
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
|
||||||
const lastMessage = result[result.length - 1];
|
const lastMessage = result[result.length - 1];
|
||||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||||
(lastMessage.content as ContentBlock[]).push({ cachePoint: { type: CachePointType.DEFAULT } });
|
(lastMessage.content as ContentBlock[]).push({
|
||||||
|
cachePoint: {
|
||||||
|
type: CachePointType.DEFAULT,
|
||||||
|
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||||
|
},
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue