mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-17 04:02:21 +00:00
Fix streaming for z-ai in anthropic provider, add preliminary support for tool call streaming. Only reporting argument string deltas, not partial JSON objects
This commit is contained in:
parent
2bdb87dfe7
commit
98a876f3a0
21 changed files with 784 additions and 448 deletions
|
|
@ -22,7 +22,7 @@ async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: Opti
|
|||
abortFired = true;
|
||||
break;
|
||||
}
|
||||
const msg = await response.finalMessage();
|
||||
const msg = await response.result();
|
||||
|
||||
// If we get here without throwing, the abort didn't work
|
||||
expect(msg.stopReason).toBe("error");
|
||||
|
|
|
|||
|
|
@ -1,113 +0,0 @@
|
|||
import { type Context, complete, getModel } from "../src/index.js";
|
||||
|
||||
async function testCrossProviderToolCall() {
|
||||
console.log("Testing cross-provider tool call handoff...\n");
|
||||
|
||||
// Define a simple tool
|
||||
const tools = [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get current weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" },
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
// Create context with tools
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Use the get_weather tool when asked about weather.",
|
||||
messages: [{ role: "user", content: "What is the weather in Paris?" }],
|
||||
tools,
|
||||
};
|
||||
|
||||
try {
|
||||
// Step 1: Get tool call from GPT-5
|
||||
console.log("Step 1: Getting tool call from GPT-5...");
|
||||
const gpt5 = getModel("openai", "gpt-5-mini");
|
||||
const gpt5Response = await complete(gpt5, context);
|
||||
context.messages.push(gpt5Response);
|
||||
|
||||
// Check for tool calls
|
||||
const toolCalls = gpt5Response.content.filter((b) => b.type === "toolCall");
|
||||
console.log(`GPT-5 made ${toolCalls.length} tool call(s)`);
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
const toolCall = toolCalls[0];
|
||||
console.log(`Tool call ID: ${toolCall.id}`);
|
||||
console.log(`Tool call contains pipe: ${toolCall.id.includes("|")}`);
|
||||
console.log(`Tool: ${toolCall.name}(${JSON.stringify(toolCall.arguments)})\n`);
|
||||
|
||||
// Add tool result
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: JSON.stringify({
|
||||
location: "Paris",
|
||||
temperature: "22°C",
|
||||
conditions: "Partly cloudy",
|
||||
}),
|
||||
isError: false,
|
||||
});
|
||||
|
||||
// Step 2: Send to Claude Haiku for follow-up
|
||||
console.log("Step 2: Sending to Claude Haiku for follow-up...");
|
||||
const haiku = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
try {
|
||||
const haikuResponse = await complete(haiku, context);
|
||||
console.log("✅ Claude Haiku successfully processed the conversation!");
|
||||
console.log("Response content types:", haikuResponse.content.map((b) => b.type).join(", "));
|
||||
console.log("Number of content blocks:", haikuResponse.content.length);
|
||||
console.log("Stop reason:", haikuResponse.stopReason);
|
||||
if (haikuResponse.error) {
|
||||
console.log("Error message:", haikuResponse.error);
|
||||
}
|
||||
|
||||
// Print all response content
|
||||
for (const block of haikuResponse.content) {
|
||||
if (block.type === "text") {
|
||||
console.log("\nClaude text response:", block.text);
|
||||
} else if (block.type === "thinking") {
|
||||
console.log("\nClaude thinking:", block.thinking);
|
||||
} else if (block.type === "toolCall") {
|
||||
console.log("\nClaude tool call:", block.name, block.arguments);
|
||||
}
|
||||
}
|
||||
|
||||
if (haikuResponse.content.length === 0) {
|
||||
console.log("⚠️ Claude returned an empty response!");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("❌ Claude Haiku failed to process the conversation:");
|
||||
console.error("Error:", error);
|
||||
|
||||
// Check if it's related to the tool call ID
|
||||
if (error instanceof Error && error.message.includes("tool")) {
|
||||
console.error("\n⚠️ This appears to be a tool call ID issue!");
|
||||
console.error("The pipe character (|) in OpenAI Response API tool IDs might be causing problems.");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.log("No tool calls were made by GPT-5");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Test failed:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Set API keys from environment or pass them explicitly
|
||||
const openaiKey = process.env.OPENAI_API_KEY;
|
||||
const anthropicKey = process.env.ANTHROPIC_API_KEY;
|
||||
|
||||
if (!openaiKey || !anthropicKey) {
|
||||
console.error("Please set OPENAI_API_KEY and ANTHROPIC_API_KEY environment variables");
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
testCrossProviderToolCall().catch(console.error);
|
||||
|
|
@ -5,7 +5,7 @@ import { fileURLToPath } from "url";
|
|||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
|
@ -70,13 +70,62 @@ async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: Op
|
|||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
const response = await complete(model, context, options);
|
||||
const s = await stream(model, context, options);
|
||||
let hasToolStart = false;
|
||||
let hasToolDelta = false;
|
||||
let hasToolEnd = false;
|
||||
let accumulatedToolArgs = "";
|
||||
let index = 0;
|
||||
for await (const event of s) {
|
||||
if (event.type === "toolcall_start") {
|
||||
hasToolStart = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
index = event.contentIndex;
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
}
|
||||
if (event.type === "toolcall_delta") {
|
||||
hasToolDelta = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
expect(event.contentIndex).toBe(index);
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
accumulatedToolArgs += event.delta;
|
||||
}
|
||||
}
|
||||
if (event.type === "toolcall_end") {
|
||||
hasToolEnd = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
expect(event.contentIndex).toBe(index);
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
JSON.parse(accumulatedToolArgs);
|
||||
expect(toolCall.arguments).not.toBeUndefined();
|
||||
expect((toolCall.arguments as any).a).toBe(15);
|
||||
expect((toolCall.arguments as any).b).toBe(27);
|
||||
expect((toolCall.arguments as any).operation).oneOf(["add", "subtract", "multiply", "divide"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expect(hasToolStart).toBe(true);
|
||||
expect(hasToolDelta).toBe(true);
|
||||
expect(hasToolEnd).toBe(true);
|
||||
|
||||
const response = await s.result();
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find((b) => b.type === "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
} else {
|
||||
throw new Error("No tool call found in response");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +150,7 @@ async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: O
|
|||
}
|
||||
}
|
||||
|
||||
const response = await s.finalMessage();
|
||||
const response = await s.result();
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
|
|
@ -135,7 +184,7 @@ async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: Op
|
|||
}
|
||||
}
|
||||
|
||||
const response = await s.finalMessage();
|
||||
const response = await s.result();
|
||||
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
|
|
@ -214,6 +263,7 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
|
|||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
const results: ToolResultMessage[] = [];
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
|
|
@ -241,15 +291,16 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
|
|||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
results.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
output: `${result}`,
|
||||
isError: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
context.messages.push(...results);
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
|
|
@ -331,12 +382,12 @@ describe("Generate E2E Tests", () => {
|
|||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle ", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
it("should handle thinking", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "high" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
await multiTurn(llm, { reasoningEffort: "high" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js";
|
||||
import type { Api, AssistantMessage, Context, Message, Model, Tool, ToolResultMessage } from "../src/types.js";
|
||||
|
||||
// Tool for testing
|
||||
const weatherTool: Tool = {
|
||||
|
|
@ -22,6 +22,7 @@ const providerContexts = {
|
|||
anthropic: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
api: "anthropic-messages",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
|
|
@ -49,14 +50,14 @@ const providerContexts = {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
} satisfies AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "toolu_01abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
output: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
isError: false,
|
||||
},
|
||||
} satisfies ToolResultMessage,
|
||||
facts: {
|
||||
calculation: 391,
|
||||
city: "Tokyo",
|
||||
|
|
@ -69,6 +70,7 @@ const providerContexts = {
|
|||
google: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
api: "google-generative-ai",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
|
|
@ -97,14 +99,14 @@ const providerContexts = {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
} satisfies AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_gemini_123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Berlin: 22°C, sunny",
|
||||
output: "Weather in Berlin: 22°C, sunny",
|
||||
isError: false,
|
||||
},
|
||||
} satisfies ToolResultMessage,
|
||||
facts: {
|
||||
calculation: 456,
|
||||
city: "Berlin",
|
||||
|
|
@ -117,6 +119,7 @@ const providerContexts = {
|
|||
openaiCompletions: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
api: "openai-completions",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
|
|
@ -144,14 +147,14 @@ const providerContexts = {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
} satisfies AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in London: 15°C, rainy",
|
||||
output: "Weather in London: 15°C, rainy",
|
||||
isError: false,
|
||||
},
|
||||
} satisfies ToolResultMessage,
|
||||
facts: {
|
||||
calculation: 525,
|
||||
city: "London",
|
||||
|
|
@ -164,6 +167,7 @@ const providerContexts = {
|
|||
openaiResponses: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
api: "openai-responses",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
|
|
@ -193,14 +197,14 @@ const providerContexts = {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
} satisfies AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Sydney: 25°C, clear",
|
||||
output: "Weather in Sydney: 25°C, clear",
|
||||
isError: false,
|
||||
},
|
||||
} satisfies ToolResultMessage,
|
||||
facts: {
|
||||
calculation: 486,
|
||||
city: "Sydney",
|
||||
|
|
@ -213,6 +217,7 @@ const providerContexts = {
|
|||
aborted: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
api: "anthropic-messages",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
|
|
@ -235,7 +240,7 @@ const providerContexts = {
|
|||
},
|
||||
stopReason: "error",
|
||||
error: "Request was aborted",
|
||||
} as AssistantMessage,
|
||||
} satisfies AssistantMessage,
|
||||
toolResult: null,
|
||||
facts: {
|
||||
calculation: 600,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue