mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 22:02:38 +00:00
Fix streaming for z-ai in anthropic provider, add preliminary support for tool call streaming. Only reporting argument string deltas, not partial JSON objects
This commit is contained in:
parent
2bdb87dfe7
commit
98a876f3a0
21 changed files with 784 additions and 448 deletions
|
|
@ -5,7 +5,7 @@ import { fileURLToPath } from "url";
|
|||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool, ToolResultMessage } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
|
@ -70,13 +70,62 @@ async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: Op
|
|||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
const response = await complete(model, context, options);
|
||||
const s = await stream(model, context, options);
|
||||
let hasToolStart = false;
|
||||
let hasToolDelta = false;
|
||||
let hasToolEnd = false;
|
||||
let accumulatedToolArgs = "";
|
||||
let index = 0;
|
||||
for await (const event of s) {
|
||||
if (event.type === "toolcall_start") {
|
||||
hasToolStart = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
index = event.contentIndex;
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
}
|
||||
if (event.type === "toolcall_delta") {
|
||||
hasToolDelta = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
expect(event.contentIndex).toBe(index);
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
accumulatedToolArgs += event.delta;
|
||||
}
|
||||
}
|
||||
if (event.type === "toolcall_end") {
|
||||
hasToolEnd = true;
|
||||
const toolCall = event.partial.content[event.contentIndex];
|
||||
expect(event.contentIndex).toBe(index);
|
||||
expect(toolCall.type).toBe("toolCall");
|
||||
if (toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
JSON.parse(accumulatedToolArgs);
|
||||
expect(toolCall.arguments).not.toBeUndefined();
|
||||
expect((toolCall.arguments as any).a).toBe(15);
|
||||
expect((toolCall.arguments as any).b).toBe(27);
|
||||
expect((toolCall.arguments as any).operation).oneOf(["add", "subtract", "multiply", "divide"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expect(hasToolStart).toBe(true);
|
||||
expect(hasToolDelta).toBe(true);
|
||||
expect(hasToolEnd).toBe(true);
|
||||
|
||||
const response = await s.result();
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find((b) => b.type === "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
} else {
|
||||
throw new Error("No tool call found in response");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +150,7 @@ async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: O
|
|||
}
|
||||
}
|
||||
|
||||
const response = await s.finalMessage();
|
||||
const response = await s.result();
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
|
|
@ -135,7 +184,7 @@ async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: Op
|
|||
}
|
||||
}
|
||||
|
||||
const response = await s.finalMessage();
|
||||
const response = await s.result();
|
||||
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
|
|
@ -214,6 +263,7 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
|
|||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
const results: ToolResultMessage[] = [];
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
|
|
@ -241,15 +291,16 @@ async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: Options
|
|||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
results.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
output: `${result}`,
|
||||
isError: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
context.messages.push(...results);
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
|
|
@ -331,12 +382,12 @@ describe("Generate E2E Tests", () => {
|
|||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle ", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
it("should handle thinking", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "high" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
await multiTurn(llm, { reasoningEffort: "high" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue