Fix streaming for z-ai in anthropic provider, add preliminary support for tool call streaming. Only reporting argument string deltas, not partial JSON objects

This commit is contained in:
Mario Zechner 2025-09-09 04:26:56 +02:00
parent 2bdb87dfe7
commit 98a876f3a0
21 changed files with 784 additions and 448 deletions

View file

@ -5,7 +5,7 @@ import { fileURLToPath } from "url";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { complete, stream } from "../src/generate.js";
import { getModel } from "../src/models.js";
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
@ -70,13 +70,62 @@ async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: Op
tools: [calculatorTool],
};
const response = await complete(model, context, options);
const s = await stream(model, context, options);
let hasToolStart = false;
let hasToolDelta = false;
let hasToolEnd = false;
let accumulatedToolArgs = "";
let index = 0;
for await (const event of s) {
if (event.type === "toolcall_start") {
hasToolStart = true;
const toolCall = event.partial.content[event.contentIndex];
index = event.contentIndex;
expect(toolCall.type).toBe("toolCall");
if (toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator");
expect(toolCall.id).toBeTruthy();
}
}
if (event.type === "toolcall_delta") {
hasToolDelta = true;
const toolCall = event.partial.content[event.contentIndex];
expect(event.contentIndex).toBe(index);
expect(toolCall.type).toBe("toolCall");
if (toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator");
accumulatedToolArgs += event.delta;
}
}
if (event.type === "toolcall_end") {
hasToolEnd = true;
const toolCall = event.partial.content[event.contentIndex];
expect(event.contentIndex).toBe(index);
expect(toolCall.type).toBe("toolCall");
if (toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator");
JSON.parse(accumulatedToolArgs);
expect(toolCall.arguments).not.toBeUndefined();
expect((toolCall.arguments as any).a).toBe(15);
expect((toolCall.arguments as any).b).toBe(27);
expect((toolCall.arguments as any).operation).oneOf(["add", "subtract", "multiply", "divide"]);
}
}
}
expect(hasToolStart).toBe(true);
expect(hasToolDelta).toBe(true);
expect(hasToolEnd).toBe(true);
const response = await s.result();
expect(response.stopReason).toBe("toolUse");
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
const toolCall = response.content.find((b) => b.type === "toolCall");
if (toolCall && toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator");
expect(toolCall.id).toBeTruthy();
} else {
throw new Error("No tool call found in response");
}
}
@ -101,7 +150,7 @@ async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: O
}
}
const response = await s.finalMessage();
const response = await s.result();
expect(textStarted).toBe(true);
expect(textChunks.length).toBeGreaterThan(0);
@ -135,7 +184,7 @@ async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: Op
}
}
const response = await s.finalMessage();
const response = await s.result();
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
expect(thinkingStarted).toBe(true);
@ -214,6 +263,7 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
context.messages.push(response);
// Process content blocks
const results: ToolResultMessage[] = [];
for (const block of response.content) {
if (block.type === "text") {
allTextContent += block.text;
@ -241,15 +291,16 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
}
// Add tool result to context
context.messages.push({
results.push({
role: "toolResult",
toolCallId: block.id,
toolName: block.name,
content: `${result}`,
output: `${result}`,
isError: false,
});
}
}
context.messages.push(...results);
// If we got a stop response with text content, we're likely done
expect(response.stopReason).not.toBe("error");
@ -331,12 +382,12 @@ describe("Generate E2E Tests", () => {
await handleStreaming(llm);
});
it("should handle ", { retry: 2 }, async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
it("should handle thinking", { retry: 2 }, async () => {
await handleThinking(llm, { reasoningEffort: "high" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
await multiTurn(llm, { reasoningEffort: "high" });
});
it("should handle image input", async () => {