import { BedrockRuntimeClient, StopReason as BedrockStopReason, type Tool as BedrockTool, CachePointType, type ContentBlock, type ContentBlockDeltaEvent, type ContentBlockStartEvent, type ContentBlockStopEvent, ConversationRole, ConverseStreamCommand, type ConverseStreamMetadataEvent, ImageFormat, type Message, type SystemContentBlock, type ToolChoice, type ToolConfiguration, ToolResultStatus, } from "@aws-sdk/client-bedrock-runtime"; import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, Context, Model, StopReason, StreamFunction, StreamOptions, TextContent, ThinkingBudgets, ThinkingContent, ThinkingLevel, Tool, ToolCall, ToolResultMessage, } from "../types.js"; import { AssistantMessageEventStream } from "../utils/event-stream.js"; import { parseStreamingJson } from "../utils/json-parse.js"; import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; import { transformMessages } from "./transform-messages.js"; export interface BedrockOptions extends StreamOptions { region?: string; profile?: string; toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string }; /* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */ reasoning?: ThinkingLevel; /* Custom token budgets per thinking level. Overrides default budgets. */ thinkingBudgets?: ThinkingBudgets; /* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */ interleavedThinking?: boolean; } type Block = (TextContent | ThinkingContent | ToolCall) & { index?: number; partialJson?: string }; export const streamBedrock: StreamFunction<"bedrock-converse-stream"> = ( model: Model<"bedrock-converse-stream">, context: Context, options: BedrockOptions, ): AssistantMessageEventStream => { const stream = new AssistantMessageEventStream(); (async () => { const output: AssistantMessage = { role: "assistant", content: [], api: "bedrock-converse-stream" as Api, provider: model.provider, model: model.id, usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "stop", timestamp: Date.now(), }; const blocks = output.content as Block[]; try { const client = new BedrockRuntimeClient({ region: options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION || "us-east-1", profile: options.profile, }); const command = new ConverseStreamCommand({ modelId: model.id, messages: convertMessages(context, model), system: buildSystemPrompt(context.systemPrompt, model), inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature }, toolConfig: convertToolConfig(context.tools, options.toolChoice), additionalModelRequestFields: buildAdditionalModelRequestFields(model, options), }); const response = await client.send(command, { abortSignal: options.signal }); for await (const item of response.stream!) { if (item.messageStart) { if (item.messageStart.role !== ConversationRole.ASSISTANT) { throw new Error("Unexpected assistant message start but got user message start instead"); } stream.push({ type: "start", partial: output }); } else if (item.contentBlockStart) { handleContentBlockStart(item.contentBlockStart, blocks, output, stream); } else if (item.contentBlockDelta) { handleContentBlockDelta(item.contentBlockDelta, blocks, output, stream); } else if (item.contentBlockStop) { handleContentBlockStop(item.contentBlockStop, blocks, output, stream); } else if (item.messageStop) { output.stopReason = mapStopReason(item.messageStop.stopReason); } else if (item.metadata) { handleMetadata(item.metadata, model, output); } else if (item.internalServerException) { throw new Error(`Internal server error: ${item.internalServerException.message}`); } else if (item.modelStreamErrorException) { throw new Error(`Model stream error: ${item.modelStreamErrorException.message}`); } else if (item.validationException) { throw new Error(`Validation error: ${item.validationException.message}`); } else if (item.throttlingException) { throw new Error(`Throttling error: ${item.throttlingException.message}`); } else if (item.serviceUnavailableException) { throw new Error(`Service unavailable: ${item.serviceUnavailableException.message}`); } } if (options.signal?.aborted) { throw new Error("Request was aborted"); } if (output.stopReason === "error" || output.stopReason === "aborted") { throw new Error("An unknown error occurred"); } stream.push({ type: "done", reason: output.stopReason, message: output }); stream.end(); } catch (error) { for (const block of output.content) { delete (block as Block).index; delete (block as Block).partialJson; } output.stopReason = options.signal?.aborted ? "aborted" : "error"; output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); stream.push({ type: "error", reason: output.stopReason, error: output }); stream.end(); } })(); return stream; }; function handleContentBlockStart( event: ContentBlockStartEvent, blocks: Block[], output: AssistantMessage, stream: AssistantMessageEventStream, ): void { const index = event.contentBlockIndex!; const start = event.start; if (start?.toolUse) { const block: Block = { type: "toolCall", id: start.toolUse.toolUseId || "", name: start.toolUse.name || "", arguments: {}, partialJson: "", index, }; output.content.push(block); stream.push({ type: "toolcall_start", contentIndex: blocks.length - 1, partial: output }); } } function handleContentBlockDelta( event: ContentBlockDeltaEvent, blocks: Block[], output: AssistantMessage, stream: AssistantMessageEventStream, ): void { const contentBlockIndex = event.contentBlockIndex!; const delta = event.delta; let index = blocks.findIndex((b) => b.index === contentBlockIndex); let block = blocks[index]; if (delta?.text !== undefined) { // If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks if (!block) { const newBlock: Block = { type: "text", text: "", index: contentBlockIndex }; output.content.push(newBlock); index = blocks.length - 1; block = blocks[index]; stream.push({ type: "text_start", contentIndex: index, partial: output }); } if (block.type === "text") { block.text += delta.text; stream.push({ type: "text_delta", contentIndex: index, delta: delta.text, partial: output }); } } else if (delta?.toolUse && block?.type === "toolCall") { block.partialJson = (block.partialJson || "") + (delta.toolUse.input || ""); block.arguments = parseStreamingJson(block.partialJson); stream.push({ type: "toolcall_delta", contentIndex: index, delta: delta.toolUse.input || "", partial: output }); } else if (delta?.reasoningContent) { let thinkingBlock = block; let thinkingIndex = index; if (!thinkingBlock) { const newBlock: Block = { type: "thinking", thinking: "", thinkingSignature: "", index: contentBlockIndex }; output.content.push(newBlock); thinkingIndex = blocks.length - 1; thinkingBlock = blocks[thinkingIndex]; stream.push({ type: "thinking_start", contentIndex: thinkingIndex, partial: output }); } if (thinkingBlock?.type === "thinking") { if (delta.reasoningContent.text) { thinkingBlock.thinking += delta.reasoningContent.text; stream.push({ type: "thinking_delta", contentIndex: thinkingIndex, delta: delta.reasoningContent.text, partial: output, }); } if (delta.reasoningContent.signature) { thinkingBlock.thinkingSignature = (thinkingBlock.thinkingSignature || "") + delta.reasoningContent.signature; } } } } function handleMetadata( event: ConverseStreamMetadataEvent, model: Model<"bedrock-converse-stream">, output: AssistantMessage, ): void { if (event.usage) { output.usage.input = event.usage.inputTokens || 0; output.usage.output = event.usage.outputTokens || 0; output.usage.cacheRead = event.usage.cacheReadInputTokens || 0; output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0; output.usage.totalTokens = event.usage.totalTokens || output.usage.input + output.usage.output; calculateCost(model, output.usage); } } function handleContentBlockStop( event: ContentBlockStopEvent, blocks: Block[], output: AssistantMessage, stream: AssistantMessageEventStream, ): void { const index = blocks.findIndex((b) => b.index === event.contentBlockIndex); const block = blocks[index]; if (!block) return; delete (block as Block).index; switch (block.type) { case "text": stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output }); break; case "thinking": stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output }); break; case "toolCall": block.arguments = parseStreamingJson(block.partialJson); delete (block as Block).partialJson; stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output }); break; } } /** * Check if the model supports prompt caching. * Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models */ function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean { const id = model.id.toLowerCase(); // Claude 4.x models (opus-4, sonnet-4, haiku-4) if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true; // Claude 3.7 Sonnet if (id.includes("claude-3-7-sonnet")) return true; // Claude 3.5 Haiku if (id.includes("claude-3-5-haiku")) return true; return false; } /** * Check if the model supports thinking signatures in reasoningContent. * Only Anthropic Claude models support the signature field. * Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with: * "This model doesn't support the reasoningContent.reasoningText.signature field" */ function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boolean { const id = model.id.toLowerCase(); return id.includes("anthropic.claude") || id.includes("anthropic/claude"); } function buildSystemPrompt( systemPrompt: string | undefined, model: Model<"bedrock-converse-stream">, ): SystemContentBlock[] | undefined { if (!systemPrompt) return undefined; const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }]; // Add cache point for supported Claude models if (supportsPromptCaching(model)) { blocks.push({ cachePoint: { type: CachePointType.DEFAULT } }); } return blocks; } function convertMessages(context: Context, model: Model<"bedrock-converse-stream">): Message[] { const result: Message[] = []; const transformedMessages = transformMessages(context.messages, model); for (let i = 0; i < transformedMessages.length; i++) { const m = transformedMessages[i]; switch (m.role) { case "user": result.push({ role: ConversationRole.USER, content: typeof m.content === "string" ? [{ text: sanitizeSurrogates(m.content) }] : m.content.map((c) => { switch (c.type) { case "text": return { text: sanitizeSurrogates(c.text) }; case "image": return { image: createImageBlock(c.mimeType, c.data) }; default: throw new Error("Unknown user content type"); } }), }); break; case "assistant": { // Skip assistant messages with empty content (e.g., from aborted requests) // Bedrock rejects messages with empty content arrays if (m.content.length === 0) { continue; } const contentBlocks: ContentBlock[] = []; for (const c of m.content) { switch (c.type) { case "text": // Skip empty text blocks if (c.text.trim().length === 0) continue; contentBlocks.push({ text: sanitizeSurrogates(c.text) }); break; case "toolCall": contentBlocks.push({ toolUse: { toolUseId: c.id, name: c.name, input: c.arguments }, }); break; case "thinking": // Skip empty thinking blocks if (c.thinking.trim().length === 0) continue; // Only Anthropic models support the signature field in reasoningText. // For other models, we omit the signature to avoid errors like: // "This model doesn't support the reasoningContent.reasoningText.signature field" if (supportsThinkingSignature(model)) { contentBlocks.push({ reasoningContent: { reasoningText: { text: sanitizeSurrogates(c.thinking), signature: c.thinkingSignature }, }, }); } else { contentBlocks.push({ reasoningContent: { reasoningText: { text: sanitizeSurrogates(c.thinking) }, }, }); } break; default: throw new Error("Unknown assistant content type"); } } // Skip if all content blocks were filtered out if (contentBlocks.length === 0) { continue; } result.push({ role: ConversationRole.ASSISTANT, content: contentBlocks, }); break; } case "toolResult": { // Collect all consecutive toolResult messages into a single user message // Bedrock requires all tool results to be in one message const toolResults: ContentBlock.ToolResultMember[] = []; // Add current tool result with all content blocks combined toolResults.push({ toolResult: { toolUseId: m.toolCallId, content: m.content.map((c) => c.type === "image" ? { image: createImageBlock(c.mimeType, c.data) } : { text: sanitizeSurrogates(c.text) }, ), status: m.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS, }, }); // Look ahead for consecutive toolResult messages let j = i + 1; while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") { const nextMsg = transformedMessages[j] as ToolResultMessage; toolResults.push({ toolResult: { toolUseId: nextMsg.toolCallId, content: nextMsg.content.map((c) => c.type === "image" ? { image: createImageBlock(c.mimeType, c.data) } : { text: sanitizeSurrogates(c.text) }, ), status: nextMsg.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS, }, }); j++; } // Skip the messages we've already processed i = j - 1; result.push({ role: ConversationRole.USER, content: toolResults, }); break; } default: throw new Error("Unknown message role"); } } // Add cache point to the last user message for supported Claude models if (supportsPromptCaching(model) && result.length > 0) { const lastMessage = result[result.length - 1]; if (lastMessage.role === ConversationRole.USER && lastMessage.content) { (lastMessage.content as ContentBlock[]).push({ cachePoint: { type: CachePointType.DEFAULT } }); } } return result; } function convertToolConfig( tools: Tool[] | undefined, toolChoice: BedrockOptions["toolChoice"], ): ToolConfiguration | undefined { if (!tools?.length || toolChoice === "none") return undefined; const bedrockTools: BedrockTool[] = tools.map((tool) => ({ toolSpec: { name: tool.name, description: tool.description, inputSchema: { json: tool.parameters }, }, })); let bedrockToolChoice: ToolChoice | undefined; switch (toolChoice) { case "auto": bedrockToolChoice = { auto: {} }; break; case "any": bedrockToolChoice = { any: {} }; break; default: if (toolChoice?.type === "tool") { bedrockToolChoice = { tool: { name: toolChoice.name } }; } } return { tools: bedrockTools, toolChoice: bedrockToolChoice }; } function mapStopReason(reason: string | undefined): StopReason { switch (reason) { case BedrockStopReason.END_TURN: case BedrockStopReason.STOP_SEQUENCE: return "stop"; case BedrockStopReason.MAX_TOKENS: case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED: return "length"; case BedrockStopReason.TOOL_USE: return "toolUse"; default: return "error"; } } function buildAdditionalModelRequestFields( model: Model<"bedrock-converse-stream">, options: BedrockOptions, ): Record | undefined { if (!options.reasoning || !model.reasoning) { return undefined; } if (model.id.includes("anthropic.claude")) { const defaultBudgets: Record = { minimal: 1024, low: 2048, medium: 8192, high: 16384, xhigh: 16384, // Claude doesn't support xhigh, clamp to high }; // Custom budgets override defaults (xhigh not in ThinkingBudgets, use high) const level = options.reasoning === "xhigh" ? "high" : options.reasoning; const budget = options.thinkingBudgets?.[level] ?? defaultBudgets[options.reasoning]; const result: Record = { thinking: { type: "enabled", budget_tokens: budget, }, }; if (options.interleavedThinking) { result.anthropic_beta = ["interleaved-thinking-2025-05-14"]; } return result; } return undefined; } function createImageBlock(mimeType: string, data: string) { let format: ImageFormat; switch (mimeType) { case "image/jpeg": case "image/jpg": format = ImageFormat.JPEG; break; case "image/png": format = ImageFormat.PNG; break; case "image/gif": format = ImageFormat.GIF; break; case "image/webp": format = ImageFormat.WEBP; break; default: throw new Error(`Unknown image type: ${mimeType}`); } const binaryString = atob(data); const bytes = new Uint8Array(binaryString.length); for (let i = 0; i < binaryString.length; i++) { bytes[i] = binaryString.charCodeAt(i); } return { source: { bytes }, format }; }