diff --git a/packages/ai/test/bedrock-interleaved-thinking.test.ts b/packages/ai/test/bedrock-interleaved-thinking.test.ts new file mode 100644 index 00000000..dedb94a8 --- /dev/null +++ b/packages/ai/test/bedrock-interleaved-thinking.test.ts @@ -0,0 +1,138 @@ +import { Type } from "@sinclair/typebox"; +import { describe, expect, it } from "vitest"; +import { getModel } from "../src/models.js"; +import { complete } from "../src/stream.js"; +import type { Context, StopReason, Tool, ToolCall, ToolResultMessage } from "../src/types.js"; +import { StringEnum } from "../src/utils/typebox-helpers.js"; +import { hasBedrockCredentials } from "./bedrock-utils.js"; + +const calculatorSchema = Type.Object({ + a: Type.Number({ description: "First number" }), + b: Type.Number({ description: "Second number" }), + operation: StringEnum(["add", "subtract", "multiply", "divide"], { + description: "The operation to perform.", + }), +}); + +const calculatorTool: Tool = { + name: "calculator", + description: "Perform basic arithmetic operations", + parameters: calculatorSchema, +}; + +type CalculatorOperation = "add" | "subtract" | "multiply" | "divide"; + +type CalculatorArguments = { + a: number; + b: number; + operation: CalculatorOperation; +}; + +function asCalculatorArguments(args: ToolCall["arguments"]): CalculatorArguments { + if (typeof args !== "object" || args === null) { + throw new Error("Tool arguments must be an object"); + } + + const value = args as Record; + const operation = value.operation; + if ( + typeof value.a !== "number" || + typeof value.b !== "number" || + (operation !== "add" && operation !== "subtract" && operation !== "multiply" && operation !== "divide") + ) { + throw new Error("Invalid calculator arguments"); + } + + return { a: value.a, b: value.b, operation }; +} + +function evaluateCalculatorCall(toolCall: ToolCall): number { + const { a, b, operation } = asCalculatorArguments(toolCall.arguments); + switch (operation) { + case "add": + return a + b; + case "subtract": + return a - b; + case "multiply": + return a * b; + case "divide": + return a / b; + } +} + +type BedrockInterleavedModelId = + | "global.anthropic.claude-opus-4-5-20251101-v1:0" + | "global.anthropic.claude-opus-4-6-v1"; + +async function assertSecondToolCallWithInterleavedThinking( + modelId: BedrockInterleavedModelId, + reasoning: "high" | "xhigh", +) { + const llm = getModel("amazon-bedrock", modelId); + const context: Context = { + systemPrompt: [ + "You are a helpful assistant that must use tools for arithmetic.", + "Always think before every tool call, not just the first one.", + "Do not answer with plain text when a tool call is required.", + ].join(" "), + messages: [ + { + role: "user", + content: [ + "Use calculator to calculate 328 * 29.", + "You must call the calculator tool exactly once.", + "Provide the final answer based on the best guess given the tool result, even if it seems unreliable.", + ].join(" "), + timestamp: Date.now(), + }, + ], + tools: [calculatorTool], + }; + + const firstResponse = await complete(llm, context, { + reasoning, + interleavedThinking: true, + }); + + expect(firstResponse.stopReason, `Error: ${firstResponse.errorMessage}`).toBe("toolUse" satisfies StopReason); + expect(firstResponse.content.some((block) => block.type === "thinking")).toBe(true); + expect(firstResponse.content.some((block) => block.type === "toolCall")).toBe(true); + + const firstToolCall = firstResponse.content.find((block) => block.type === "toolCall"); + expect(firstToolCall?.type).toBe("toolCall"); + if (!firstToolCall || firstToolCall.type !== "toolCall") { + throw new Error("Expected first response to include a tool call"); + } + + context.messages.push(firstResponse); + + const correctAnswer = evaluateCalculatorCall(firstToolCall); + const firstToolResult: ToolResultMessage = { + role: "toolResult", + toolCallId: firstToolCall.id, + toolName: firstToolCall.name, + content: [{ type: "text", text: `The answer is ${correctAnswer} or ${correctAnswer * 2}.` }], + isError: false, + timestamp: Date.now(), + }; + context.messages.push(firstToolResult); + + const secondResponse = await complete(llm, context, { + reasoning, + interleavedThinking: true, + }); + + expect(secondResponse.stopReason, `Error: ${secondResponse.errorMessage}`).toBe("stop" satisfies StopReason); + expect(secondResponse.content.some((block) => block.type === "thinking")).toBe(true); + expect(secondResponse.content.some((block) => block.type === "text")).toBe(true); +} + +describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock interleaved thinking", () => { + it("should do interleaved thinking on Claude Opus 4.5", { retry: 3 }, async () => { + await assertSecondToolCallWithInterleavedThinking("global.anthropic.claude-opus-4-5-20251101-v1:0", "high"); + }); + + it("should do interleaved thinking on Claude Opus 4.6", { retry: 3 }, async () => { + await assertSecondToolCallWithInterleavedThinking("global.anthropic.claude-opus-4-6-v1", "xhigh"); + }); +});