diff --git a/packages/ai/src/providers/amazon-bedrock.ts b/packages/ai/src/providers/amazon-bedrock.ts index d74081d3..72bb8203 100644 --- a/packages/ai/src/providers/amazon-bedrock.ts +++ b/packages/ai/src/providers/amazon-bedrock.ts @@ -2,6 +2,7 @@ import { BedrockRuntimeClient, StopReason as BedrockStopReason, type Tool as BedrockTool, + CachePointType, type ContentBlock, type ContentBlockDeltaEvent, type ContentBlockStartEvent, @@ -11,6 +12,7 @@ import { type ConverseStreamMetadataEvent, ImageFormat, type Message, + type SystemContentBlock, type ToolChoice, type ToolConfiguration, ToolResultStatus, @@ -87,8 +89,8 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream"> = ( const command = new ConverseStreamCommand({ modelId: model.id, - messages: convertMessages(context), - system: context.systemPrompt ? [{ text: sanitizeSurrogates(context.systemPrompt) }] : undefined, + messages: convertMessages(context, model), + system: buildSystemPrompt(context.systemPrompt, model), inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature }, toolConfig: convertToolConfig(context.tools, options.toolChoice), additionalModelRequestFields: buildAdditionalModelRequestFields(model, options), @@ -272,7 +274,38 @@ function handleContentBlockStop( } } -function convertMessages(context: Context): Message[] { +/** + * Check if the model supports prompt caching. + * Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models + */ +function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean { + const id = model.id.toLowerCase(); + // Claude 4.x models (opus-4, sonnet-4, haiku-4) + if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true; + // Claude 3.7 Sonnet + if (id.includes("claude-3-7-sonnet")) return true; + // Claude 3.5 Haiku + if (id.includes("claude-3-5-haiku")) return true; + return false; +} + +function buildSystemPrompt( + systemPrompt: string | undefined, + model: Model<"bedrock-converse-stream">, +): SystemContentBlock[] | undefined { + if (!systemPrompt) return undefined; + + const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }]; + + // Add cache point for supported Claude models + if (supportsPromptCaching(model)) { + blocks.push({ cachePoint: { type: CachePointType.DEFAULT } }); + } + + return blocks; +} + +function convertMessages(context: Context, model: Model<"bedrock-converse-stream">): Message[] { const result: Message[] = []; const messages = context.messages; @@ -394,6 +427,14 @@ function convertMessages(context: Context): Message[] { } } + // Add cache point to the last user message for supported Claude models + if (supportsPromptCaching(model) && result.length > 0) { + const lastMessage = result[result.length - 1]; + if (lastMessage.role === ConversationRole.USER && lastMessage.content) { + (lastMessage.content as ContentBlock[]).push({ cachePoint: { type: CachePointType.DEFAULT } }); + } + } + return result; }