mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-15 23:01:30 +00:00
feat(ai): Add new streaming generate API with AsyncIterable interface
- Implement QueuedGenerateStream class that extends AsyncIterable with finalMessage() method - Add new types: GenerateStream, GenerateOptions, GenerateOptionsUnified, GenerateFunction - Create generateAnthropic function-based implementation replacing class-based approach - Add comprehensive test suite for the new generate API - Support streaming events with text, thinking, and tool call deltas - Map ReasoningEffort to provider-specific options - Include apiKey in options instead of constructor parameter
This commit is contained in:
parent
be07c08a75
commit
004de3c9d0
6 changed files with 1106 additions and 129 deletions
268
packages/ai/src/generate.ts
Normal file
268
packages/ai/src/generate.ts
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptionsUnified,
|
||||
GenerateStream,
|
||||
KnownProvider,
|
||||
Model,
|
||||
ReasoningEffort,
|
||||
} from "./types.js";
|
||||
|
||||
export class QueuedGenerateStream implements GenerateStream {
|
||||
private queue: AssistantMessageEvent[] = [];
|
||||
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
|
||||
private done = false;
|
||||
private error?: Error;
|
||||
private finalMessagePromise: Promise<AssistantMessage>;
|
||||
private resolveFinalMessage!: (message: AssistantMessage) => void;
|
||||
private rejectFinalMessage!: (error: Error) => void;
|
||||
|
||||
constructor() {
|
||||
this.finalMessagePromise = new Promise((resolve, reject) => {
|
||||
this.resolveFinalMessage = resolve;
|
||||
this.rejectFinalMessage = reject;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: AssistantMessageEvent): void {
|
||||
if (this.done) return;
|
||||
|
||||
// If it's the done event, resolve the final message
|
||||
if (event.type === "done") {
|
||||
this.done = true;
|
||||
this.resolveFinalMessage(event.message);
|
||||
}
|
||||
|
||||
// If it's an error event, reject the final message
|
||||
if (event.type === "error") {
|
||||
this.error = new Error(event.error);
|
||||
if (!this.done) {
|
||||
this.rejectFinalMessage(this.error);
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: event, done: false });
|
||||
} else {
|
||||
this.queue.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
end(): void {
|
||||
this.done = true;
|
||||
// Notify all waiting consumers that we're done
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as any, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<AssistantMessageEvent> {
|
||||
while (true) {
|
||||
// If we have queued events, yield them
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
} else if (this.done) {
|
||||
// No more events and we're done
|
||||
return;
|
||||
} else {
|
||||
// Wait for next event
|
||||
const result = await new Promise<IteratorResult<AssistantMessageEvent>>((resolve) =>
|
||||
this.waiting.push(resolve),
|
||||
);
|
||||
if (result.done) return;
|
||||
yield result.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finalMessage(): Promise<AssistantMessage> {
|
||||
return this.finalMessagePromise;
|
||||
}
|
||||
}
|
||||
|
||||
// API implementations registry
|
||||
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
|
||||
|
||||
/**
|
||||
* Register a custom API implementation
|
||||
*/
|
||||
export function registerApi(api: string, impl: GenerateFunction): void {
|
||||
apiImplementations.set(api, impl);
|
||||
}
|
||||
|
||||
// API key storage
|
||||
const apiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Set an API key for a provider
|
||||
*/
|
||||
export function setApiKey(provider: KnownProvider, key: string): void;
|
||||
export function setApiKey(provider: string, key: string): void;
|
||||
export function setApiKey(provider: any, key: string): void {
|
||||
apiKeys.set(provider, key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider
|
||||
*/
|
||||
export function getApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getApiKey(provider: string): string | undefined;
|
||||
export function getApiKey(provider: any): string | undefined {
|
||||
// Check explicit keys first
|
||||
const key = apiKeys.get(provider);
|
||||
if (key) return key;
|
||||
|
||||
// Fall back to environment variables
|
||||
const envMap: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main generate function
|
||||
*/
|
||||
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream {
|
||||
// Get implementation
|
||||
const impl = apiImplementations.get(model.api);
|
||||
if (!impl) {
|
||||
throw new Error(`Unsupported API: ${model.api}`);
|
||||
}
|
||||
|
||||
// Get API key from options or environment
|
||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Map generic options to provider-specific
|
||||
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey);
|
||||
|
||||
// Return the GenerateStream from implementation
|
||||
return impl(model, context, providerOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to generate and get complete response (no streaming)
|
||||
*/
|
||||
export async function generateComplete(
|
||||
model: Model,
|
||||
context: Context,
|
||||
options?: GenerateOptionsUnified,
|
||||
): Promise<AssistantMessage> {
|
||||
const stream = generate(model, context, options);
|
||||
return stream.finalMessage();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map generic options to provider-specific options
|
||||
*/
|
||||
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any {
|
||||
const base = {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
signal: options?.signal,
|
||||
apiKey: apiKey || options?.apiKey,
|
||||
};
|
||||
|
||||
switch (api) {
|
||||
case "openai-responses":
|
||||
case "openai-completions":
|
||||
return {
|
||||
...base,
|
||||
reasoning_effort: options?.reasoning,
|
||||
};
|
||||
|
||||
case "anthropic-messages": {
|
||||
if (!options?.reasoning) return base;
|
||||
|
||||
// Map effort to token budget
|
||||
const anthropicBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: Math.min(25000, model.maxTokens - 1000),
|
||||
};
|
||||
|
||||
return {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: anthropicBudgets[options.reasoning],
|
||||
},
|
||||
};
|
||||
}
|
||||
case "google-generative-ai": {
|
||||
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
|
||||
|
||||
// Model-specific mapping for Google
|
||||
const googleBudget = getGoogleBudget(model, options.reasoning);
|
||||
return {
|
||||
...base,
|
||||
thinking_budget: googleBudget,
|
||||
};
|
||||
}
|
||||
default:
|
||||
return base;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Google thinking budget based on model and effort
|
||||
*/
|
||||
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
||||
// Model-specific logic
|
||||
if (model.id.includes("flash-lite")) {
|
||||
const budgets = {
|
||||
minimal: 512,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("pro")) {
|
||||
const budgets = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: Math.min(25000, 32768),
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("flash")) {
|
||||
const budgets = {
|
||||
minimal: 0, // Disable thinking
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
// Unknown model - use dynamic
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Register built-in API implementations
|
||||
// Import the new function-based implementations
|
||||
import { generateAnthropic } from "./providers/anthropic-generate.js";
|
||||
|
||||
// Register Anthropic implementation
|
||||
apiImplementations.set("anthropic-messages", generateAnthropic);
|
||||
|
|
@ -3,26 +3,26 @@
|
|||
|
||||
export const version = "0.5.8";
|
||||
|
||||
// Export generate API
|
||||
export {
|
||||
generate,
|
||||
generateComplete,
|
||||
getApiKey,
|
||||
QueuedGenerateStream,
|
||||
registerApi,
|
||||
setApiKey,
|
||||
} from "./generate.js";
|
||||
// Export generated models data
|
||||
export { PROVIDERS } from "./models.generated.js";
|
||||
|
||||
// Export models utilities and types
|
||||
// Export model utilities
|
||||
export {
|
||||
type AnthropicModel,
|
||||
type CerebrasModel,
|
||||
createLLM,
|
||||
type GoogleModel,
|
||||
type GroqModel,
|
||||
type Model,
|
||||
type OpenAIModel,
|
||||
type OpenRouterModel,
|
||||
PROVIDER_CONFIG,
|
||||
type ProviderModels,
|
||||
type ProviderToLLM,
|
||||
type XAIModel,
|
||||
calculateCost,
|
||||
getModel,
|
||||
type KnownProvider,
|
||||
registerModel,
|
||||
} from "./models.js";
|
||||
|
||||
// Export providers
|
||||
// Legacy providers (to be deprecated)
|
||||
export { AnthropicLLM } from "./providers/anthropic.js";
|
||||
export { GoogleLLM } from "./providers/google.js";
|
||||
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
||||
|
|
@ -30,3 +30,8 @@ export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
|||
|
||||
// Export types
|
||||
export type * from "./types.js";
|
||||
|
||||
// TODO: Remove these legacy exports once consumers are updated
|
||||
export function createLLM(): never {
|
||||
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,108 +1,44 @@
|
|||
import { PROVIDERS } from "./models.generated.js";
|
||||
import { AnthropicLLM } from "./providers/anthropic.js";
|
||||
import { GoogleLLM } from "./providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
||||
import type { Model, Usage } from "./types.js";
|
||||
import type { KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
// Provider configuration with factory functions
|
||||
export const PROVIDER_CONFIG = {
|
||||
google: {
|
||||
envKey: "GEMINI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new GoogleLLM(model, apiKey),
|
||||
},
|
||||
openai: {
|
||||
envKey: "OPENAI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAIResponsesLLM(model, apiKey),
|
||||
},
|
||||
anthropic: {
|
||||
envKey: "ANTHROPIC_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new AnthropicLLM(model, apiKey),
|
||||
},
|
||||
xai: {
|
||||
envKey: "XAI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
groq: {
|
||||
envKey: "GROQ_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
cerebras: {
|
||||
envKey: "CEREBRAS_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
openrouter: {
|
||||
envKey: "OPENROUTER_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
} as const;
|
||||
// Re-export Model type
|
||||
export type { KnownProvider, Model } from "./types.js";
|
||||
|
||||
// Type mapping from provider to LLM implementation
|
||||
export type ProviderToLLM = {
|
||||
google: GoogleLLM;
|
||||
openai: OpenAIResponsesLLM;
|
||||
anthropic: AnthropicLLM;
|
||||
xai: OpenAICompletionsLLM;
|
||||
groq: OpenAICompletionsLLM;
|
||||
cerebras: OpenAICompletionsLLM;
|
||||
openrouter: OpenAICompletionsLLM;
|
||||
};
|
||||
// Dynamic model registry initialized from PROVIDERS
|
||||
const modelRegistry: Map<string, Map<string, Model>> = new Map();
|
||||
|
||||
// Extract model types for each provider
|
||||
export type GoogleModel = keyof typeof PROVIDERS.google.models;
|
||||
export type OpenAIModel = keyof typeof PROVIDERS.openai.models;
|
||||
export type AnthropicModel = keyof typeof PROVIDERS.anthropic.models;
|
||||
export type XAIModel = keyof typeof PROVIDERS.xai.models;
|
||||
export type GroqModel = keyof typeof PROVIDERS.groq.models;
|
||||
export type CerebrasModel = keyof typeof PROVIDERS.cerebras.models;
|
||||
export type OpenRouterModel = keyof typeof PROVIDERS.openrouter.models;
|
||||
|
||||
// Map providers to their model types
|
||||
export type ProviderModels = {
|
||||
google: GoogleModel;
|
||||
openai: OpenAIModel;
|
||||
anthropic: AnthropicModel;
|
||||
xai: XAIModel;
|
||||
groq: GroqModel;
|
||||
cerebras: CerebrasModel;
|
||||
openrouter: OpenRouterModel;
|
||||
};
|
||||
|
||||
// Single generic factory function
|
||||
export function createLLM<P extends keyof typeof PROVIDERS, M extends keyof (typeof PROVIDERS)[P]["models"]>(
|
||||
provider: P,
|
||||
model: M,
|
||||
apiKey?: string,
|
||||
): ProviderToLLM[P] {
|
||||
const config = PROVIDER_CONFIG[provider as keyof typeof PROVIDER_CONFIG];
|
||||
if (!config) throw new Error(`Unknown provider: ${provider}`);
|
||||
|
||||
const providerData = PROVIDERS[provider];
|
||||
if (!providerData) throw new Error(`Unknown provider: ${provider}`);
|
||||
|
||||
// Type-safe model lookup
|
||||
const models = providerData.models as Record<string, Model>;
|
||||
const modelData = models[model as string];
|
||||
if (!modelData) throw new Error(`Unknown model: ${String(model)} for provider ${provider}`);
|
||||
|
||||
const key = apiKey || process.env[config.envKey];
|
||||
if (!key) throw new Error(`No API key provided for ${provider}. Set ${config.envKey} or pass apiKey.`);
|
||||
|
||||
return config.create(modelData, key) as ProviderToLLM[P];
|
||||
// Initialize registry from PROVIDERS on module load
|
||||
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
||||
const providerModels = new Map<string, Model>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
// Helper function to get model info with type-safe model IDs
|
||||
export function getModel<P extends keyof typeof PROVIDERS>(
|
||||
provider: P,
|
||||
modelId: keyof (typeof PROVIDERS)[P]["models"],
|
||||
): Model | undefined {
|
||||
const providerData = PROVIDERS[provider];
|
||||
if (!providerData) return undefined;
|
||||
const models = providerData.models as Record<string, Model>;
|
||||
return models[modelId as string];
|
||||
/**
|
||||
* Get a model from the registry - typed overload for known providers
|
||||
*/
|
||||
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
|
||||
export function getModel(provider: string, modelId: string): Model | undefined;
|
||||
export function getModel(provider: any, modelId: any): Model | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId);
|
||||
}
|
||||
|
||||
export function calculateCost(model: Model, usage: Usage) {
|
||||
/**
|
||||
* Register a custom model
|
||||
*/
|
||||
export function registerModel(model: Model): void {
|
||||
if (!modelRegistry.has(model.provider)) {
|
||||
modelRegistry.set(model.provider, new Map());
|
||||
}
|
||||
modelRegistry.get(model.provider)!.set(model.id, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cost for token usage
|
||||
*/
|
||||
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
|
|
@ -110,6 +46,3 @@ export function calculateCost(model: Model, usage: Usage) {
|
|||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
// Re-export Model type for convenience
|
||||
export type { Model };
|
||||
|
|
|
|||
425
packages/ai/src/providers/anthropic-generate.ts
Normal file
425
packages/ai/src/providers/anthropic-generate.ts
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
// Anthropic-specific options
|
||||
export interface AnthropicOptions extends GenerateOptions {
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
};
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Anthropic API
|
||||
*/
|
||||
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Anthropic client
|
||||
const client = createAnthropicClient(model, options.apiKey!);
|
||||
|
||||
// Convert messages
|
||||
const messages = convertMessages(context.messages, model, "anthropic-messages");
|
||||
|
||||
// Build params
|
||||
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
|
||||
|
||||
// Create Anthropic stream
|
||||
const anthropicStream = client.client.messages.stream(
|
||||
{
|
||||
...params,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options.signal,
|
||||
},
|
||||
);
|
||||
|
||||
// Emit start event
|
||||
stream.push({
|
||||
type: "start",
|
||||
partial: output,
|
||||
});
|
||||
|
||||
// Process Anthropic events
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
currentBlock = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
// We wait for the full tool use to be streamed
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: event.content_block.name,
|
||||
arguments: event.content_block.input as Record<string, any>,
|
||||
partialJson: "",
|
||||
};
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
if (currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
if (currentBlock && currentBlock.type === "toolCall") {
|
||||
currentBlock.partialJson += event.delta.partial_json;
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
|
||||
currentBlock.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
const finalToolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: currentBlock.id,
|
||||
name: currentBlock.name,
|
||||
arguments: JSON.parse(currentBlock.partialJson),
|
||||
};
|
||||
output.content.push(finalToolCall);
|
||||
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
output.usage.input += event.usage.input_tokens || 0;
|
||||
output.usage.output += event.usage.output_tokens || 0;
|
||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit done event with final message
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Helper to create Anthropic client
|
||||
interface AnthropicClientWrapper {
|
||||
client: Anthropic;
|
||||
isOAuthToken: boolean;
|
||||
}
|
||||
|
||||
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
}
|
||||
|
||||
// Build Anthropic API params
|
||||
function buildAnthropicParams(
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
messages: MessageParam[],
|
||||
isOAuthToken: boolean,
|
||||
): MessageCreateParamsStreaming {
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
max_tokens: options.maxTokens || model.maxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Only enable thinking if the model supports it
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert messages to Anthropic format
|
||||
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model, api);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Anthropic format
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
// Image content
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert tools to Anthropic format
|
||||
function convertTools(tools: Context["tools"]): Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// Map Anthropic stop reason to our StopReason type
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,38 @@
|
|||
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||
export type Api = KnownApi | string;
|
||||
|
||||
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
||||
export type ReasoningEffort = "minimal" | "low" | "medium" | "high";
|
||||
|
||||
// The stream interface - what generate() returns
|
||||
export interface GenerateStream extends AsyncIterable<AssistantMessageEvent> {
|
||||
// Get the final message (waits for streaming to complete)
|
||||
finalMessage(): Promise<AssistantMessage>;
|
||||
}
|
||||
|
||||
// Base options all providers share
|
||||
export interface GenerateOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
}
|
||||
|
||||
// Unified options with reasoning (what public generate() accepts)
|
||||
export interface GenerateOptionsUnified extends GenerateOptions {
|
||||
reasoning?: ReasoningEffort;
|
||||
}
|
||||
|
||||
// Generic GenerateFunction with typed options
|
||||
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = (
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: TOptions,
|
||||
) => GenerateStream;
|
||||
|
||||
// Legacy LLM interface (to be removed)
|
||||
export interface LLMOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
|
|
@ -60,11 +95,10 @@ export interface UserMessage {
|
|||
export interface AssistantMessage {
|
||||
role: "assistant";
|
||||
content: (TextContent | ThinkingContent | ToolCall)[];
|
||||
api: string;
|
||||
provider: string;
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
model: string;
|
||||
usage: Usage;
|
||||
|
||||
stopReason: StopReason;
|
||||
error?: string | Error;
|
||||
}
|
||||
|
|
@ -92,23 +126,24 @@ export interface Context {
|
|||
}
|
||||
|
||||
export type AssistantMessageEvent =
|
||||
| { type: "start"; model: string; provider: string }
|
||||
| { type: "text_start" }
|
||||
| { type: "text_delta"; content: string; delta: string }
|
||||
| { type: "text_end"; content: string }
|
||||
| { type: "thinking_start" }
|
||||
| { type: "thinking_delta"; content: string; delta: string }
|
||||
| { type: "thinking_end"; content: string }
|
||||
| { type: "toolCall"; toolCall: ToolCall }
|
||||
| { type: "start"; partial: AssistantMessage }
|
||||
| { type: "text_start"; partial: AssistantMessage }
|
||||
| { type: "text_delta"; delta: string; partial: AssistantMessage }
|
||||
| { type: "text_end"; content: string; partial: AssistantMessage }
|
||||
| { type: "thinking_start"; partial: AssistantMessage }
|
||||
| { type: "thinking_delta"; delta: string; partial: AssistantMessage }
|
||||
| { type: "thinking_end"; content: string; partial: AssistantMessage }
|
||||
| { type: "toolCall"; toolCall: ToolCall; partial: AssistantMessage }
|
||||
| { type: "done"; reason: StopReason; message: AssistantMessage }
|
||||
| { type: "error"; error: string };
|
||||
| { type: "error"; error: string; partial: AssistantMessage };
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
baseUrl?: string;
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
|
|
|
|||
311
packages/ai/test/generate.test.ts
Normal file
311
packages/ai/test/generate.test.ts
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { generate, generateComplete } from "../src/generate.js";
|
||||
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
|
||||
import { readFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Calculator tool definition (same as examples)
|
||||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform"
|
||||
}
|
||||
},
|
||||
required: ["a", "b", "operation"]
|
||||
}
|
||||
};
|
||||
|
||||
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [
|
||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
||||
]
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
|
||||
const secondResponse = await generateComplete(model, context, options);
|
||||
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
||||
}
|
||||
|
||||
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool."
|
||||
}],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find(b => b.type == "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleThinking(model: Model, options: GenerateOptionsUnified) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleImage(model: Model, options?: GenerateOptionsUnified) {
|
||||
// Check if the model supports images
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const textContent = response.content.find(b => b.type == "text");
|
||||
if (textContent && textContent.type === "text") {
|
||||
const lowerContent = textContent.text.toLowerCase();
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
}
|
||||
|
||||
async function multiTurn(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
||||
}
|
||||
],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await generateComplete(model, context, options);
|
||||
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add": result = a + b; break;
|
||||
case "multiply": result = a * b; break;
|
||||
default: result = 0;
|
||||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
}
|
||||
|
||||
describe("Generate E2E Tests", () => {
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
let model: Model;
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
let model: Model;
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(model, { reasoning: "low" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(model, { reasoning: "medium" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue