feat(ai): Add new streaming generate API with AsyncIterable interface

- Implement QueuedGenerateStream class that extends AsyncIterable with finalMessage() method
- Add new types: GenerateStream, GenerateOptions, GenerateOptionsUnified, GenerateFunction
- Create generateAnthropic function-based implementation replacing class-based approach
- Add comprehensive test suite for the new generate API
- Support streaming events with text, thinking, and tool call deltas
- Map ReasoningEffort to provider-specific options
- Include apiKey in options instead of constructor parameter
This commit is contained in:
Mario Zechner 2025-09-02 18:07:46 +02:00
parent be07c08a75
commit 004de3c9d0
6 changed files with 1106 additions and 129 deletions

268
packages/ai/src/generate.ts Normal file
View file

@ -0,0 +1,268 @@
import type {
Api,
AssistantMessage,
AssistantMessageEvent,
Context,
GenerateFunction,
GenerateOptionsUnified,
GenerateStream,
KnownProvider,
Model,
ReasoningEffort,
} from "./types.js";
export class QueuedGenerateStream implements GenerateStream {
private queue: AssistantMessageEvent[] = [];
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
private done = false;
private error?: Error;
private finalMessagePromise: Promise<AssistantMessage>;
private resolveFinalMessage!: (message: AssistantMessage) => void;
private rejectFinalMessage!: (error: Error) => void;
constructor() {
this.finalMessagePromise = new Promise((resolve, reject) => {
this.resolveFinalMessage = resolve;
this.rejectFinalMessage = reject;
});
}
push(event: AssistantMessageEvent): void {
if (this.done) return;
// If it's the done event, resolve the final message
if (event.type === "done") {
this.done = true;
this.resolveFinalMessage(event.message);
}
// If it's an error event, reject the final message
if (event.type === "error") {
this.error = new Error(event.error);
if (!this.done) {
this.rejectFinalMessage(this.error);
}
}
// Deliver to waiting consumer or queue it
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: event, done: false });
} else {
this.queue.push(event);
}
}
end(): void {
this.done = true;
// Notify all waiting consumers that we're done
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as any, done: true });
}
}
async *[Symbol.asyncIterator](): AsyncIterator<AssistantMessageEvent> {
while (true) {
// If we have queued events, yield them
if (this.queue.length > 0) {
yield this.queue.shift()!;
} else if (this.done) {
// No more events and we're done
return;
} else {
// Wait for next event
const result = await new Promise<IteratorResult<AssistantMessageEvent>>((resolve) =>
this.waiting.push(resolve),
);
if (result.done) return;
yield result.value;
}
}
}
finalMessage(): Promise<AssistantMessage> {
return this.finalMessagePromise;
}
}
// API implementations registry
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
/**
* Register a custom API implementation
*/
export function registerApi(api: string, impl: GenerateFunction): void {
apiImplementations.set(api, impl);
}
// API key storage
const apiKeys: Map<string, string> = new Map();
/**
* Set an API key for a provider
*/
export function setApiKey(provider: KnownProvider, key: string): void;
export function setApiKey(provider: string, key: string): void;
export function setApiKey(provider: any, key: string): void {
apiKeys.set(provider, key);
}
/**
* Get API key for a provider
*/
export function getApiKey(provider: KnownProvider): string | undefined;
export function getApiKey(provider: string): string | undefined;
export function getApiKey(provider: any): string | undefined {
// Check explicit keys first
const key = apiKeys.get(provider);
if (key) return key;
// Fall back to environment variables
const envMap: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
xai: "XAI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
};
const envVar = envMap[provider];
return envVar ? process.env[envVar] : undefined;
}
/**
* Main generate function
*/
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream {
// Get implementation
const impl = apiImplementations.get(model.api);
if (!impl) {
throw new Error(`Unsupported API: ${model.api}`);
}
// Get API key from options or environment
const apiKey = options?.apiKey || getApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
// Map generic options to provider-specific
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey);
// Return the GenerateStream from implementation
return impl(model, context, providerOptions);
}
/**
* Helper to generate and get complete response (no streaming)
*/
export async function generateComplete(
model: Model,
context: Context,
options?: GenerateOptionsUnified,
): Promise<AssistantMessage> {
const stream = generate(model, context, options);
return stream.finalMessage();
}
/**
* Map generic options to provider-specific options
*/
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any {
const base = {
temperature: options?.temperature,
maxTokens: options?.maxTokens,
signal: options?.signal,
apiKey: apiKey || options?.apiKey,
};
switch (api) {
case "openai-responses":
case "openai-completions":
return {
...base,
reasoning_effort: options?.reasoning,
};
case "anthropic-messages": {
if (!options?.reasoning) return base;
// Map effort to token budget
const anthropicBudgets = {
minimal: 1024,
low: 2048,
medium: 8192,
high: Math.min(25000, model.maxTokens - 1000),
};
return {
...base,
thinking: {
enabled: true,
budgetTokens: anthropicBudgets[options.reasoning],
},
};
}
case "google-generative-ai": {
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
// Model-specific mapping for Google
const googleBudget = getGoogleBudget(model, options.reasoning);
return {
...base,
thinking_budget: googleBudget,
};
}
default:
return base;
}
}
/**
* Get Google thinking budget based on model and effort
*/
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
// Model-specific logic
if (model.id.includes("flash-lite")) {
const budgets = {
minimal: 512,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
if (model.id.includes("pro")) {
const budgets = {
minimal: 128,
low: 2048,
medium: 8192,
high: Math.min(25000, 32768),
};
return budgets[effort];
}
if (model.id.includes("flash")) {
const budgets = {
minimal: 0, // Disable thinking
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
// Unknown model - use dynamic
return -1;
}
// Register built-in API implementations
// Import the new function-based implementations
import { generateAnthropic } from "./providers/anthropic-generate.js";
// Register Anthropic implementation
apiImplementations.set("anthropic-messages", generateAnthropic);

View file

@ -3,26 +3,26 @@
export const version = "0.5.8";
// Export generate API
export {
generate,
generateComplete,
getApiKey,
QueuedGenerateStream,
registerApi,
setApiKey,
} from "./generate.js";
// Export generated models data
export { PROVIDERS } from "./models.generated.js";
// Export models utilities and types
// Export model utilities
export {
type AnthropicModel,
type CerebrasModel,
createLLM,
type GoogleModel,
type GroqModel,
type Model,
type OpenAIModel,
type OpenRouterModel,
PROVIDER_CONFIG,
type ProviderModels,
type ProviderToLLM,
type XAIModel,
calculateCost,
getModel,
type KnownProvider,
registerModel,
} from "./models.js";
// Export providers
// Legacy providers (to be deprecated)
export { AnthropicLLM } from "./providers/anthropic.js";
export { GoogleLLM } from "./providers/google.js";
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
@ -30,3 +30,8 @@ export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
// Export types
export type * from "./types.js";
// TODO: Remove these legacy exports once consumers are updated
export function createLLM(): never {
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
}

View file

@ -1,108 +1,44 @@
import { PROVIDERS } from "./models.generated.js";
import { AnthropicLLM } from "./providers/anthropic.js";
import { GoogleLLM } from "./providers/google.js";
import { OpenAICompletionsLLM } from "./providers/openai-completions.js";
import { OpenAIResponsesLLM } from "./providers/openai-responses.js";
import type { Model, Usage } from "./types.js";
import type { KnownProvider, Model, Usage } from "./types.js";
// Provider configuration with factory functions
export const PROVIDER_CONFIG = {
google: {
envKey: "GEMINI_API_KEY",
create: (model: Model, apiKey: string) => new GoogleLLM(model, apiKey),
},
openai: {
envKey: "OPENAI_API_KEY",
create: (model: Model, apiKey: string) => new OpenAIResponsesLLM(model, apiKey),
},
anthropic: {
envKey: "ANTHROPIC_API_KEY",
create: (model: Model, apiKey: string) => new AnthropicLLM(model, apiKey),
},
xai: {
envKey: "XAI_API_KEY",
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
},
groq: {
envKey: "GROQ_API_KEY",
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
},
cerebras: {
envKey: "CEREBRAS_API_KEY",
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
},
openrouter: {
envKey: "OPENROUTER_API_KEY",
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
},
} as const;
// Re-export Model type
export type { KnownProvider, Model } from "./types.js";
// Type mapping from provider to LLM implementation
export type ProviderToLLM = {
google: GoogleLLM;
openai: OpenAIResponsesLLM;
anthropic: AnthropicLLM;
xai: OpenAICompletionsLLM;
groq: OpenAICompletionsLLM;
cerebras: OpenAICompletionsLLM;
openrouter: OpenAICompletionsLLM;
};
// Dynamic model registry initialized from PROVIDERS
const modelRegistry: Map<string, Map<string, Model>> = new Map();
// Extract model types for each provider
export type GoogleModel = keyof typeof PROVIDERS.google.models;
export type OpenAIModel = keyof typeof PROVIDERS.openai.models;
export type AnthropicModel = keyof typeof PROVIDERS.anthropic.models;
export type XAIModel = keyof typeof PROVIDERS.xai.models;
export type GroqModel = keyof typeof PROVIDERS.groq.models;
export type CerebrasModel = keyof typeof PROVIDERS.cerebras.models;
export type OpenRouterModel = keyof typeof PROVIDERS.openrouter.models;
// Map providers to their model types
export type ProviderModels = {
google: GoogleModel;
openai: OpenAIModel;
anthropic: AnthropicModel;
xai: XAIModel;
groq: GroqModel;
cerebras: CerebrasModel;
openrouter: OpenRouterModel;
};
// Single generic factory function
export function createLLM<P extends keyof typeof PROVIDERS, M extends keyof (typeof PROVIDERS)[P]["models"]>(
provider: P,
model: M,
apiKey?: string,
): ProviderToLLM[P] {
const config = PROVIDER_CONFIG[provider as keyof typeof PROVIDER_CONFIG];
if (!config) throw new Error(`Unknown provider: ${provider}`);
const providerData = PROVIDERS[provider];
if (!providerData) throw new Error(`Unknown provider: ${provider}`);
// Type-safe model lookup
const models = providerData.models as Record<string, Model>;
const modelData = models[model as string];
if (!modelData) throw new Error(`Unknown model: ${String(model)} for provider ${provider}`);
const key = apiKey || process.env[config.envKey];
if (!key) throw new Error(`No API key provided for ${provider}. Set ${config.envKey} or pass apiKey.`);
return config.create(modelData, key) as ProviderToLLM[P];
// Initialize registry from PROVIDERS on module load
for (const [provider, models] of Object.entries(PROVIDERS)) {
const providerModels = new Map<string, Model>();
for (const [id, model] of Object.entries(models)) {
providerModels.set(id, model as Model);
}
modelRegistry.set(provider, providerModels);
}
// Helper function to get model info with type-safe model IDs
export function getModel<P extends keyof typeof PROVIDERS>(
provider: P,
modelId: keyof (typeof PROVIDERS)[P]["models"],
): Model | undefined {
const providerData = PROVIDERS[provider];
if (!providerData) return undefined;
const models = providerData.models as Record<string, Model>;
return models[modelId as string];
/**
* Get a model from the registry - typed overload for known providers
*/
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
export function getModel(provider: string, modelId: string): Model | undefined;
export function getModel(provider: any, modelId: any): Model | undefined {
return modelRegistry.get(provider)?.get(modelId);
}
export function calculateCost(model: Model, usage: Usage) {
/**
* Register a custom model
*/
export function registerModel(model: Model): void {
if (!modelRegistry.has(model.provider)) {
modelRegistry.set(model.provider, new Map());
}
modelRegistry.get(model.provider)!.set(model.id, model);
}
/**
* Calculate cost for token usage
*/
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
usage.cost.input = (model.cost.input / 1000000) * usage.input;
usage.cost.output = (model.cost.output / 1000000) * usage.output;
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
@ -110,6 +46,3 @@ export function calculateCost(model: Model, usage: Usage) {
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
return usage.cost;
}
// Re-export Model type for convenience
export type { Model };

View file

@ -0,0 +1,425 @@
import Anthropic from "@anthropic-ai/sdk";
import type {
ContentBlockParam,
MessageCreateParamsStreaming,
MessageParam,
Tool,
} from "@anthropic-ai/sdk/resources/messages.js";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
GenerateStream,
Message,
Model,
StopReason,
TextContent,
ThinkingContent,
ToolCall,
} from "../types.js";
import { transformMessages } from "./utils.js";
// Anthropic-specific options
export interface AnthropicOptions extends GenerateOptions {
thinking?: {
enabled: boolean;
budgetTokens?: number;
};
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
/**
* Generate function for Anthropic API
*/
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
model: Model,
context: Context,
options: AnthropicOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
// Start async processing
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "anthropic-messages" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
};
try {
// Create Anthropic client
const client = createAnthropicClient(model, options.apiKey!);
// Convert messages
const messages = convertMessages(context.messages, model, "anthropic-messages");
// Build params
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
// Create Anthropic stream
const anthropicStream = client.client.messages.stream(
{
...params,
stream: true,
},
{
signal: options.signal,
},
);
// Emit start event
stream.push({
type: "start",
partial: output,
});
// Process Anthropic events
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
for await (const event of anthropicStream) {
if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
currentBlock = {
type: "text",
text: "",
};
output.content.push(currentBlock);
stream.push({ type: "text_start", partial: output });
} else if (event.content_block.type === "thinking") {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: "",
};
output.content.push(currentBlock);
stream.push({ type: "thinking_start", partial: output });
} else if (event.content_block.type === "tool_use") {
// We wait for the full tool use to be streamed
currentBlock = {
type: "toolCall",
id: event.content_block.id,
name: event.content_block.name,
arguments: event.content_block.input as Record<string, any>,
partialJson: "",
};
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
if (currentBlock && currentBlock.type === "text") {
currentBlock.text += event.delta.text;
stream.push({
type: "text_delta",
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
if (currentBlock && currentBlock.type === "toolCall") {
currentBlock.partialJson += event.delta.partial_json;
}
} else if (event.delta.type === "signature_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
currentBlock.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
} else if (currentBlock.type === "thinking") {
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
} else if (currentBlock.type === "toolCall") {
const finalToolCall: ToolCall = {
type: "toolCall",
id: currentBlock.id,
name: currentBlock.name,
arguments: JSON.parse(currentBlock.partialJson),
};
output.content.push(finalToolCall);
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
}
currentBlock = null;
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
output.stopReason = mapStopReason(event.delta.stop_reason);
}
output.usage.input += event.usage.input_tokens || 0;
output.usage.output += event.usage.output_tokens || 0;
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
calculateCost(model, output.usage);
}
}
// Emit done event with final message
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", error: output.error, partial: output });
stream.end();
}
})();
return stream;
};
// Helper to create Anthropic client
interface AnthropicClientWrapper {
client: Anthropic;
isOAuthToken: boolean;
}
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
if (apiKey.includes("sk-ant-oat")) {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
};
// Clear the env var if we're in Node.js to prevent SDK from using it
if (typeof process !== "undefined" && process.env) {
process.env.ANTHROPIC_API_KEY = undefined;
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
defaultHeaders,
dangerouslyAllowBrowser: true,
});
return { client, isOAuthToken: true };
} else {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
};
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders,
});
return { client, isOAuthToken: false };
}
}
// Build Anthropic API params
function buildAnthropicParams(
model: Model,
context: Context,
options: AnthropicOptions,
messages: MessageParam[],
isOAuthToken: boolean,
): MessageCreateParamsStreaming {
const params: MessageCreateParamsStreaming = {
model: model.id,
messages,
max_tokens: options.maxTokens || model.maxTokens,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: {
type: "ephemeral",
},
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: context.systemPrompt,
cache_control: {
type: "ephemeral",
},
});
}
} else if (context.systemPrompt) {
params.system = context.systemPrompt;
}
if (options.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools);
}
// Only enable thinking if the model supports it
if (options.thinking?.enabled && model.reasoning) {
params.thinking = {
type: "enabled",
budget_tokens: options.thinking.budgetTokens || 1024,
};
}
if (options.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
return params;
}
// Convert messages to Anthropic format
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, model, api);
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {
params.push({
role: "user",
content: msg.content,
});
} else {
// Convert array content to Anthropic format
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: item.text,
};
} else {
// Image content
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: item.data,
},
};
}
});
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
params.push({
role: "user",
content: filteredBlocks,
});
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
blocks.push({
type: "text",
text: block.text,
});
} else if (block.type === "thinking") {
blocks.push({
type: "thinking",
thinking: block.thinking,
signature: block.thinkingSignature || "",
});
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: block.name,
input: block.arguments,
});
}
}
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
params.push({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: msg.toolCallId,
content: msg.content,
is_error: msg.isError,
},
],
});
}
}
return params;
}
// Convert tools to Anthropic format
function convertTools(tools: Context["tools"]): Tool[] {
if (!tools) return [];
return tools.map((tool) => ({
name: tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: tool.parameters.properties || {},
required: tool.parameters.required || [],
},
}));
}
// Map Anthropic stop reason to our StopReason type
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
switch (reason) {
case "end_turn":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
case "refusal":
return "safety";
case "pause_turn": // Stop is good enough -> resubmit
return "stop";
case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen
default:
return "stop";
}
}

View file

@ -1,3 +1,38 @@
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
export type Api = KnownApi | string;
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
export type Provider = KnownProvider | string;
export type ReasoningEffort = "minimal" | "low" | "medium" | "high";
// The stream interface - what generate() returns
export interface GenerateStream extends AsyncIterable<AssistantMessageEvent> {
// Get the final message (waits for streaming to complete)
finalMessage(): Promise<AssistantMessage>;
}
// Base options all providers share
export interface GenerateOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
}
// Unified options with reasoning (what public generate() accepts)
export interface GenerateOptionsUnified extends GenerateOptions {
reasoning?: ReasoningEffort;
}
// Generic GenerateFunction with typed options
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = (
model: Model,
context: Context,
options: TOptions,
) => GenerateStream;
// Legacy LLM interface (to be removed)
export interface LLMOptions {
temperature?: number;
maxTokens?: number;
@ -60,11 +95,10 @@ export interface UserMessage {
export interface AssistantMessage {
role: "assistant";
content: (TextContent | ThinkingContent | ToolCall)[];
api: string;
provider: string;
api: Api;
provider: Provider;
model: string;
usage: Usage;
stopReason: StopReason;
error?: string | Error;
}
@ -92,23 +126,24 @@ export interface Context {
}
export type AssistantMessageEvent =
| { type: "start"; model: string; provider: string }
| { type: "text_start" }
| { type: "text_delta"; content: string; delta: string }
| { type: "text_end"; content: string }
| { type: "thinking_start" }
| { type: "thinking_delta"; content: string; delta: string }
| { type: "thinking_end"; content: string }
| { type: "toolCall"; toolCall: ToolCall }
| { type: "start"; partial: AssistantMessage }
| { type: "text_start"; partial: AssistantMessage }
| { type: "text_delta"; delta: string; partial: AssistantMessage }
| { type: "text_end"; content: string; partial: AssistantMessage }
| { type: "thinking_start"; partial: AssistantMessage }
| { type: "thinking_delta"; delta: string; partial: AssistantMessage }
| { type: "thinking_end"; content: string; partial: AssistantMessage }
| { type: "toolCall"; toolCall: ToolCall; partial: AssistantMessage }
| { type: "done"; reason: StopReason; message: AssistantMessage }
| { type: "error"; error: string };
| { type: "error"; error: string; partial: AssistantMessage };
// Model interface for the unified model system
export interface Model {
id: string;
name: string;
provider: string;
baseUrl?: string;
api: Api;
provider: Provider;
baseUrl: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {

View file

@ -0,0 +1,311 @@
import { describe, it, beforeAll, expect } from "vitest";
import { getModel } from "../src/models.js";
import { generate, generateComplete } from "../src/generate.js";
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
import { readFileSync } from "fs";
import { join, dirname } from "path";
import { fileURLToPath } from "url";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Calculator tool definition (same as examples)
const calculatorTool: Tool = {
name: "calculator",
description: "Perform basic arithmetic operations",
parameters: {
type: "object",
properties: {
a: { type: "number", description: "First number" },
b: { type: "number", description: "Second number" },
operation: {
type: "string",
enum: ["add", "subtract", "multiply", "divide"],
description: "The operation to perform"
}
},
required: ["a", "b", "operation"]
}
};
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) {
const context: Context = {
systemPrompt: "You are a helpful assistant. Be concise.",
messages: [
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
]
};
const response = await generateComplete(model, context, options);
expect(response.role).toBe("assistant");
expect(response.content).toBeTruthy();
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.error).toBeFalsy();
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
context.messages.push(response);
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
const secondResponse = await generateComplete(model, context, options);
expect(secondResponse.role).toBe("assistant");
expect(secondResponse.content).toBeTruthy();
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
expect(secondResponse.usage.output).toBeGreaterThan(0);
expect(secondResponse.error).toBeFalsy();
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
}
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) {
const context: Context = {
systemPrompt: "You are a helpful assistant that uses tools when asked.",
messages: [{
role: "user",
content: "Calculate 15 + 27 using the calculator tool."
}],
tools: [calculatorTool]
};
const response = await generateComplete(model, context, options);
expect(response.stopReason).toBe("toolUse");
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
const toolCall = response.content.find(b => b.type == "toolCall");
if (toolCall && toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator");
expect(toolCall.id).toBeTruthy();
}
}
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) {
let textStarted = false;
let textChunks = "";
let textCompleted = false;
const context: Context = {
messages: [{ role: "user", content: "Count from 1 to 3" }]
};
const stream = generate(model, context, options);
for await (const event of stream) {
if (event.type === "text_start") {
textStarted = true;
} else if (event.type === "text_delta") {
textChunks += event.delta;
} else if (event.type === "text_end") {
textCompleted = true;
}
}
const response = await stream.finalMessage();
expect(textStarted).toBe(true);
expect(textChunks.length).toBeGreaterThan(0);
expect(textCompleted).toBe(true);
expect(response.content.some(b => b.type == "text")).toBeTruthy();
}
async function handleThinking(model: Model, options: GenerateOptionsUnified) {
let thinkingStarted = false;
let thinkingChunks = "";
let thinkingCompleted = false;
const context: Context = {
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
};
const stream = generate(model, context, options);
for await (const event of stream) {
if (event.type === "thinking_start") {
thinkingStarted = true;
} else if (event.type === "thinking_delta") {
thinkingChunks += event.delta;
} else if (event.type === "thinking_end") {
thinkingCompleted = true;
}
}
const response = await stream.finalMessage();
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
expect(thinkingStarted).toBe(true);
expect(thinkingChunks.length).toBeGreaterThan(0);
expect(thinkingCompleted).toBe(true);
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
}
async function handleImage(model: Model, options?: GenerateOptionsUnified) {
// Check if the model supports images
if (!model.input.includes("image")) {
console.log(`Skipping image test - model ${model.id} doesn't support images`);
return;
}
// Read the test image
const imagePath = join(__dirname, "data", "red-circle.png");
const imageBuffer = readFileSync(imagePath);
const base64Image = imageBuffer.toString("base64");
const imageContent: ImageContent = {
type: "image",
data: base64Image,
mimeType: "image/png",
};
const context: Context = {
messages: [
{
role: "user",
content: [
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
imageContent,
],
},
],
};
const response = await generateComplete(model, context, options);
// Check the response mentions red and circle
expect(response.content.length > 0).toBeTruthy();
const textContent = response.content.find(b => b.type == "text");
if (textContent && textContent.type === "text") {
const lowerContent = textContent.text.toLowerCase();
expect(lowerContent).toContain("red");
expect(lowerContent).toContain("circle");
}
}
async function multiTurn(model: Model, options?: GenerateOptionsUnified) {
const context: Context = {
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
messages: [
{
role: "user",
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
}
],
tools: [calculatorTool]
};
// Collect all text content from all assistant responses
let allTextContent = "";
let hasSeenThinking = false;
let hasSeenToolCalls = false;
const maxTurns = 5; // Prevent infinite loops
for (let turn = 0; turn < maxTurns; turn++) {
const response = await generateComplete(model, context, options);
// Add the assistant response to context
context.messages.push(response);
// Process content blocks
for (const block of response.content) {
if (block.type === "text") {
allTextContent += block.text;
} else if (block.type === "thinking") {
hasSeenThinking = true;
} else if (block.type === "toolCall") {
hasSeenToolCalls = true;
// Process the tool call
expect(block.name).toBe("calculator");
expect(block.id).toBeTruthy();
expect(block.arguments).toBeTruthy();
const { a, b, operation } = block.arguments;
let result: number;
switch (operation) {
case "add": result = a + b; break;
case "multiply": result = a * b; break;
default: result = 0;
}
// Add tool result to context
context.messages.push({
role: "toolResult",
toolCallId: block.id,
toolName: block.name,
content: `${result}`,
isError: false
});
}
}
// If we got a stop response with text content, we're likely done
expect(response.stopReason).not.toBe("error");
if (response.stopReason === "stop") {
break;
}
}
// Verify we got either thinking content or tool calls (or both)
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
// The accumulated text should reference both calculations
expect(allTextContent).toBeTruthy();
expect(allTextContent.includes("714")).toBe(true);
expect(allTextContent.includes("887")).toBe(true);
}
describe("Generate E2E Tests", () => {
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
let model: Model;
beforeAll(() => {
model = getModel("anthropic", "claude-3-5-haiku-20241022");
});
it("should complete basic text generation", async () => {
await basicTextGeneration(model);
});
it("should handle tool calling", async () => {
await handleToolCall(model);
});
it("should handle streaming", async () => {
await handleStreaming(model);
});
it("should handle image input", async () => {
await handleImage(model);
});
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
let model: Model;
beforeAll(() => {
model = getModel("anthropic", "claude-sonnet-4-20250514");
});
it("should complete basic text generation", async () => {
await basicTextGeneration(model);
});
it("should handle tool calling", async () => {
await handleToolCall(model);
});
it("should handle streaming", async () => {
await handleStreaming(model);
});
it("should handle thinking mode", async () => {
await handleThinking(model, { reasoning: "low" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(model, { reasoning: "medium" });
});
it("should handle image input", async () => {
await handleImage(model);
});
});
});