Fix streaming for z-ai in anthropic provider, add preliminary support for tool call streaming. Only reporting argument string deltas, not partial JSON objects

This commit is contained in:
Mario Zechner 2025-09-09 04:26:56 +02:00
parent 2bdb87dfe7
commit 98a876f3a0
21 changed files with 784 additions and 448 deletions

View file

@ -4,7 +4,7 @@ import type {
MessageCreateParamsStreaming,
MessageParam,
} from "@anthropic-ai/sdk/resources/messages.js";
import { QueuedGenerateStream } from "../generate.js";
import { AssistantMessageEventStream } from "../event-stream.js";
import { calculateCost } from "../models.js";
import type {
Api,
@ -12,7 +12,6 @@ import type {
Context,
GenerateFunction,
GenerateOptions,
GenerateStream,
Message,
Model,
StopReason,
@ -20,8 +19,9 @@ import type {
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { transformMessages } from "./utils.js";
import { transformMessages } from "./transorm-messages.js";
export interface AnthropicOptions extends GenerateOptions {
thinkingEnabled?: boolean;
@ -33,8 +33,8 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
model: Model<"anthropic-messages">,
context: Context,
options?: AnthropicOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
@ -59,93 +59,114 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
stream.push({ type: "start", partial: output });
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
const blocks = output.content as Block[];
for await (const event of anthropicStream) {
if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
currentBlock = {
const block: Block = {
type: "text",
text: "",
index: event.index,
};
output.content.push(currentBlock);
stream.push({ type: "text_start", partial: output });
output.content.push(block);
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "thinking") {
currentBlock = {
const block: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
};
output.content.push(currentBlock);
stream.push({ type: "thinking_start", partial: output });
output.content.push(block);
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "tool_use") {
// We wait for the full tool use to be streamed
currentBlock = {
const block: Block = {
type: "toolCall",
id: event.content_block.id,
name: event.content_block.name,
arguments: event.content_block.input as Record<string, any>,
partialJson: "",
index: event.index,
};
output.content.push(block);
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
if (currentBlock && currentBlock.type === "text") {
currentBlock.text += event.delta.text;
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "text") {
block.text += event.delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinking += event.delta.thinking;
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: index,
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
if (currentBlock && currentBlock.type === "toolCall") {
currentBlock.partialJson += event.delta.partial_json;
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "toolCall") {
block.partialJson += event.delta.partial_json;
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: event.delta.partial_json,
partial: output,
});
}
} else if (event.delta.type === "signature_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
currentBlock.thinkingSignature += event.delta.signature;
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinkingSignature = block.thinkingSignature || "";
block.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
if (currentBlock) {
if (currentBlock.type === "text") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block) {
delete (block as any).index;
if (block.type === "text") {
stream.push({
type: "text_end",
content: currentBlock.text,
contentIndex: index,
content: block.text,
partial: output,
});
} else if (currentBlock.type === "thinking") {
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
content: currentBlock.thinking,
contentIndex: index,
content: block.thinking,
partial: output,
});
} else if (currentBlock.type === "toolCall") {
const finalToolCall: ToolCall = {
type: "toolCall",
id: currentBlock.id,
name: currentBlock.name,
arguments: JSON.parse(currentBlock.partialJson),
};
output.content.push(finalToolCall);
} else if (block.type === "toolCall") {
block.arguments = JSON.parse(block.partialJson);
delete (block as any).partialJson;
stream.push({
type: "toolCall",
toolCall: finalToolCall,
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
}
currentBlock = null;
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
@ -166,6 +187,7 @@ export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", error: output.error, partial: output });
@ -294,7 +316,9 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages">
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, model);
for (const msg of transformedMessages) {
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim().length > 0) {
@ -366,16 +390,37 @@ function convertMessages(messages: Message[], model: Model<"anthropic-messages">
content: blocks,
});
} else if (msg.role === "toolResult") {
// Collect all consecutive toolResult messages
const toolResults: ContentBlockParam[] = [];
// Add the current tool result
toolResults.push({
type: "tool_result",
tool_use_id: sanitizeToolCallId(msg.toolCallId),
content: msg.output,
is_error: msg.isError,
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
toolResults.push({
type: "tool_result",
tool_use_id: sanitizeToolCallId(nextMsg.toolCallId),
content: nextMsg.output,
is_error: nextMsg.isError,
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
// Add a single user message with all tool results
params.push({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: sanitizeToolCallId(msg.toolCallId),
content: msg.content,
is_error: msg.isError,
},
],
content: toolResults,
});
}
}