From 004de3c9d0153755a3364aaee8e4a7418aeb0048 Mon Sep 17 00:00:00 2001 From: Mario Zechner Date: Tue, 2 Sep 2025 18:07:46 +0200 Subject: [PATCH] feat(ai): Add new streaming generate API with AsyncIterable interface - Implement QueuedGenerateStream class that extends AsyncIterable with finalMessage() method - Add new types: GenerateStream, GenerateOptions, GenerateOptionsUnified, GenerateFunction - Create generateAnthropic function-based implementation replacing class-based approach - Add comprehensive test suite for the new generate API - Support streaming events with text, thinking, and tool call deltas - Map ReasoningEffort to provider-specific options - Include apiKey in options instead of constructor parameter --- packages/ai/src/generate.ts | 268 +++++++++++ packages/ai/src/index.ts | 35 +- packages/ai/src/models.ts | 133 ++---- .../ai/src/providers/anthropic-generate.ts | 425 ++++++++++++++++++ packages/ai/src/types.ts | 63 ++- packages/ai/test/generate.test.ts | 311 +++++++++++++ 6 files changed, 1106 insertions(+), 129 deletions(-) create mode 100644 packages/ai/src/generate.ts create mode 100644 packages/ai/src/providers/anthropic-generate.ts create mode 100644 packages/ai/test/generate.test.ts diff --git a/packages/ai/src/generate.ts b/packages/ai/src/generate.ts new file mode 100644 index 00000000..170c16cb --- /dev/null +++ b/packages/ai/src/generate.ts @@ -0,0 +1,268 @@ +import type { + Api, + AssistantMessage, + AssistantMessageEvent, + Context, + GenerateFunction, + GenerateOptionsUnified, + GenerateStream, + KnownProvider, + Model, + ReasoningEffort, +} from "./types.js"; + +export class QueuedGenerateStream implements GenerateStream { + private queue: AssistantMessageEvent[] = []; + private waiting: ((value: IteratorResult) => void)[] = []; + private done = false; + private error?: Error; + private finalMessagePromise: Promise; + private resolveFinalMessage!: (message: AssistantMessage) => void; + private rejectFinalMessage!: (error: Error) => void; + + constructor() { + this.finalMessagePromise = new Promise((resolve, reject) => { + this.resolveFinalMessage = resolve; + this.rejectFinalMessage = reject; + }); + } + + push(event: AssistantMessageEvent): void { + if (this.done) return; + + // If it's the done event, resolve the final message + if (event.type === "done") { + this.done = true; + this.resolveFinalMessage(event.message); + } + + // If it's an error event, reject the final message + if (event.type === "error") { + this.error = new Error(event.error); + if (!this.done) { + this.rejectFinalMessage(this.error); + } + } + + // Deliver to waiting consumer or queue it + const waiter = this.waiting.shift(); + if (waiter) { + waiter({ value: event, done: false }); + } else { + this.queue.push(event); + } + } + + end(): void { + this.done = true; + // Notify all waiting consumers that we're done + while (this.waiting.length > 0) { + const waiter = this.waiting.shift()!; + waiter({ value: undefined as any, done: true }); + } + } + + async *[Symbol.asyncIterator](): AsyncIterator { + while (true) { + // If we have queued events, yield them + if (this.queue.length > 0) { + yield this.queue.shift()!; + } else if (this.done) { + // No more events and we're done + return; + } else { + // Wait for next event + const result = await new Promise>((resolve) => + this.waiting.push(resolve), + ); + if (result.done) return; + yield result.value; + } + } + } + + finalMessage(): Promise { + return this.finalMessagePromise; + } +} + +// API implementations registry +const apiImplementations: Map = new Map(); + +/** + * Register a custom API implementation + */ +export function registerApi(api: string, impl: GenerateFunction): void { + apiImplementations.set(api, impl); +} + +// API key storage +const apiKeys: Map = new Map(); + +/** + * Set an API key for a provider + */ +export function setApiKey(provider: KnownProvider, key: string): void; +export function setApiKey(provider: string, key: string): void; +export function setApiKey(provider: any, key: string): void { + apiKeys.set(provider, key); +} + +/** + * Get API key for a provider + */ +export function getApiKey(provider: KnownProvider): string | undefined; +export function getApiKey(provider: string): string | undefined; +export function getApiKey(provider: any): string | undefined { + // Check explicit keys first + const key = apiKeys.get(provider); + if (key) return key; + + // Fall back to environment variables + const envMap: Record = { + openai: "OPENAI_API_KEY", + anthropic: "ANTHROPIC_API_KEY", + google: "GEMINI_API_KEY", + groq: "GROQ_API_KEY", + cerebras: "CEREBRAS_API_KEY", + xai: "XAI_API_KEY", + openrouter: "OPENROUTER_API_KEY", + }; + + const envVar = envMap[provider]; + return envVar ? process.env[envVar] : undefined; +} + +/** + * Main generate function + */ +export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream { + // Get implementation + const impl = apiImplementations.get(model.api); + if (!impl) { + throw new Error(`Unsupported API: ${model.api}`); + } + + // Get API key from options or environment + const apiKey = options?.apiKey || getApiKey(model.provider); + if (!apiKey) { + throw new Error(`No API key for provider: ${model.provider}`); + } + + // Map generic options to provider-specific + const providerOptions = mapOptionsForApi(model.api, model, options, apiKey); + + // Return the GenerateStream from implementation + return impl(model, context, providerOptions); +} + +/** + * Helper to generate and get complete response (no streaming) + */ +export async function generateComplete( + model: Model, + context: Context, + options?: GenerateOptionsUnified, +): Promise { + const stream = generate(model, context, options); + return stream.finalMessage(); +} + +/** + * Map generic options to provider-specific options + */ +function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any { + const base = { + temperature: options?.temperature, + maxTokens: options?.maxTokens, + signal: options?.signal, + apiKey: apiKey || options?.apiKey, + }; + + switch (api) { + case "openai-responses": + case "openai-completions": + return { + ...base, + reasoning_effort: options?.reasoning, + }; + + case "anthropic-messages": { + if (!options?.reasoning) return base; + + // Map effort to token budget + const anthropicBudgets = { + minimal: 1024, + low: 2048, + medium: 8192, + high: Math.min(25000, model.maxTokens - 1000), + }; + + return { + ...base, + thinking: { + enabled: true, + budgetTokens: anthropicBudgets[options.reasoning], + }, + }; + } + case "google-generative-ai": { + if (!options?.reasoning) return { ...base, thinking_budget: -1 }; + + // Model-specific mapping for Google + const googleBudget = getGoogleBudget(model, options.reasoning); + return { + ...base, + thinking_budget: googleBudget, + }; + } + default: + return base; + } +} + +/** + * Get Google thinking budget based on model and effort + */ +function getGoogleBudget(model: Model, effort: ReasoningEffort): number { + // Model-specific logic + if (model.id.includes("flash-lite")) { + const budgets = { + minimal: 512, + low: 2048, + medium: 8192, + high: 24576, + }; + return budgets[effort]; + } + + if (model.id.includes("pro")) { + const budgets = { + minimal: 128, + low: 2048, + medium: 8192, + high: Math.min(25000, 32768), + }; + return budgets[effort]; + } + + if (model.id.includes("flash")) { + const budgets = { + minimal: 0, // Disable thinking + low: 2048, + medium: 8192, + high: 24576, + }; + return budgets[effort]; + } + + // Unknown model - use dynamic + return -1; +} + +// Register built-in API implementations +// Import the new function-based implementations +import { generateAnthropic } from "./providers/anthropic-generate.js"; + +// Register Anthropic implementation +apiImplementations.set("anthropic-messages", generateAnthropic); diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 13caf2e8..1e18b02b 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -3,26 +3,26 @@ export const version = "0.5.8"; +// Export generate API +export { + generate, + generateComplete, + getApiKey, + QueuedGenerateStream, + registerApi, + setApiKey, +} from "./generate.js"; // Export generated models data export { PROVIDERS } from "./models.generated.js"; - -// Export models utilities and types +// Export model utilities export { - type AnthropicModel, - type CerebrasModel, - createLLM, - type GoogleModel, - type GroqModel, - type Model, - type OpenAIModel, - type OpenRouterModel, - PROVIDER_CONFIG, - type ProviderModels, - type ProviderToLLM, - type XAIModel, + calculateCost, + getModel, + type KnownProvider, + registerModel, } from "./models.js"; -// Export providers +// Legacy providers (to be deprecated) export { AnthropicLLM } from "./providers/anthropic.js"; export { GoogleLLM } from "./providers/google.js"; export { OpenAICompletionsLLM } from "./providers/openai-completions.js"; @@ -30,3 +30,8 @@ export { OpenAIResponsesLLM } from "./providers/openai-responses.js"; // Export types export type * from "./types.js"; + +// TODO: Remove these legacy exports once consumers are updated +export function createLLM(): never { + throw new Error("createLLM is deprecated. Use generate() with getModel() instead."); +} diff --git a/packages/ai/src/models.ts b/packages/ai/src/models.ts index c033d399..3a202585 100644 --- a/packages/ai/src/models.ts +++ b/packages/ai/src/models.ts @@ -1,108 +1,44 @@ import { PROVIDERS } from "./models.generated.js"; -import { AnthropicLLM } from "./providers/anthropic.js"; -import { GoogleLLM } from "./providers/google.js"; -import { OpenAICompletionsLLM } from "./providers/openai-completions.js"; -import { OpenAIResponsesLLM } from "./providers/openai-responses.js"; -import type { Model, Usage } from "./types.js"; +import type { KnownProvider, Model, Usage } from "./types.js"; -// Provider configuration with factory functions -export const PROVIDER_CONFIG = { - google: { - envKey: "GEMINI_API_KEY", - create: (model: Model, apiKey: string) => new GoogleLLM(model, apiKey), - }, - openai: { - envKey: "OPENAI_API_KEY", - create: (model: Model, apiKey: string) => new OpenAIResponsesLLM(model, apiKey), - }, - anthropic: { - envKey: "ANTHROPIC_API_KEY", - create: (model: Model, apiKey: string) => new AnthropicLLM(model, apiKey), - }, - xai: { - envKey: "XAI_API_KEY", - create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey), - }, - groq: { - envKey: "GROQ_API_KEY", - create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey), - }, - cerebras: { - envKey: "CEREBRAS_API_KEY", - create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey), - }, - openrouter: { - envKey: "OPENROUTER_API_KEY", - create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey), - }, -} as const; +// Re-export Model type +export type { KnownProvider, Model } from "./types.js"; -// Type mapping from provider to LLM implementation -export type ProviderToLLM = { - google: GoogleLLM; - openai: OpenAIResponsesLLM; - anthropic: AnthropicLLM; - xai: OpenAICompletionsLLM; - groq: OpenAICompletionsLLM; - cerebras: OpenAICompletionsLLM; - openrouter: OpenAICompletionsLLM; -}; +// Dynamic model registry initialized from PROVIDERS +const modelRegistry: Map> = new Map(); -// Extract model types for each provider -export type GoogleModel = keyof typeof PROVIDERS.google.models; -export type OpenAIModel = keyof typeof PROVIDERS.openai.models; -export type AnthropicModel = keyof typeof PROVIDERS.anthropic.models; -export type XAIModel = keyof typeof PROVIDERS.xai.models; -export type GroqModel = keyof typeof PROVIDERS.groq.models; -export type CerebrasModel = keyof typeof PROVIDERS.cerebras.models; -export type OpenRouterModel = keyof typeof PROVIDERS.openrouter.models; - -// Map providers to their model types -export type ProviderModels = { - google: GoogleModel; - openai: OpenAIModel; - anthropic: AnthropicModel; - xai: XAIModel; - groq: GroqModel; - cerebras: CerebrasModel; - openrouter: OpenRouterModel; -}; - -// Single generic factory function -export function createLLM

( - provider: P, - model: M, - apiKey?: string, -): ProviderToLLM[P] { - const config = PROVIDER_CONFIG[provider as keyof typeof PROVIDER_CONFIG]; - if (!config) throw new Error(`Unknown provider: ${provider}`); - - const providerData = PROVIDERS[provider]; - if (!providerData) throw new Error(`Unknown provider: ${provider}`); - - // Type-safe model lookup - const models = providerData.models as Record; - const modelData = models[model as string]; - if (!modelData) throw new Error(`Unknown model: ${String(model)} for provider ${provider}`); - - const key = apiKey || process.env[config.envKey]; - if (!key) throw new Error(`No API key provided for ${provider}. Set ${config.envKey} or pass apiKey.`); - - return config.create(modelData, key) as ProviderToLLM[P]; +// Initialize registry from PROVIDERS on module load +for (const [provider, models] of Object.entries(PROVIDERS)) { + const providerModels = new Map(); + for (const [id, model] of Object.entries(models)) { + providerModels.set(id, model as Model); + } + modelRegistry.set(provider, providerModels); } -// Helper function to get model info with type-safe model IDs -export function getModel

( - provider: P, - modelId: keyof (typeof PROVIDERS)[P]["models"], -): Model | undefined { - const providerData = PROVIDERS[provider]; - if (!providerData) return undefined; - const models = providerData.models as Record; - return models[modelId as string]; +/** + * Get a model from the registry - typed overload for known providers + */ +export function getModel

(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model; +export function getModel(provider: string, modelId: string): Model | undefined; +export function getModel(provider: any, modelId: any): Model | undefined { + return modelRegistry.get(provider)?.get(modelId); } -export function calculateCost(model: Model, usage: Usage) { +/** + * Register a custom model + */ +export function registerModel(model: Model): void { + if (!modelRegistry.has(model.provider)) { + modelRegistry.set(model.provider, new Map()); + } + modelRegistry.get(model.provider)!.set(model.id, model); +} + +/** + * Calculate cost for token usage + */ +export function calculateCost(model: Model, usage: Usage): Usage["cost"] { usage.cost.input = (model.cost.input / 1000000) * usage.input; usage.cost.output = (model.cost.output / 1000000) * usage.output; usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead; @@ -110,6 +46,3 @@ export function calculateCost(model: Model, usage: Usage) { usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite; return usage.cost; } - -// Re-export Model type for convenience -export type { Model }; diff --git a/packages/ai/src/providers/anthropic-generate.ts b/packages/ai/src/providers/anthropic-generate.ts new file mode 100644 index 00000000..bb1f8043 --- /dev/null +++ b/packages/ai/src/providers/anthropic-generate.ts @@ -0,0 +1,425 @@ +import Anthropic from "@anthropic-ai/sdk"; +import type { + ContentBlockParam, + MessageCreateParamsStreaming, + MessageParam, + Tool, +} from "@anthropic-ai/sdk/resources/messages.js"; +import { QueuedGenerateStream } from "../generate.js"; +import { calculateCost } from "../models.js"; +import type { + Api, + AssistantMessage, + Context, + GenerateFunction, + GenerateOptions, + GenerateStream, + Message, + Model, + StopReason, + TextContent, + ThinkingContent, + ToolCall, +} from "../types.js"; +import { transformMessages } from "./utils.js"; + +// Anthropic-specific options +export interface AnthropicOptions extends GenerateOptions { + thinking?: { + enabled: boolean; + budgetTokens?: number; + }; + toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string }; +} + +/** + * Generate function for Anthropic API + */ +export const generateAnthropic: GenerateFunction = ( + model: Model, + context: Context, + options: AnthropicOptions, +): GenerateStream => { + const stream = new QueuedGenerateStream(); + + // Start async processing + (async () => { + const output: AssistantMessage = { + role: "assistant", + content: [], + api: "anthropic-messages" as Api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + }; + + try { + // Create Anthropic client + const client = createAnthropicClient(model, options.apiKey!); + + // Convert messages + const messages = convertMessages(context.messages, model, "anthropic-messages"); + + // Build params + const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken); + + // Create Anthropic stream + const anthropicStream = client.client.messages.stream( + { + ...params, + stream: true, + }, + { + signal: options.signal, + }, + ); + + // Emit start event + stream.push({ + type: "start", + partial: output, + }); + + // Process Anthropic events + let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null; + + for await (const event of anthropicStream) { + if (event.type === "content_block_start") { + if (event.content_block.type === "text") { + currentBlock = { + type: "text", + text: "", + }; + output.content.push(currentBlock); + stream.push({ type: "text_start", partial: output }); + } else if (event.content_block.type === "thinking") { + currentBlock = { + type: "thinking", + thinking: "", + thinkingSignature: "", + }; + output.content.push(currentBlock); + stream.push({ type: "thinking_start", partial: output }); + } else if (event.content_block.type === "tool_use") { + // We wait for the full tool use to be streamed + currentBlock = { + type: "toolCall", + id: event.content_block.id, + name: event.content_block.name, + arguments: event.content_block.input as Record, + partialJson: "", + }; + } + } else if (event.type === "content_block_delta") { + if (event.delta.type === "text_delta") { + if (currentBlock && currentBlock.type === "text") { + currentBlock.text += event.delta.text; + stream.push({ + type: "text_delta", + delta: event.delta.text, + partial: output, + }); + } + } else if (event.delta.type === "thinking_delta") { + if (currentBlock && currentBlock.type === "thinking") { + currentBlock.thinking += event.delta.thinking; + stream.push({ + type: "thinking_delta", + delta: event.delta.thinking, + partial: output, + }); + } + } else if (event.delta.type === "input_json_delta") { + if (currentBlock && currentBlock.type === "toolCall") { + currentBlock.partialJson += event.delta.partial_json; + } + } else if (event.delta.type === "signature_delta") { + if (currentBlock && currentBlock.type === "thinking") { + currentBlock.thinkingSignature = currentBlock.thinkingSignature || ""; + currentBlock.thinkingSignature += event.delta.signature; + } + } + } else if (event.type === "content_block_stop") { + if (currentBlock) { + if (currentBlock.type === "text") { + stream.push({ type: "text_end", content: currentBlock.text, partial: output }); + } else if (currentBlock.type === "thinking") { + stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output }); + } else if (currentBlock.type === "toolCall") { + const finalToolCall: ToolCall = { + type: "toolCall", + id: currentBlock.id, + name: currentBlock.name, + arguments: JSON.parse(currentBlock.partialJson), + }; + output.content.push(finalToolCall); + stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output }); + } + currentBlock = null; + } + } else if (event.type === "message_delta") { + if (event.delta.stop_reason) { + output.stopReason = mapStopReason(event.delta.stop_reason); + } + output.usage.input += event.usage.input_tokens || 0; + output.usage.output += event.usage.output_tokens || 0; + output.usage.cacheRead += event.usage.cache_read_input_tokens || 0; + output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0; + calculateCost(model, output.usage); + } + } + + // Emit done event with final message + stream.push({ type: "done", reason: output.stopReason, message: output }); + stream.end(); + } catch (error) { + output.stopReason = "error"; + output.error = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", error: output.error, partial: output }); + stream.end(); + } + })(); + + return stream; +}; + +// Helper to create Anthropic client +interface AnthropicClientWrapper { + client: Anthropic; + isOAuthToken: boolean; +} + +function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper { + if (apiKey.includes("sk-ant-oat")) { + const defaultHeaders = { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14", + }; + + // Clear the env var if we're in Node.js to prevent SDK from using it + if (typeof process !== "undefined" && process.env) { + process.env.ANTHROPIC_API_KEY = undefined; + } + + const client = new Anthropic({ + apiKey: null, + authToken: apiKey, + baseURL: model.baseUrl, + defaultHeaders, + dangerouslyAllowBrowser: true, + }); + + return { client, isOAuthToken: true }; + } else { + const defaultHeaders = { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": "fine-grained-tool-streaming-2025-05-14", + }; + + const client = new Anthropic({ + apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders, + }); + + return { client, isOAuthToken: false }; + } +} + +// Build Anthropic API params +function buildAnthropicParams( + model: Model, + context: Context, + options: AnthropicOptions, + messages: MessageParam[], + isOAuthToken: boolean, +): MessageCreateParamsStreaming { + const params: MessageCreateParamsStreaming = { + model: model.id, + messages, + max_tokens: options.maxTokens || model.maxTokens, + stream: true, + }; + + // For OAuth tokens, we MUST include Claude Code identity + if (isOAuthToken) { + params.system = [ + { + type: "text", + text: "You are Claude Code, Anthropic's official CLI for Claude.", + cache_control: { + type: "ephemeral", + }, + }, + ]; + if (context.systemPrompt) { + params.system.push({ + type: "text", + text: context.systemPrompt, + cache_control: { + type: "ephemeral", + }, + }); + } + } else if (context.systemPrompt) { + params.system = context.systemPrompt; + } + + if (options.temperature !== undefined) { + params.temperature = options.temperature; + } + + if (context.tools) { + params.tools = convertTools(context.tools); + } + + // Only enable thinking if the model supports it + if (options.thinking?.enabled && model.reasoning) { + params.thinking = { + type: "enabled", + budget_tokens: options.thinking.budgetTokens || 1024, + }; + } + + if (options.toolChoice) { + if (typeof options.toolChoice === "string") { + params.tool_choice = { type: options.toolChoice }; + } else { + params.tool_choice = options.toolChoice; + } + } + + return params; +} + +// Convert messages to Anthropic format +function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] { + const params: MessageParam[] = []; + + // Transform messages for cross-provider compatibility + const transformedMessages = transformMessages(messages, model, api); + + for (const msg of transformedMessages) { + if (msg.role === "user") { + // Handle both string and array content + if (typeof msg.content === "string") { + params.push({ + role: "user", + content: msg.content, + }); + } else { + // Convert array content to Anthropic format + const blocks: ContentBlockParam[] = msg.content.map((item) => { + if (item.type === "text") { + return { + type: "text", + text: item.text, + }; + } else { + // Image content + return { + type: "image", + source: { + type: "base64", + media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp", + data: item.data, + }, + }; + } + }); + const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks; + params.push({ + role: "user", + content: filteredBlocks, + }); + } + } else if (msg.role === "assistant") { + const blocks: ContentBlockParam[] = []; + + for (const block of msg.content) { + if (block.type === "text") { + blocks.push({ + type: "text", + text: block.text, + }); + } else if (block.type === "thinking") { + blocks.push({ + type: "thinking", + thinking: block.thinking, + signature: block.thinkingSignature || "", + }); + } else if (block.type === "toolCall") { + blocks.push({ + type: "tool_use", + id: block.id, + name: block.name, + input: block.arguments, + }); + } + } + + params.push({ + role: "assistant", + content: blocks, + }); + } else if (msg.role === "toolResult") { + params.push({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: msg.toolCallId, + content: msg.content, + is_error: msg.isError, + }, + ], + }); + } + } + return params; +} + +// Convert tools to Anthropic format +function convertTools(tools: Context["tools"]): Tool[] { + if (!tools) return []; + + return tools.map((tool) => ({ + name: tool.name, + description: tool.description, + input_schema: { + type: "object" as const, + properties: tool.parameters.properties || {}, + required: tool.parameters.required || [], + }, + })); +} + +// Map Anthropic stop reason to our StopReason type +function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason { + switch (reason) { + case "end_turn": + return "stop"; + case "max_tokens": + return "length"; + case "tool_use": + return "toolUse"; + case "refusal": + return "safety"; + case "pause_turn": // Stop is good enough -> resubmit + return "stop"; + case "stop_sequence": + return "stop"; // We don't supply stop sequences, so this should never happen + default: + return "stop"; + } +} diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 8a17ddd2..15f40d66 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -1,3 +1,38 @@ +export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai"; +export type Api = KnownApi | string; + +export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter"; +export type Provider = KnownProvider | string; + +export type ReasoningEffort = "minimal" | "low" | "medium" | "high"; + +// The stream interface - what generate() returns +export interface GenerateStream extends AsyncIterable { + // Get the final message (waits for streaming to complete) + finalMessage(): Promise; +} + +// Base options all providers share +export interface GenerateOptions { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + apiKey?: string; +} + +// Unified options with reasoning (what public generate() accepts) +export interface GenerateOptionsUnified extends GenerateOptions { + reasoning?: ReasoningEffort; +} + +// Generic GenerateFunction with typed options +export type GenerateFunction = ( + model: Model, + context: Context, + options: TOptions, +) => GenerateStream; + +// Legacy LLM interface (to be removed) export interface LLMOptions { temperature?: number; maxTokens?: number; @@ -60,11 +95,10 @@ export interface UserMessage { export interface AssistantMessage { role: "assistant"; content: (TextContent | ThinkingContent | ToolCall)[]; - api: string; - provider: string; + api: Api; + provider: Provider; model: string; usage: Usage; - stopReason: StopReason; error?: string | Error; } @@ -92,23 +126,24 @@ export interface Context { } export type AssistantMessageEvent = - | { type: "start"; model: string; provider: string } - | { type: "text_start" } - | { type: "text_delta"; content: string; delta: string } - | { type: "text_end"; content: string } - | { type: "thinking_start" } - | { type: "thinking_delta"; content: string; delta: string } - | { type: "thinking_end"; content: string } - | { type: "toolCall"; toolCall: ToolCall } + | { type: "start"; partial: AssistantMessage } + | { type: "text_start"; partial: AssistantMessage } + | { type: "text_delta"; delta: string; partial: AssistantMessage } + | { type: "text_end"; content: string; partial: AssistantMessage } + | { type: "thinking_start"; partial: AssistantMessage } + | { type: "thinking_delta"; delta: string; partial: AssistantMessage } + | { type: "thinking_end"; content: string; partial: AssistantMessage } + | { type: "toolCall"; toolCall: ToolCall; partial: AssistantMessage } | { type: "done"; reason: StopReason; message: AssistantMessage } - | { type: "error"; error: string }; + | { type: "error"; error: string; partial: AssistantMessage }; // Model interface for the unified model system export interface Model { id: string; name: string; - provider: string; - baseUrl?: string; + api: Api; + provider: Provider; + baseUrl: string; reasoning: boolean; input: ("text" | "image")[]; cost: { diff --git a/packages/ai/test/generate.test.ts b/packages/ai/test/generate.test.ts new file mode 100644 index 00000000..772d456e --- /dev/null +++ b/packages/ai/test/generate.test.ts @@ -0,0 +1,311 @@ +import { describe, it, beforeAll, expect } from "vitest"; +import { getModel } from "../src/models.js"; +import { generate, generateComplete } from "../src/generate.js"; +import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js"; +import { readFileSync } from "fs"; +import { join, dirname } from "path"; +import { fileURLToPath } from "url"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); + +// Calculator tool definition (same as examples) +const calculatorTool: Tool = { + name: "calculator", + description: "Perform basic arithmetic operations", + parameters: { + type: "object", + properties: { + a: { type: "number", description: "First number" }, + b: { type: "number", description: "Second number" }, + operation: { + type: "string", + enum: ["add", "subtract", "multiply", "divide"], + description: "The operation to perform" + } + }, + required: ["a", "b", "operation"] + } +}; + +async function basicTextGeneration

(model: Model, options?: P) { + const context: Context = { + systemPrompt: "You are a helpful assistant. Be concise.", + messages: [ + { role: "user", content: "Reply with exactly: 'Hello test successful'" } + ] + }; + + const response = await generateComplete(model, context, options); + + expect(response.role).toBe("assistant"); + expect(response.content).toBeTruthy(); + expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0); + expect(response.usage.output).toBeGreaterThan(0); + expect(response.error).toBeFalsy(); + expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful"); + + context.messages.push(response); + context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" }); + + const secondResponse = await generateComplete(model, context, options); + + expect(secondResponse.role).toBe("assistant"); + expect(secondResponse.content).toBeTruthy(); + expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0); + expect(secondResponse.usage.output).toBeGreaterThan(0); + expect(secondResponse.error).toBeFalsy(); + expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful"); +} + +async function handleToolCall(model: Model, options?: GenerateOptionsUnified) { + const context: Context = { + systemPrompt: "You are a helpful assistant that uses tools when asked.", + messages: [{ + role: "user", + content: "Calculate 15 + 27 using the calculator tool." + }], + tools: [calculatorTool] + }; + + const response = await generateComplete(model, context, options); + expect(response.stopReason).toBe("toolUse"); + expect(response.content.some(b => b.type == "toolCall")).toBeTruthy(); + const toolCall = response.content.find(b => b.type == "toolCall"); + if (toolCall && toolCall.type === "toolCall") { + expect(toolCall.name).toBe("calculator"); + expect(toolCall.id).toBeTruthy(); + } +} + +async function handleStreaming(model: Model, options?: GenerateOptionsUnified) { + let textStarted = false; + let textChunks = ""; + let textCompleted = false; + + const context: Context = { + messages: [{ role: "user", content: "Count from 1 to 3" }] + }; + + const stream = generate(model, context, options); + + for await (const event of stream) { + if (event.type === "text_start") { + textStarted = true; + } else if (event.type === "text_delta") { + textChunks += event.delta; + } else if (event.type === "text_end") { + textCompleted = true; + } + } + + const response = await stream.finalMessage(); + + expect(textStarted).toBe(true); + expect(textChunks.length).toBeGreaterThan(0); + expect(textCompleted).toBe(true); + expect(response.content.some(b => b.type == "text")).toBeTruthy(); +} + +async function handleThinking(model: Model, options: GenerateOptionsUnified) { + let thinkingStarted = false; + let thinkingChunks = ""; + let thinkingCompleted = false; + + const context: Context = { + messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }] + }; + + const stream = generate(model, context, options); + + for await (const event of stream) { + if (event.type === "thinking_start") { + thinkingStarted = true; + } else if (event.type === "thinking_delta") { + thinkingChunks += event.delta; + } else if (event.type === "thinking_end") { + thinkingCompleted = true; + } + } + + const response = await stream.finalMessage(); + + expect(response.stopReason, `Error: ${response.error}`).toBe("stop"); + expect(thinkingStarted).toBe(true); + expect(thinkingChunks.length).toBeGreaterThan(0); + expect(thinkingCompleted).toBe(true); + expect(response.content.some(b => b.type == "thinking")).toBeTruthy(); +} + +async function handleImage(model: Model, options?: GenerateOptionsUnified) { + // Check if the model supports images + if (!model.input.includes("image")) { + console.log(`Skipping image test - model ${model.id} doesn't support images`); + return; + } + + // Read the test image + const imagePath = join(__dirname, "data", "red-circle.png"); + const imageBuffer = readFileSync(imagePath); + const base64Image = imageBuffer.toString("base64"); + + const imageContent: ImageContent = { + type: "image", + data: base64Image, + mimeType: "image/png", + }; + + const context: Context = { + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." }, + imageContent, + ], + }, + ], + }; + + const response = await generateComplete(model, context, options); + + // Check the response mentions red and circle + expect(response.content.length > 0).toBeTruthy(); + const textContent = response.content.find(b => b.type == "text"); + if (textContent && textContent.type === "text") { + const lowerContent = textContent.text.toLowerCase(); + expect(lowerContent).toContain("red"); + expect(lowerContent).toContain("circle"); + } +} + +async function multiTurn(model: Model, options?: GenerateOptionsUnified) { + const context: Context = { + systemPrompt: "You are a helpful assistant that can use tools to answer questions.", + messages: [ + { + role: "user", + content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool." + } + ], + tools: [calculatorTool] + }; + + // Collect all text content from all assistant responses + let allTextContent = ""; + let hasSeenThinking = false; + let hasSeenToolCalls = false; + const maxTurns = 5; // Prevent infinite loops + + for (let turn = 0; turn < maxTurns; turn++) { + const response = await generateComplete(model, context, options); + + // Add the assistant response to context + context.messages.push(response); + + // Process content blocks + for (const block of response.content) { + if (block.type === "text") { + allTextContent += block.text; + } else if (block.type === "thinking") { + hasSeenThinking = true; + } else if (block.type === "toolCall") { + hasSeenToolCalls = true; + + // Process the tool call + expect(block.name).toBe("calculator"); + expect(block.id).toBeTruthy(); + expect(block.arguments).toBeTruthy(); + + const { a, b, operation } = block.arguments; + let result: number; + switch (operation) { + case "add": result = a + b; break; + case "multiply": result = a * b; break; + default: result = 0; + } + + // Add tool result to context + context.messages.push({ + role: "toolResult", + toolCallId: block.id, + toolName: block.name, + content: `${result}`, + isError: false + }); + } + } + + // If we got a stop response with text content, we're likely done + expect(response.stopReason).not.toBe("error"); + if (response.stopReason === "stop") { + break; + } + } + + // Verify we got either thinking content or tool calls (or both) + expect(hasSeenThinking || hasSeenToolCalls).toBe(true); + + // The accumulated text should reference both calculations + expect(allTextContent).toBeTruthy(); + expect(allTextContent.includes("714")).toBe(true); + expect(allTextContent.includes("887")).toBe(true); +} + +describe("Generate E2E Tests", () => { + describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => { + let model: Model; + + beforeAll(() => { + model = getModel("anthropic", "claude-3-5-haiku-20241022"); + }); + + it("should complete basic text generation", async () => { + await basicTextGeneration(model); + }); + + it("should handle tool calling", async () => { + await handleToolCall(model); + }); + + it("should handle streaming", async () => { + await handleStreaming(model); + }); + + it("should handle image input", async () => { + await handleImage(model); + }); + }); + + describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => { + let model: Model; + + beforeAll(() => { + model = getModel("anthropic", "claude-sonnet-4-20250514"); + }); + + it("should complete basic text generation", async () => { + await basicTextGeneration(model); + }); + + it("should handle tool calling", async () => { + await handleToolCall(model); + }); + + it("should handle streaming", async () => { + await handleStreaming(model); + }); + + it("should handle thinking mode", async () => { + await handleThinking(model, { reasoning: "low" }); + }); + + it("should handle multi-turn with thinking and tools", async () => { + await multiTurn(model, { reasoning: "medium" }); + }); + + it("should handle image input", async () => { + await handleImage(model); + }); + }); +}); \ No newline at end of file