Fix streaming for z-ai in anthropic provider, add preliminary support for tool call streaming. Only reporting argument string deltas, not partial JSON objects

This commit is contained in:
Mario Zechner 2025-09-09 04:26:56 +02:00
parent 2bdb87dfe7
commit 98a876f3a0
21 changed files with 784 additions and 448 deletions

View file

@ -5,9 +5,8 @@ import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/
import type {
Api,
AssistantMessage,
AssistantMessageEvent,
AssistantMessageEventStream,
Context,
GenerateStream,
KnownProvider,
Model,
OptionsForApi,
@ -15,73 +14,6 @@ import type {
SimpleGenerateOptions,
} from "./types.js";
export class QueuedGenerateStream implements GenerateStream {
private queue: AssistantMessageEvent[] = [];
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
private done = false;
private finalMessagePromise: Promise<AssistantMessage>;
private resolveFinalMessage!: (message: AssistantMessage) => void;
constructor() {
this.finalMessagePromise = new Promise((resolve) => {
this.resolveFinalMessage = resolve;
});
}
push(event: AssistantMessageEvent): void {
if (this.done) return;
if (event.type === "done") {
this.done = true;
this.resolveFinalMessage(event.message);
}
if (event.type === "error") {
this.done = true;
this.resolveFinalMessage(event.partial);
}
// Deliver to waiting consumer or queue it
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: event, done: false });
} else {
this.queue.push(event);
}
}
end(): void {
this.done = true;
// Notify all waiting consumers that we're done
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as any, done: true });
}
}
async *[Symbol.asyncIterator](): AsyncIterator<AssistantMessageEvent> {
while (true) {
// If we have queued events, yield them
if (this.queue.length > 0) {
yield this.queue.shift()!;
} else if (this.done) {
// No more events and we're done
return;
} else {
// Wait for next event
const result = await new Promise<IteratorResult<AssistantMessageEvent>>((resolve) =>
this.waiting.push(resolve),
);
if (result.done) return;
yield result.value;
}
}
}
finalMessage(): Promise<AssistantMessage> {
return this.finalMessagePromise;
}
}
const apiKeys: Map<string, string> = new Map();
export function setApiKey(provider: KnownProvider, key: string): void;
@ -117,7 +49,7 @@ export function stream<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: OptionsForApi<TApi>,
): GenerateStream {
): AssistantMessageEventStream {
const apiKey = options?.apiKey || getApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
@ -152,14 +84,14 @@ export async function complete<TApi extends Api>(
options?: OptionsForApi<TApi>,
): Promise<AssistantMessage> {
const s = stream(model, context, options);
return s.finalMessage();
return s.result();
}
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleGenerateOptions,
): GenerateStream {
): AssistantMessageEventStream {
const apiKey = options?.apiKey || getApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
@ -175,7 +107,7 @@ export async function completeSimple<TApi extends Api>(
options?: SimpleGenerateOptions,
): Promise<AssistantMessage> {
const s = streamSimple(model, context, options);
return s.finalMessage();
return s.result();
}
function mapOptionsForApi<TApi extends Api>(