mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-22 06:00:26 +00:00
Massive refactor of API
- Switch to function based API - Anthropic SDK style async generator - Fully typed with escape hatches for custom models
This commit is contained in:
parent
004de3c9d0
commit
66cefb236e
29 changed files with 5835 additions and 6225 deletions
|
|
@ -3,6 +3,7 @@
|
|||
import { writeFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { Api, KnownProvider, Model } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
|
@ -28,30 +29,13 @@ interface ModelsDevModel {
|
|||
};
|
||||
}
|
||||
|
||||
interface NormalizedModel {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||
async function fetchOpenRouterModels(): Promise<Model<any>[]> {
|
||||
try {
|
||||
console.log("Fetching models from OpenRouter API...");
|
||||
const response = await fetch("https://openrouter.ai/api/v1/models");
|
||||
const data = await response.json();
|
||||
|
||||
const models: NormalizedModel[] = [];
|
||||
const models: Model<any>[] = [];
|
||||
|
||||
for (const model of data.data) {
|
||||
// Only include models that support tools
|
||||
|
|
@ -59,27 +43,17 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
|
||||
// Parse provider from model ID
|
||||
const [providerPrefix] = model.id.split("/");
|
||||
let provider = "";
|
||||
let provider: KnownProvider = "openrouter";
|
||||
let modelKey = model.id;
|
||||
|
||||
// Skip models that we get from models.dev (Anthropic, Google, OpenAI)
|
||||
if (model.id.startsWith("google/") ||
|
||||
model.id.startsWith("openai/") ||
|
||||
model.id.startsWith("anthropic/")) {
|
||||
continue;
|
||||
} else if (model.id.startsWith("x-ai/")) {
|
||||
provider = "xai";
|
||||
modelKey = model.id.replace("x-ai/", "");
|
||||
} else {
|
||||
// All other models go through OpenRouter
|
||||
provider = "openrouter";
|
||||
modelKey = model.id; // Keep full ID for OpenRouter
|
||||
}
|
||||
|
||||
// Skip if not one of our supported providers from OpenRouter
|
||||
if (!["xai", "openrouter"].includes(provider)) {
|
||||
model.id.startsWith("anthropic/") ||
|
||||
model.id.startsWith("x-ai/")) {
|
||||
continue;
|
||||
}
|
||||
modelKey = model.id; // Keep full ID for OpenRouter
|
||||
|
||||
// Parse input modalities
|
||||
const input: ("text" | "image")[] = ["text"];
|
||||
|
|
@ -93,9 +67,11 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
|
||||
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
|
||||
|
||||
const normalizedModel: NormalizedModel = {
|
||||
const normalizedModel: Model<any> = {
|
||||
id: modelKey,
|
||||
name: model.name,
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
provider,
|
||||
reasoning: model.supported_parameters?.includes("reasoning") || false,
|
||||
input,
|
||||
|
|
@ -108,14 +84,6 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
contextWindow: model.context_length || 4096,
|
||||
maxTokens: model.top_provider?.max_completion_tokens || 4096,
|
||||
};
|
||||
|
||||
// Add baseUrl for providers that need it
|
||||
if (provider === "xai") {
|
||||
normalizedModel.baseUrl = "https://api.x.ai/v1";
|
||||
} else if (provider === "openrouter") {
|
||||
normalizedModel.baseUrl = "https://openrouter.ai/api/v1";
|
||||
}
|
||||
|
||||
models.push(normalizedModel);
|
||||
}
|
||||
|
||||
|
|
@ -127,13 +95,13 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
}
|
||||
}
|
||||
|
||||
async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||
async function loadModelsDevData(): Promise<Model<any>[]> {
|
||||
try {
|
||||
console.log("Fetching models from models.dev API...");
|
||||
const response = await fetch("https://models.dev/api.json");
|
||||
const data = await response.json();
|
||||
|
||||
const models: NormalizedModel[] = [];
|
||||
const models: Model<any>[] = [];
|
||||
|
||||
// Process Anthropic models
|
||||
if (data.anthropic?.models) {
|
||||
|
|
@ -144,7 +112,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
baseUrl: "https://api.anthropic.com",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -168,7 +138,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "google-generative-ai",
|
||||
provider: "google",
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -192,7 +164,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -216,6 +190,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "groq",
|
||||
baseUrl: "https://api.groq.com/openai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
|
|
@ -241,6 +216,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "cerebras",
|
||||
baseUrl: "https://api.cerebras.ai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
|
|
@ -257,6 +233,32 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
}
|
||||
}
|
||||
|
||||
// Process xAi models
|
||||
if (data.xai?.models) {
|
||||
for (const [modelId, model] of Object.entries(data.xai.models)) {
|
||||
const m = model as ModelsDevModel;
|
||||
if (m.tool_call !== true) continue;
|
||||
|
||||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "xai",
|
||||
baseUrl: "https://api.x.ai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
input: m.cost?.input || 0,
|
||||
output: m.cost?.output || 0,
|
||||
cacheRead: m.cost?.cache_read || 0,
|
||||
cacheWrite: m.cost?.cache_write || 0,
|
||||
},
|
||||
contextWindow: m.limit?.context || 4096,
|
||||
maxTokens: m.limit?.output || 4096,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`Loaded ${models.length} tool-capable models from models.dev`);
|
||||
return models;
|
||||
} catch (error) {
|
||||
|
|
@ -280,6 +282,8 @@ async function generateModels() {
|
|||
allModels.push({
|
||||
id: "gpt-5-chat-latest",
|
||||
name: "GPT-5 Chat Latest",
|
||||
api: "openai-responses",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
provider: "openai",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
|
|
@ -294,8 +298,29 @@ async function generateModels() {
|
|||
});
|
||||
}
|
||||
|
||||
// Add missing Grok models
|
||||
if (!allModels.some(m => m.provider === "xai" && m.id === "grok-code-fast-1")) {
|
||||
allModels.push({
|
||||
id: "grok-code-fast-1",
|
||||
name: "Grok Code Fast 1",
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://api.x.ai/v1",
|
||||
provider: "xai",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.2,
|
||||
output: 1.5,
|
||||
cacheRead: 0.02,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 32768,
|
||||
maxTokens: 8192,
|
||||
});
|
||||
}
|
||||
|
||||
// Group by provider and deduplicate by model ID
|
||||
const providers: Record<string, Record<string, NormalizedModel>> = {};
|
||||
const providers: Record<string, Record<string, Model<any>>> = {};
|
||||
for (const model of allModels) {
|
||||
if (!providers[model.provider]) {
|
||||
providers[model.provider] = {};
|
||||
|
|
@ -319,39 +344,33 @@ export const PROVIDERS = {
|
|||
// Generate provider sections
|
||||
for (const [providerId, models] of Object.entries(providers)) {
|
||||
output += `\t${providerId}: {\n`;
|
||||
output += `\t\tmodels: {\n`;
|
||||
|
||||
for (const model of Object.values(models)) {
|
||||
output += `\t\t\t"${model.id}": {\n`;
|
||||
output += `\t\t\t\tid: "${model.id}",\n`;
|
||||
output += `\t\t\t\tname: "${model.name}",\n`;
|
||||
output += `\t\t\t\tprovider: "${model.provider}",\n`;
|
||||
output += `\t\t"${model.id}": {\n`;
|
||||
output += `\t\t\tid: "${model.id}",\n`;
|
||||
output += `\t\t\tname: "${model.name}",\n`;
|
||||
output += `\t\t\tapi: "${model.api}",\n`;
|
||||
output += `\t\t\tprovider: "${model.provider}",\n`;
|
||||
if (model.baseUrl) {
|
||||
output += `\t\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
||||
output += `\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
||||
}
|
||||
output += `\t\t\t\treasoning: ${model.reasoning},\n`;
|
||||
output += `\t\t\t\tinput: ${JSON.stringify(model.input)},\n`;
|
||||
output += `\t\t\t\tcost: {\n`;
|
||||
output += `\t\t\t\t\tinput: ${model.cost.input},\n`;
|
||||
output += `\t\t\t\t\toutput: ${model.cost.output},\n`;
|
||||
output += `\t\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
||||
output += `\t\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
||||
output += `\t\t\t\t},\n`;
|
||||
output += `\t\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
||||
output += `\t\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
||||
output += `\t\t\t} satisfies Model,\n`;
|
||||
output += `\t\t\treasoning: ${model.reasoning},\n`;
|
||||
output += `\t\t\tinput: [${model.input.map(i => `"${i}"`).join(", ")}],\n`;
|
||||
output += `\t\t\tcost: {\n`;
|
||||
output += `\t\t\t\tinput: ${model.cost.input},\n`;
|
||||
output += `\t\t\t\toutput: ${model.cost.output},\n`;
|
||||
output += `\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
||||
output += `\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
||||
output += `\t\t\t},\n`;
|
||||
output += `\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
||||
output += `\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
||||
output += `\t\t} satisfies Model<"${model.api}">,\n`;
|
||||
}
|
||||
|
||||
output += `\t\t}\n`;
|
||||
output += `\t},\n`;
|
||||
}
|
||||
|
||||
output += `} as const;
|
||||
|
||||
// Helper type to extract models for each provider
|
||||
export type ProviderModels = {
|
||||
[K in keyof typeof PROVIDERS]: typeof PROVIDERS[K]["models"]
|
||||
};
|
||||
`;
|
||||
|
||||
// Write file
|
||||
|
|
|
|||
|
|
@ -1,47 +1,43 @@
|
|||
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
|
||||
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
|
||||
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
|
||||
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptionsUnified,
|
||||
GenerateStream,
|
||||
KnownProvider,
|
||||
Model,
|
||||
OptionsForApi,
|
||||
ReasoningEffort,
|
||||
SimpleGenerateOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export class QueuedGenerateStream implements GenerateStream {
|
||||
private queue: AssistantMessageEvent[] = [];
|
||||
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
|
||||
private done = false;
|
||||
private error?: Error;
|
||||
private finalMessagePromise: Promise<AssistantMessage>;
|
||||
private resolveFinalMessage!: (message: AssistantMessage) => void;
|
||||
private rejectFinalMessage!: (error: Error) => void;
|
||||
|
||||
constructor() {
|
||||
this.finalMessagePromise = new Promise((resolve, reject) => {
|
||||
this.finalMessagePromise = new Promise((resolve) => {
|
||||
this.resolveFinalMessage = resolve;
|
||||
this.rejectFinalMessage = reject;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: AssistantMessageEvent): void {
|
||||
if (this.done) return;
|
||||
|
||||
// If it's the done event, resolve the final message
|
||||
if (event.type === "done") {
|
||||
this.done = true;
|
||||
this.resolveFinalMessage(event.message);
|
||||
}
|
||||
|
||||
// If it's an error event, reject the final message
|
||||
if (event.type === "error") {
|
||||
this.error = new Error(event.error);
|
||||
if (!this.done) {
|
||||
this.rejectFinalMessage(this.error);
|
||||
}
|
||||
this.done = true;
|
||||
this.resolveFinalMessage(event.partial);
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
|
|
@ -86,31 +82,14 @@ export class QueuedGenerateStream implements GenerateStream {
|
|||
}
|
||||
}
|
||||
|
||||
// API implementations registry
|
||||
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
|
||||
|
||||
/**
|
||||
* Register a custom API implementation
|
||||
*/
|
||||
export function registerApi(api: string, impl: GenerateFunction): void {
|
||||
apiImplementations.set(api, impl);
|
||||
}
|
||||
|
||||
// API key storage
|
||||
const apiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Set an API key for a provider
|
||||
*/
|
||||
export function setApiKey(provider: KnownProvider, key: string): void;
|
||||
export function setApiKey(provider: string, key: string): void;
|
||||
export function setApiKey(provider: any, key: string): void {
|
||||
apiKeys.set(provider, key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider
|
||||
*/
|
||||
export function getApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getApiKey(provider: string): string | undefined;
|
||||
export function getApiKey(provider: any): string | undefined {
|
||||
|
|
@ -133,45 +112,76 @@ export function getApiKey(provider: any): string | undefined {
|
|||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main generate function
|
||||
*/
|
||||
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream {
|
||||
// Get implementation
|
||||
const impl = apiImplementations.get(model.api);
|
||||
if (!impl) {
|
||||
throw new Error(`Unsupported API: ${model.api}`);
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: OptionsForApi<TApi>,
|
||||
): GenerateStream {
|
||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
const providerOptions = { ...options, apiKey };
|
||||
|
||||
// Get API key from options or environment
|
||||
const api: Api = model.api;
|
||||
switch (api) {
|
||||
case "anthropic-messages":
|
||||
return streamAnthropic(model as Model<"anthropic-messages">, context, providerOptions);
|
||||
|
||||
case "openai-completions":
|
||||
return streamOpenAICompletions(model as Model<"openai-completions">, context, providerOptions as any);
|
||||
|
||||
case "openai-responses":
|
||||
return streamOpenAIResponses(model as Model<"openai-responses">, context, providerOptions as any);
|
||||
|
||||
case "google-generative-ai":
|
||||
return streamGoogle(model as Model<"google-generative-ai">, context, providerOptions);
|
||||
|
||||
default: {
|
||||
// This should never be reached if all Api cases are handled
|
||||
const _exhaustive: never = api;
|
||||
throw new Error(`Unhandled API: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: OptionsForApi<TApi>,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.finalMessage();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleGenerateOptions,
|
||||
): GenerateStream {
|
||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Map generic options to provider-specific
|
||||
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey);
|
||||
|
||||
// Return the GenerateStream from implementation
|
||||
return impl(model, context, providerOptions);
|
||||
const providerOptions = mapOptionsForApi(model, options, apiKey);
|
||||
return stream(model, context, providerOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to generate and get complete response (no streaming)
|
||||
*/
|
||||
export async function generateComplete(
|
||||
model: Model,
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: GenerateOptionsUnified,
|
||||
options?: SimpleGenerateOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const stream = generate(model, context, options);
|
||||
return stream.finalMessage();
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.finalMessage();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map generic options to provider-specific options
|
||||
*/
|
||||
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any {
|
||||
function mapOptionsForApi<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
options?: SimpleGenerateOptions,
|
||||
apiKey?: string,
|
||||
): OptionsForApi<TApi> {
|
||||
const base = {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
|
|
@ -179,18 +189,10 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
|||
apiKey: apiKey || options?.apiKey,
|
||||
};
|
||||
|
||||
switch (api) {
|
||||
case "openai-responses":
|
||||
case "openai-completions":
|
||||
return {
|
||||
...base,
|
||||
reasoning_effort: options?.reasoning,
|
||||
};
|
||||
|
||||
switch (model.api) {
|
||||
case "anthropic-messages": {
|
||||
if (!options?.reasoning) return base;
|
||||
if (!options?.reasoning) return base satisfies AnthropicOptions;
|
||||
|
||||
// Map effort to token budget
|
||||
const anthropicBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
|
|
@ -200,55 +202,60 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
|||
|
||||
return {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: anthropicBudgets[options.reasoning],
|
||||
},
|
||||
};
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: anthropicBudgets[options.reasoning],
|
||||
} satisfies AnthropicOptions;
|
||||
}
|
||||
case "google-generative-ai": {
|
||||
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
|
||||
|
||||
// Model-specific mapping for Google
|
||||
const googleBudget = getGoogleBudget(model, options.reasoning);
|
||||
case "openai-completions":
|
||||
return {
|
||||
...base,
|
||||
thinking_budget: googleBudget,
|
||||
};
|
||||
reasoningEffort: options?.reasoning,
|
||||
} satisfies OpenAICompletionsOptions;
|
||||
|
||||
case "openai-responses":
|
||||
return {
|
||||
...base,
|
||||
reasoningEffort: options?.reasoning,
|
||||
} satisfies OpenAIResponsesOptions;
|
||||
|
||||
case "google-generative-ai": {
|
||||
if (!options?.reasoning) return base as any;
|
||||
|
||||
const googleBudget = getGoogleBudget(model as Model<"google-generative-ai">, options.reasoning);
|
||||
return {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: googleBudget,
|
||||
},
|
||||
} satisfies GoogleOptions;
|
||||
}
|
||||
|
||||
default: {
|
||||
// Exhaustiveness check
|
||||
const _exhaustive: never = model.api;
|
||||
throw new Error(`Unhandled API in mapOptionsForApi: ${_exhaustive}`);
|
||||
}
|
||||
default:
|
||||
return base;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Google thinking budget based on model and effort
|
||||
*/
|
||||
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
||||
// Model-specific logic
|
||||
if (model.id.includes("flash-lite")) {
|
||||
const budgets = {
|
||||
minimal: 512,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("pro")) {
|
||||
function getGoogleBudget(model: Model<"google-generative-ai">, effort: ReasoningEffort): number {
|
||||
// See https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: Math.min(25000, 32768),
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("flash")) {
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
// Covers 2.5-flash-lite as well
|
||||
const budgets = {
|
||||
minimal: 0, // Disable thinking
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
|
|
@ -259,10 +266,3 @@ function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
|||
// Unknown model - use dynamic
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Register built-in API implementations
|
||||
// Import the new function-based implementations
|
||||
import { generateAnthropic } from "./providers/anthropic-generate.js";
|
||||
|
||||
// Register Anthropic implementation
|
||||
apiImplementations.set("anthropic-messages", generateAnthropic);
|
||||
|
|
|
|||
|
|
@ -1,37 +1,8 @@
|
|||
// @mariozechner/pi-ai - Unified LLM API with automatic model discovery
|
||||
// This package provides a common interface for working with multiple LLM providers
|
||||
|
||||
export const version = "0.5.8";
|
||||
|
||||
// Export generate API
|
||||
export {
|
||||
generate,
|
||||
generateComplete,
|
||||
getApiKey,
|
||||
QueuedGenerateStream,
|
||||
registerApi,
|
||||
setApiKey,
|
||||
} from "./generate.js";
|
||||
// Export generated models data
|
||||
export { PROVIDERS } from "./models.generated.js";
|
||||
// Export model utilities
|
||||
export {
|
||||
calculateCost,
|
||||
getModel,
|
||||
type KnownProvider,
|
||||
registerModel,
|
||||
} from "./models.js";
|
||||
|
||||
// Legacy providers (to be deprecated)
|
||||
export { AnthropicLLM } from "./providers/anthropic.js";
|
||||
export { GoogleLLM } from "./providers/google.js";
|
||||
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
||||
export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
||||
|
||||
// Export types
|
||||
export type * from "./types.js";
|
||||
|
||||
// TODO: Remove these legacy exports once consumers are updated
|
||||
export function createLLM(): never {
|
||||
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
|
||||
}
|
||||
export * from "./generate.js";
|
||||
export * from "./models.generated.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export * from "./providers/google.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./types.js";
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +1,39 @@
|
|||
import { PROVIDERS } from "./models.generated.js";
|
||||
import type { KnownProvider, Model, Usage } from "./types.js";
|
||||
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
// Re-export Model type
|
||||
export type { KnownProvider, Model } from "./types.js";
|
||||
|
||||
// Dynamic model registry initialized from PROVIDERS
|
||||
const modelRegistry: Map<string, Map<string, Model>> = new Map();
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
// Initialize registry from PROVIDERS on module load
|
||||
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
||||
const providerModels = new Map<string, Model>();
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model);
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a model from the registry - typed overload for known providers
|
||||
*/
|
||||
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
|
||||
export function getModel(provider: string, modelId: string): Model | undefined;
|
||||
export function getModel(provider: any, modelId: any): Model | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId);
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof PROVIDERS)[TProvider],
|
||||
> = (typeof PROVIDERS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||
|
||||
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof PROVIDERS)[TProvider]>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>>;
|
||||
export function getModel<TApi extends Api>(provider: string, modelId: string): Model<TApi> | undefined;
|
||||
export function getModel<TApi extends Api>(provider: any, modelId: any): Model<TApi> | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId) as Model<TApi> | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom model
|
||||
*/
|
||||
export function registerModel(model: Model): void {
|
||||
export function registerModel<TApi extends Api>(model: Model<TApi>): void {
|
||||
if (!modelRegistry.has(model.provider)) {
|
||||
modelRegistry.set(model.provider, new Map());
|
||||
}
|
||||
modelRegistry.get(model.provider)!.set(model.id, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cost for token usage
|
||||
*/
|
||||
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
|
||||
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
|
|
|
|||
|
|
@ -1,425 +0,0 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
// Anthropic-specific options
|
||||
export interface AnthropicOptions extends GenerateOptions {
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
};
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Anthropic API
|
||||
*/
|
||||
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Anthropic client
|
||||
const client = createAnthropicClient(model, options.apiKey!);
|
||||
|
||||
// Convert messages
|
||||
const messages = convertMessages(context.messages, model, "anthropic-messages");
|
||||
|
||||
// Build params
|
||||
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
|
||||
|
||||
// Create Anthropic stream
|
||||
const anthropicStream = client.client.messages.stream(
|
||||
{
|
||||
...params,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options.signal,
|
||||
},
|
||||
);
|
||||
|
||||
// Emit start event
|
||||
stream.push({
|
||||
type: "start",
|
||||
partial: output,
|
||||
});
|
||||
|
||||
// Process Anthropic events
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
currentBlock = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
// We wait for the full tool use to be streamed
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: event.content_block.name,
|
||||
arguments: event.content_block.input as Record<string, any>,
|
||||
partialJson: "",
|
||||
};
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
if (currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
if (currentBlock && currentBlock.type === "toolCall") {
|
||||
currentBlock.partialJson += event.delta.partial_json;
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
|
||||
currentBlock.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
const finalToolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: currentBlock.id,
|
||||
name: currentBlock.name,
|
||||
arguments: JSON.parse(currentBlock.partialJson),
|
||||
};
|
||||
output.content.push(finalToolCall);
|
||||
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
output.usage.input += event.usage.input_tokens || 0;
|
||||
output.usage.output += event.usage.output_tokens || 0;
|
||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit done event with final message
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Helper to create Anthropic client
|
||||
interface AnthropicClientWrapper {
|
||||
client: Anthropic;
|
||||
isOAuthToken: boolean;
|
||||
}
|
||||
|
||||
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
}
|
||||
|
||||
// Build Anthropic API params
|
||||
function buildAnthropicParams(
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
messages: MessageParam[],
|
||||
isOAuthToken: boolean,
|
||||
): MessageCreateParamsStreaming {
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
max_tokens: options.maxTokens || model.maxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Only enable thinking if the model supports it
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert messages to Anthropic format
|
||||
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model, api);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Anthropic format
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
// Image content
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert tools to Anthropic format
|
||||
function convertTools(tools: Context["tools"]): Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// Map Anthropic stop reason to our StopReason type
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
|
@ -3,91 +3,46 @@ import type {
|
|||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface AnthropicLLMOptions extends LLMOptions {
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
};
|
||||
export interface AnthropicOptions extends GenerateOptions {
|
||||
thinkingEnabled?: boolean;
|
||||
thinkingBudgetTokens?: number;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||
private client: Anthropic;
|
||||
private modelInfo: Model;
|
||||
private isOAuthToken: boolean = false;
|
||||
export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.ANTHROPIC_API_KEY) {
|
||||
throw new Error(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
this.isOAuthToken = true;
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
this.client = new Anthropic({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true, defaultHeaders });
|
||||
this.isOAuthToken = false;
|
||||
}
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "anthropic-messages";
|
||||
}
|
||||
|
||||
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "anthropic-messages" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -99,77 +54,14 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
};
|
||||
|
||||
try {
|
||||
const messages = this.convertMessages(context.messages);
|
||||
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
max_tokens: options?.maxTokens || 4096,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (this.isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = this.convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Only enable thinking if the model supports it
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
const stream = this.client.messages.stream(
|
||||
{
|
||||
...params,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
);
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
const { client, isOAuthToken } = createClient(model, options?.apiKey!);
|
||||
const params = buildParams(model, context, isOAuthToken, options);
|
||||
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
for await (const event of stream) {
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
currentBlock = {
|
||||
|
|
@ -177,7 +69,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
text: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
|
|
@ -185,9 +77,9 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
thinkingSignature: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
// We wait for the full tool use to be streamed to send the event
|
||||
// We wait for the full tool use to be streamed
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
|
|
@ -200,15 +92,19 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
if (event.delta.type === "text_delta") {
|
||||
if (currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text += event.delta.text;
|
||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: event.delta.text });
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += event.delta.thinking;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentBlock.thinking,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
|
|
@ -224,9 +120,17 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
} else if (event.type === "content_block_stop") {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
const finalToolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -235,150 +139,274 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
arguments: JSON.parse(currentBlock.partialJson),
|
||||
};
|
||||
output.content.push(finalToolCall);
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: finalToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = this.mapStopReason(event.delta.stop_reason);
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
output.usage.input += event.usage.input_tokens || 0;
|
||||
output.usage.output += event.usage.output_tokens || 0;
|
||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, model),
|
||||
max_tokens: options?.maxTokens || model.maxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
if (options?.thinkingEnabled && model.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[]): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
return params;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
function convertMessages(messages: Message[], model: Model<"anthropic-messages">): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Anthropic format
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
// Image content
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredBlocks = !this.modelInfo?.input.includes("image")
|
||||
? blocks.filter((b) => b.type !== "image")
|
||||
: blocks;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
blocks.push({
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
return params;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
private convertTools(tools: Context["tools"]): Tool[] {
|
||||
if (!tools) return [];
|
||||
function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
private mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default:
|
||||
return "stop";
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
import {
|
||||
type Content,
|
||||
type FinishReason,
|
||||
FinishReason,
|
||||
FunctionCallingConfigMode,
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type Part,
|
||||
} from "@google/genai";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
|
|
@ -23,7 +25,7 @@ import type {
|
|||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface GoogleLLMOptions extends LLMOptions {
|
||||
export interface GoogleOptions extends GenerateOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
|
|
@ -31,38 +33,20 @@ export interface GoogleLLMOptions extends LLMOptions {
|
|||
};
|
||||
}
|
||||
|
||||
export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||
private client: GoogleGenAI;
|
||||
private modelInfo: Model;
|
||||
export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.GEMINI_API_KEY) {
|
||||
throw new Error(
|
||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.GEMINI_API_KEY;
|
||||
}
|
||||
this.client = new GoogleGenAI({ apiKey });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "google-generative-ai";
|
||||
}
|
||||
|
||||
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -72,70 +56,20 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
const contents = this.convertMessages(context.messages);
|
||||
const client = createClient(options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
// Build generation config
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options?.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options?.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
// Build the config object
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
||||
...(context.tools && { tools: this.convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
// Add tool config if needed
|
||||
if (context.tools && options?.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: this.mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Add thinking config if enabled and model supports it
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
};
|
||||
}
|
||||
|
||||
// Abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
// Build the request parameters
|
||||
const params: GenerateContentParameters = {
|
||||
model: this.modelInfo.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
const stream = await this.client.models.generateContentStream(params);
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
for await (const chunk of stream) {
|
||||
// Extract parts from the chunk
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = part.thought === true;
|
||||
|
||||
// Check if we need to switch blocks
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
|
|
@ -143,50 +77,60 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Start new block
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Append content to current block
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = part.thoughtSignature;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentBlock.thinking,
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: part.text });
|
||||
stream.push({ type: "text_delta", delta: part.text, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle function calls
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Add tool call
|
||||
const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`;
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -195,21 +139,18 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
arguments: part.functionCall.args as Record<string, any>,
|
||||
};
|
||||
output.content.push(toolCall);
|
||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
||||
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map finish reason
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = this.mapStopReason(candidate.finishReason);
|
||||
// Check if we have tool calls in blocks
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
// Capture usage metadata if available
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
|
|
@ -225,166 +166,223 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize last block
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(apiKey?: string): GoogleGenAI {
|
||||
if (!apiKey) {
|
||||
if (!process.env.GEMINI_API_KEY) {
|
||||
throw new Error(
|
||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.GEMINI_API_KEY;
|
||||
}
|
||||
return new GoogleGenAI({ apiKey });
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[]): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
||||
...(context.tools && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (context.tools && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Google format
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: item.text };
|
||||
} else {
|
||||
// Image content - Google uses inlineData
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !this.modelInfo?.input.includes("image")
|
||||
? parts.filter((p) => p.text !== undefined)
|
||||
: parts;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
};
|
||||
}
|
||||
|
||||
// Process content blocks
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
parts.push({ text: block.text });
|
||||
} else if (block.type === "thinking") {
|
||||
const thinkingPart: Part = {
|
||||
thought: true,
|
||||
thoughtSignature: block.thinkingSignature,
|
||||
text: block.thinking,
|
||||
};
|
||||
parts.push(thinkingPart);
|
||||
} else if (block.type === "toolCall") {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
args: block.arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
if (parts.length > 0) {
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
response: {
|
||||
result: msg.content,
|
||||
isError: msg.isError,
|
||||
},
|
||||
parts: [{ text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: item.text };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
|
||||
return contents;
|
||||
}
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
parts.push({ text: block.text });
|
||||
} else if (block.type === "thinking") {
|
||||
const thinkingPart: Part = {
|
||||
thought: true,
|
||||
thoughtSignature: block.thinkingSignature,
|
||||
text: block.thinking,
|
||||
};
|
||||
parts.push(thinkingPart);
|
||||
} else if (block.type === "toolCall") {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
args: block.arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): any[] {
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
private mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
response: {
|
||||
result: msg.content,
|
||||
isError: msg.isError,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
case "BLOCKLIST":
|
||||
case "PROHIBITED_CONTENT":
|
||||
case "SPII":
|
||||
case "SAFETY":
|
||||
case "IMAGE_SAFETY":
|
||||
return "safety";
|
||||
case "RECITATION":
|
||||
return "safety";
|
||||
case "FINISH_REASON_UNSPECIFIED":
|
||||
case "OTHER":
|
||||
case "LANGUAGE":
|
||||
case "MALFORMED_FUNCTION_CALL":
|
||||
case "UNEXPECTED_TOOL_CALL":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
return contents;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[]): any[] {
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.RECITATION:
|
||||
return "safety";
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
import OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
|
|
@ -22,43 +24,25 @@ import type {
|
|||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAICompletionsLLMOptions extends LLMOptions {
|
||||
export interface OpenAICompletionsOptions extends GenerateOptions {
|
||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "low" | "medium" | "high";
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||
}
|
||||
|
||||
export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private modelInfo: Model;
|
||||
export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "openai-completions";
|
||||
}
|
||||
|
||||
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -70,52 +54,13 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
};
|
||||
|
||||
try {
|
||||
const messages = this.convertMessages(request.messages, request.systemPrompt);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
};
|
||||
|
||||
// Cerebras/xAI dont like the "store" field
|
||||
if (!this.modelInfo.baseUrl?.includes("cerebras.ai") && !this.modelInfo.baseUrl?.includes("api.x.ai")) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_completion_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (request.tools) {
|
||||
params.tools = this.convertTools(request.tools);
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if (
|
||||
options?.reasoningEffort &&
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.id.toLowerCase().includes("grok")
|
||||
) {
|
||||
params.reasoning_effort = options.reasoningEffort;
|
||||
}
|
||||
|
||||
const stream = await this.client.chat.completions.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
const client = createClient(model, options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||
for await (const chunk of stream) {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
output.usage = {
|
||||
input: chunk.usage.prompt_tokens || 0,
|
||||
|
|
@ -132,137 +77,170 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
// Capture finish reason
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = this.mapStopReason(choice.finish_reason);
|
||||
output.stopReason = mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
// Handle text content
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
// Check if we need to switch to text block
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new text block
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
// Append to text block
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentBlock.text,
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reasoning_content field
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp)
|
||||
if (
|
||||
(choice.delta as any).reasoning_content !== null &&
|
||||
(choice.delta as any).reasoning_content !== undefined &&
|
||||
(choice.delta as any).reasoning_content.length > 0
|
||||
) {
|
||||
// Check if we need to switch to thinking block
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" };
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "reasoning_content",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any).reasoning_content;
|
||||
currentBlock.thinking += delta;
|
||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reasoning field
|
||||
// Some endpoints return reasoning in reasining (ollama, xAI, ...)
|
||||
if (
|
||||
(choice.delta as any).reasoning !== null &&
|
||||
(choice.delta as any).reasoning !== undefined &&
|
||||
(choice.delta as any).reasoning.length > 0
|
||||
) {
|
||||
// Check if we need to switch to thinking block
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" };
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "reasoning",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any).reasoning;
|
||||
currentBlock.thinking += delta;
|
||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
||||
stream.push({ type: "thinking_delta", delta, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
// Check if we need a new tool call block
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Start new tool call block
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
|
|
@ -273,7 +251,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Accumulate tool call data
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
||||
|
|
@ -286,16 +263,27 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
}
|
||||
}
|
||||
|
||||
// Save final block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -303,141 +291,188 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
return output;
|
||||
} catch (error) {
|
||||
// Update output with error information
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : String(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(model: Model<"openai-completions">, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
|
||||
const messages = convertMessages(model, context);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
};
|
||||
|
||||
// Cerebras/xAI dont like the "store" field
|
||||
if (!model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai")) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
if (options?.maxTokens) {
|
||||
params.max_completion_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
const useDeveloperRole =
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.baseUrl?.includes("cerebras.ai") &&
|
||||
!this.modelInfo.baseUrl?.includes("api.x.ai");
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: systemPrompt });
|
||||
}
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to OpenAI format
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
// Image content - OpenAI uses data URLs
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const assistantMsg: ChatCompletionMessageParam = {
|
||||
role: "assistant",
|
||||
content: null,
|
||||
};
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
// Build content from blocks
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
if (textBlocks.length > 0) {
|
||||
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
||||
}
|
||||
// Grok models don't like reasoning_effort
|
||||
if (options?.reasoningEffort && model.reasoning && !model.id.toLowerCase().includes("grok")) {
|
||||
params.reasoning_effort = options.reasoningEffort;
|
||||
}
|
||||
|
||||
// Handle thinking blocks for llama.cpp server + gpt-oss
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
if (thinkingBlocks.length > 0) {
|
||||
// Use the signature from the first thinking block if available
|
||||
const signature = thinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
}
|
||||
function convertMessages(model: Model<"openai-completions">, context: Context): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
const useDeveloperRole =
|
||||
model.reasoning && !model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai");
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: context.systemPrompt });
|
||||
}
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "tool",
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
tool_call_id: msg.toolCallId,
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
if (textBlocks.length > 0) {
|
||||
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
||||
}
|
||||
|
||||
// Handle thinking blocks for llama.cpp server + gpt-oss
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
if (thinkingBlocks.length > 0) {
|
||||
// Use the signature from the first thinking block if available
|
||||
const signature = thinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "tool",
|
||||
content: msg.content,
|
||||
tool_call_id: msg.toolCallId,
|
||||
});
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
private mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"] | null): StopReason {
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "safety";
|
||||
default:
|
||||
return "stop";
|
||||
function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "safety";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,58 +10,49 @@ import type {
|
|||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAIResponsesLLMOptions extends LLMOptions {
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends GenerateOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
}
|
||||
|
||||
export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private modelInfo: Model;
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "openai-responses";
|
||||
}
|
||||
|
||||
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "openai-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -71,77 +62,31 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
const input = this.convertToInput(request.messages, request.systemPrompt);
|
||||
// Create OpenAI client
|
||||
const client = createClient(model, options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const openaiStream = await client.responses.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
input,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (request.tools) {
|
||||
params.tools = this.convertTools(request.tools);
|
||||
}
|
||||
|
||||
// Add reasoning options for models that support it
|
||||
if (this.modelInfo?.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = {
|
||||
effort: this.modelInfo.name.startsWith("gpt-5") ? "minimal" : null,
|
||||
summary: null,
|
||||
};
|
||||
|
||||
if (this.modelInfo.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
input.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const stream = await this.client.responses.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
|
||||
const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = [];
|
||||
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
||||
let currentBlock: ThinkingContent | TextContent | ToolCall | null = null;
|
||||
|
||||
for await (const event of stream) {
|
||||
for await (const event of openaiStream) {
|
||||
// Handle output item start
|
||||
if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
outputItems.push(item);
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (item.type === "message") {
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
outputItems.push(item);
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
}
|
||||
// Handle reasoning summary deltas
|
||||
|
|
@ -151,30 +96,42 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
if (
|
||||
currentItem &&
|
||||
currentItem.type === "reasoning" &&
|
||||
currentBlock &&
|
||||
currentBlock.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add a new line between summary parts (hack...)
|
||||
else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
if (
|
||||
currentItem &&
|
||||
currentItem.type === "reasoning" &&
|
||||
currentBlock &&
|
||||
currentBlock.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -186,30 +143,28 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
currentItem.content.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem && currentItem.type === "message") {
|
||||
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart && lastPart.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentItem.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join(""),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem && currentItem.type === "message") {
|
||||
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart && lastPart.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentItem.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join(""),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -218,14 +173,24 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
options?.onEvent?.({ type: "thinking_end", content: thinkingContent });
|
||||
} else if (item.type === "message") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
const textContent = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
options?.onEvent?.({ type: "text_end", content: textContent });
|
||||
if (item.type === "reasoning" && currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
currentBlock.textSignature = item.id;
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -233,8 +198,8 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
name: item.name,
|
||||
arguments: JSON.parse(item.arguments),
|
||||
};
|
||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
||||
outputItems.push(item);
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
// Handle completion
|
||||
|
|
@ -249,10 +214,10 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
// Map status to stop reason
|
||||
output.stopReason = this.mapStopReason(response?.status);
|
||||
if (outputItems.some((b) => b.type === "function_call") && output.stopReason === "stop") {
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
|
@ -260,173 +225,215 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
else if (event.type === "error") {
|
||||
output.stopReason = "error";
|
||||
output.error = `Code ${event.code}: ${event.message}` || "Unknown error";
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
return output;
|
||||
} else if (event.type === "response.failed") {
|
||||
output.stopReason = "error";
|
||||
output.error = "Unknown error";
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert output items to blocks
|
||||
for (const item of outputItems) {
|
||||
if (item.type === "reasoning") {
|
||||
output.content.push({
|
||||
type: "thinking",
|
||||
thinking: item.summary?.map((s: any) => s.text).join("\n\n") || "",
|
||||
thinkingSignature: JSON.stringify(item), // Full item for resubmission
|
||||
});
|
||||
} else if (item.type === "message") {
|
||||
output.content.push({
|
||||
type: "text",
|
||||
text: item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""),
|
||||
textSignature: item.id, // ID for resubmission
|
||||
});
|
||||
} else if (item.type === "function_call") {
|
||||
output.content.push({
|
||||
type: "toolCall",
|
||||
id: item.call_id + "|" + item.id,
|
||||
name: item.name,
|
||||
arguments: JSON.parse(item.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(model: Model<"openai-responses">, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertMessages(model, context);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput {
|
||||
const input: ResponseInput = [];
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
const role = this.modelInfo?.reasoning ? "developer" : "system";
|
||||
input.push({
|
||||
role,
|
||||
content: systemPrompt,
|
||||
});
|
||||
}
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = {
|
||||
effort: model.name.startsWith("gpt-5") ? "minimal" : null,
|
||||
summary: null,
|
||||
};
|
||||
|
||||
// Convert messages
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
input.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
// Convert array content to OpenAI Responses format
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: item.text,
|
||||
} satisfies ResponseInputText;
|
||||
} else {
|
||||
// Image content - OpenAI Responses uses data URLs
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
input.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Process content blocks in order
|
||||
const output: ResponseInput = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
if (block.type === "thinking" && msg.stopReason !== "error") {
|
||||
// Push the full reasoning item(s) from signature
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature);
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
||||
status: "completed",
|
||||
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
||||
} satisfies ResponseOutputMessage);
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
||||
const toolCall = block as ToolCall;
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: toolCall.id.split("|")[1], // Extract original ID
|
||||
call_id: toolCall.id.split("|")[0], // Extract call session ID
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add all output items to input
|
||||
input.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Tool results are sent as function_call_output
|
||||
input.push({
|
||||
type: "function_call_output",
|
||||
call_id: msg.toolCallId.split("|")[0], // Extract call session ID
|
||||
output: msg.content,
|
||||
if (model.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): OpenAITool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
strict: null,
|
||||
}));
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertMessages(model: Model<"openai-responses">, context: Context): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: context.systemPrompt,
|
||||
});
|
||||
}
|
||||
|
||||
private mapStopReason(status: string | undefined): StopReason {
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: item.text,
|
||||
} satisfies ResponseInputText;
|
||||
} else {
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
if (block.type === "thinking" && msg.stopReason !== "error") {
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature);
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
||||
status: "completed",
|
||||
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
||||
} satisfies ResponseOutputMessage);
|
||||
// Do not submit toolcall blocks if the completion had an error (i.e. abort)
|
||||
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
||||
const toolCall = block as ToolCall;
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: toolCall.id.split("|")[1],
|
||||
call_id: toolCall.id.split("|")[0],
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: msg.toolCallId.split("|")[0],
|
||||
output: msg.content,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[]): OpenAITool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
strict: null,
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,6 @@
|
|||
import type { AssistantMessage, Message, Model } from "../types.js";
|
||||
import type { Api, AssistantMessage, Message, Model } from "../types.js";
|
||||
|
||||
/**
|
||||
* Transform messages for cross-provider compatibility.
|
||||
*
|
||||
* - User and toolResult messages are copied verbatim
|
||||
* - Assistant messages:
|
||||
* - If from the same provider/model, copied as-is
|
||||
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
|
||||
*
|
||||
* @param messages The messages to transform
|
||||
* @param model The target model that will process these messages
|
||||
* @returns A copy of the messages array with transformations applied
|
||||
*/
|
||||
export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
|
||||
export function transformMessages<TApi extends Api>(messages: Message[], model: Model<TApi>): Message[] {
|
||||
return messages.map((msg) => {
|
||||
// User and toolResult messages pass through unchanged
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
|
|
@ -24,7 +12,7 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
|||
const assistantMsg = msg as AssistantMessage;
|
||||
|
||||
// If message is from the same provider and API, keep as is
|
||||
if (assistantMsg.provider === model.provider && assistantMsg.api === api) {
|
||||
if (assistantMsg.provider === model.provider && assistantMsg.api === model.api) {
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
|
@ -47,8 +35,6 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
|||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
|
||||
// Should not reach here, but return as-is for safety
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||
export type Api = KnownApi | string;
|
||||
import type { AnthropicOptions } from "./providers/anthropic";
|
||||
import type { GoogleOptions } from "./providers/google";
|
||||
import type { OpenAICompletionsOptions } from "./providers/openai-completions";
|
||||
import type { OpenAIResponsesOptions } from "./providers/openai-responses";
|
||||
|
||||
export type Api = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||
|
||||
export interface ApiOptionsMap {
|
||||
"anthropic-messages": AnthropicOptions;
|
||||
"openai-completions": OpenAICompletionsOptions;
|
||||
"openai-responses": OpenAIResponsesOptions;
|
||||
"google-generative-ai": GoogleOptions;
|
||||
}
|
||||
|
||||
// Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys
|
||||
type _CheckExhaustive = ApiOptionsMap extends Record<Api, GenerateOptions>
|
||||
? Record<Api, GenerateOptions> extends ApiOptionsMap
|
||||
? true
|
||||
: ["ApiOptionsMap is missing some KnownApi values", Exclude<Api, keyof ApiOptionsMap>]
|
||||
: ["ApiOptionsMap doesn't extend Record<KnownApi, GenerateOptions>"];
|
||||
const _exhaustive: _CheckExhaustive = true;
|
||||
|
||||
// Helper type to get options for a specific API
|
||||
export type OptionsForApi<TApi extends Api> = ApiOptionsMap[TApi];
|
||||
|
||||
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
|
@ -21,31 +43,17 @@ export interface GenerateOptions {
|
|||
}
|
||||
|
||||
// Unified options with reasoning (what public generate() accepts)
|
||||
export interface GenerateOptionsUnified extends GenerateOptions {
|
||||
export interface SimpleGenerateOptions extends GenerateOptions {
|
||||
reasoning?: ReasoningEffort;
|
||||
}
|
||||
|
||||
// Generic GenerateFunction with typed options
|
||||
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = (
|
||||
model: Model,
|
||||
export type GenerateFunction<TApi extends Api> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options: TOptions,
|
||||
options: OptionsForApi<TApi>,
|
||||
) => GenerateStream;
|
||||
|
||||
// Legacy LLM interface (to be removed)
|
||||
export interface LLMOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
onEvent?: (event: AssistantMessageEvent) => void;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface LLM<T extends LLMOptions> {
|
||||
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
||||
getModel(): Model;
|
||||
getApi(): string;
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
|
|
@ -100,7 +108,7 @@ export interface AssistantMessage {
|
|||
model: string;
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
error?: string | Error;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface ToolResultMessage {
|
||||
|
|
@ -138,10 +146,10 @@ export type AssistantMessageEvent =
|
|||
| { type: "error"; error: string; partial: AssistantMessage };
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model {
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: Api;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
|
|
|
|||
|
|
@ -1,128 +1,103 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context } from "../src/types.js";
|
||||
import { beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||
|
||||
async function testAbortSignal<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "What is 15 + 27? Think step by step. Then list 50 first names."
|
||||
}]
|
||||
};
|
||||
async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "What is 15 + 27? Think step by step. Then list 50 first names.",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let abortFired = false;
|
||||
const controller = new AbortController();
|
||||
const response = await llm.generate(context, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
onEvent: (event) => {
|
||||
// console.log(JSON.stringify(event, null, 2));
|
||||
if (abortFired) return;
|
||||
setTimeout(() => controller.abort(), 2000);
|
||||
abortFired = true;
|
||||
}
|
||||
});
|
||||
let abortFired = false;
|
||||
const controller = new AbortController();
|
||||
const response = await stream(llm, context, { ...options, signal: controller.signal });
|
||||
for await (const event of response) {
|
||||
if (abortFired) return;
|
||||
setTimeout(() => controller.abort(), 3000);
|
||||
abortFired = true;
|
||||
break;
|
||||
}
|
||||
const msg = await response.finalMessage();
|
||||
|
||||
// If we get here without throwing, the abort didn't work
|
||||
expect(response.stopReason).toBe("error");
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
// If we get here without throwing, the abort didn't work
|
||||
expect(msg.stopReason).toBe("error");
|
||||
expect(msg.content.length).toBeGreaterThan(0);
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
||||
context.messages.push(msg);
|
||||
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
||||
|
||||
// Ensure we can still make requests after abort
|
||||
const followUp = await llm.generate(context, options);
|
||||
expect(followUp.stopReason).toBe("stop");
|
||||
expect(followUp.content.length).toBeGreaterThan(0);
|
||||
const followUp = await complete(llm, context, options);
|
||||
expect(followUp.stopReason).toBe("stop");
|
||||
expect(followUp.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
|
||||
async function testImmediateAbort<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
const controller = new AbortController();
|
||||
async function testImmediateAbort<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
const controller = new AbortController();
|
||||
|
||||
// Abort immediately
|
||||
controller.abort();
|
||||
controller.abort();
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Hello" }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Hello" }],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
...options,
|
||||
signal: controller.signal
|
||||
});
|
||||
expect(response.stopReason).toBe("error");
|
||||
const response = await complete(llm, context, { ...options, signal: controller.signal });
|
||||
expect(response.stopReason).toBe("error");
|
||||
}
|
||||
|
||||
describe("AI Providers Abort Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
||||
let llm: GoogleLLM;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
||||
const llm: Model<"openai-completions"> = {
|
||||
...getModel("openai", "gpt-4o-mini")!,
|
||||
api: "openai-completions",
|
||||
};
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (!model) {
|
||||
throw new Error("Model not found");
|
||||
}
|
||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
||||
const llm = getModel("anthropic", "claude-opus-4-1-20250805");
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, {});
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, {});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-opus-4-1-20250805")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,313 +1,265 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context, UserMessage, AssistantMessage } from "../src/types.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||
|
||||
async function testEmptyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with completely empty content array
|
||||
const emptyMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: []
|
||||
};
|
||||
async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with completely empty content array
|
||||
const emptyMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [],
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [emptyMessage]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [emptyMessage],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
// Should either handle gracefully or return an error
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Most providers should return an error or empty response
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
// If it didn't error, it should have some content or gracefully handle empty
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
// Should either handle gracefully or return an error
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testEmptyStringMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with empty string content
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: ""
|
||||
}]
|
||||
};
|
||||
async function testEmptyStringMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with empty string content
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testWhitespaceOnlyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with whitespace-only content
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: " \n\t "
|
||||
}]
|
||||
};
|
||||
async function testWhitespaceOnlyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with whitespace-only content
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: " \n\t ",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle whitespace-only gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle whitespace-only gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testEmptyAssistantMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with empty assistant message in conversation flow
|
||||
// User -> Empty Assistant -> User
|
||||
const emptyAssistant: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: llm.getApi(),
|
||||
provider: llm.getModel().provider,
|
||||
model: llm.getModel().id,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
||||
},
|
||||
stopReason: "stop"
|
||||
};
|
||||
async function testEmptyAssistantMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with empty assistant message in conversation flow
|
||||
// User -> Empty Assistant -> User
|
||||
const emptyAssistant: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: llm.api,
|
||||
provider: llm.provider,
|
||||
model: llm.id,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello, how are you?"
|
||||
},
|
||||
emptyAssistant,
|
||||
{
|
||||
role: "user",
|
||||
content: "Please respond this time."
|
||||
}
|
||||
]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello, how are you?",
|
||||
},
|
||||
emptyAssistant,
|
||||
{
|
||||
role: "user",
|
||||
content: "Please respond this time.",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty assistant message in context gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty assistant message in context gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
}
|
||||
|
||||
describe("AI Providers Empty Message Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
||||
let llm: GoogleLLM;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
||||
const llm = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (!model) {
|
||||
throw new Error("Model gpt-5-mini not found");
|
||||
}
|
||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
||||
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
||||
let llm: AnthropicLLM;
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-3-5-haiku-20241022")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
||||
const llm = getModel("xai", "grok-3");
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
// Test with xAI/Grok if available
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
if (!model) {
|
||||
throw new Error("Model grok-3 not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.XAI_API_KEY!);
|
||||
});
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
||||
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Test with Groq if available
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
||||
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("groq", "llama-3.3-70b-versatile");
|
||||
if (!model) {
|
||||
throw new Error("Model llama-3.3-70b-versatile not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.GROQ_API_KEY!);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Test with Cerebras if available
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
if (!model) {
|
||||
throw new Error("Model gpt-oss-120b not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.CEREBRAS_API_KEY!);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,311 +1,612 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { generate, generateComplete } from "../src/generate.js";
|
||||
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
|
||||
import { type ChildProcess, execSync, spawn } from "child_process";
|
||||
import { readFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { dirname, join } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Calculator tool definition (same as examples)
|
||||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform"
|
||||
}
|
||||
},
|
||||
required: ["a", "b", "operation"]
|
||||
}
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform",
|
||||
},
|
||||
},
|
||||
required: ["a", "b", "operation"],
|
||||
},
|
||||
};
|
||||
|
||||
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [
|
||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
||||
]
|
||||
};
|
||||
async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [{ role: "user", content: "Reply with exactly: 'Hello test successful'" }],
|
||||
};
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain("Hello test successful");
|
||||
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
const secondResponse = await complete(model, context, options);
|
||||
|
||||
const secondResponse = await generateComplete(model, context, options);
|
||||
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain(
|
||||
"Goodbye test successful",
|
||||
);
|
||||
}
|
||||
|
||||
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool."
|
||||
}],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool.",
|
||||
},
|
||||
],
|
||||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find(b => b.type == "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
const response = await complete(model, context, options);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find((b) => b.type === "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }],
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
const s = stream(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
for await (const event of s) {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
const response = await s.finalMessage();
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some((b) => b.type === "text")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleThinking(model: Model, options: GenerateOptionsUnified) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
const s = stream(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
}
|
||||
for await (const event of s) {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
const response = await s.finalMessage();
|
||||
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some((b) => b.type === "thinking")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleImage(model: Model, options?: GenerateOptionsUnified) {
|
||||
// Check if the model supports images
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
async function handleImage<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
// Check if the model supports images
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...).",
|
||||
},
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const textContent = response.content.find(b => b.type == "text");
|
||||
if (textContent && textContent.type === "text") {
|
||||
const lowerContent = textContent.text.toLowerCase();
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const textContent = response.content.find((b) => b.type === "text");
|
||||
if (textContent && textContent.type === "text") {
|
||||
const lowerContent = textContent.text.toLowerCase();
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
}
|
||||
|
||||
async function multiTurn(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
||||
}
|
||||
],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool.",
|
||||
},
|
||||
],
|
||||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await generateComplete(model, context, options);
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add": result = a + b; break;
|
||||
case "multiply": result = a * b; break;
|
||||
default: result = 0;
|
||||
}
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add":
|
||||
result = a + b;
|
||||
break;
|
||||
case "multiply":
|
||||
result = a * b;
|
||||
break;
|
||||
default:
|
||||
result = 0;
|
||||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
}
|
||||
|
||||
describe("Generate E2E Tests", () => {
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
let model: Model;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
it("should handle ", async () => {
|
||||
await handleThinking(llm, { thinking: { enabled: true, budgetTokens: 1024 } });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { thinking: { enabled: true, budgetTokens: 2048 } });
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
let model: Model;
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
||||
const llm: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(model, { reasoning: "low" });
|
||||
});
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(model, { reasoning: "medium" });
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle ", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
const model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle thinking", async () => {
|
||||
await handleThinking(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
||||
const llm = getModel("xai", "grok-code-fast-1");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
||||
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
||||
const llm = getModel("openrouter", "z-ai/glm-4.5v");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Check if ollama is installed
|
||||
let ollamaInstalled = false;
|
||||
try {
|
||||
execSync("which ollama", { stdio: "ignore" });
|
||||
ollamaInstalled = true;
|
||||
} catch {
|
||||
ollamaInstalled = false;
|
||||
}
|
||||
|
||||
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: Model<"openai-completions">;
|
||||
let ollamaProcess: ChildProcess | null = null;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Check if model is available, if not pull it
|
||||
try {
|
||||
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
||||
} catch {
|
||||
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
||||
try {
|
||||
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
||||
} catch (e) {
|
||||
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Start ollama server
|
||||
ollamaProcess = spawn("ollama", ["serve"], {
|
||||
detached: false,
|
||||
stdio: "ignore",
|
||||
});
|
||||
|
||||
// Wait for server to be ready
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkServer = async () => {
|
||||
try {
|
||||
const response = await fetch("http://localhost:11434/api/tags");
|
||||
if (response.ok) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
} catch {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
};
|
||||
setTimeout(checkServer, 1000); // Initial delay
|
||||
});
|
||||
|
||||
llm = {
|
||||
id: "gpt-oss:20b",
|
||||
api: "openai-completions",
|
||||
provider: "ollama",
|
||||
baseUrl: "http://localhost:11434/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16000,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
name: "Ollama GPT-OSS 20B",
|
||||
};
|
||||
}, 30000); // 30 second timeout for setup
|
||||
|
||||
afterAll(() => {
|
||||
// Kill ollama server
|
||||
if (ollamaProcess) {
|
||||
ollamaProcess.kill("SIGTERM");
|
||||
ollamaProcess = null;
|
||||
}
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,503 +1,489 @@
|
|||
import { describe, it, expect, beforeAll } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, Context, AssistantMessage, Tool, Message } from "../src/types.js";
|
||||
import { createLLM, getModel } from "../src/models.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js";
|
||||
|
||||
// Tool for testing
|
||||
const weatherTool: Tool = {
|
||||
name: "get_weather",
|
||||
description: "Get the weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" }
|
||||
},
|
||||
required: ["location"]
|
||||
}
|
||||
name: "get_weather",
|
||||
description: "Get the weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" },
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
};
|
||||
|
||||
// Pre-built contexts representing typical outputs from each provider
|
||||
const providerContexts = {
|
||||
// Anthropic-style message with thinking block
|
||||
anthropic: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
||||
thinkingSignature: "signature_abc123"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "toolu_01abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Tokyo" }
|
||||
}
|
||||
],
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-haiku-latest",
|
||||
usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "toolu_01abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 391,
|
||||
city: "Tokyo",
|
||||
temperature: 18,
|
||||
capital: "Vienna"
|
||||
}
|
||||
},
|
||||
// Anthropic-style message with thinking block
|
||||
anthropic: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
||||
thinkingSignature: "signature_abc123",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "toolu_01abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Tokyo" },
|
||||
},
|
||||
],
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-haiku-latest",
|
||||
usage: {
|
||||
input: 100,
|
||||
output: 50,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "toolu_01abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 391,
|
||||
city: "Tokyo",
|
||||
temperature: 18,
|
||||
capital: "Vienna",
|
||||
},
|
||||
},
|
||||
|
||||
// Google-style message with thinking
|
||||
google: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
||||
thinkingSignature: undefined
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_gemini_123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Berlin" }
|
||||
}
|
||||
],
|
||||
provider: "google",
|
||||
model: "gemini-2.5-flash",
|
||||
usage: { input: 120, output: 60, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_gemini_123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Berlin: 22°C, sunny",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 456,
|
||||
city: "Berlin",
|
||||
temperature: 22,
|
||||
capital: "Paris"
|
||||
}
|
||||
},
|
||||
// Google-style message with thinking
|
||||
google: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking:
|
||||
"I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
||||
thinkingSignature: undefined,
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_gemini_123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Berlin" },
|
||||
},
|
||||
],
|
||||
provider: "google",
|
||||
model: "gemini-2.5-flash",
|
||||
usage: {
|
||||
input: 120,
|
||||
output: 60,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_gemini_123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Berlin: 22°C, sunny",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 456,
|
||||
city: "Berlin",
|
||||
temperature: 22,
|
||||
capital: "Paris",
|
||||
},
|
||||
},
|
||||
|
||||
// OpenAI Completions style (with reasoning_content)
|
||||
openaiCompletions: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
||||
thinkingSignature: "reasoning_content"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "London" }
|
||||
}
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
usage: { input: 110, output: 55, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in London: 15°C, rainy",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 525,
|
||||
city: "London",
|
||||
temperature: 15,
|
||||
capital: "Madrid"
|
||||
}
|
||||
},
|
||||
// OpenAI Completions style (with reasoning_content)
|
||||
openaiCompletions: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
||||
thinkingSignature: "reasoning_content",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "London" },
|
||||
},
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
usage: {
|
||||
input: 110,
|
||||
output: 55,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in London: 15°C, rainy",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 525,
|
||||
city: "London",
|
||||
temperature: 15,
|
||||
capital: "Madrid",
|
||||
},
|
||||
},
|
||||
|
||||
// OpenAI Responses style (with complex tool call IDs)
|
||||
openaiResponses: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
||||
thinkingSignature: '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}'
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
||||
textSignature: "msg_response_456"
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney" }
|
||||
}
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-5-mini",
|
||||
usage: { input: 115, output: 58, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Sydney: 25°C, clear",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 486,
|
||||
city: "Sydney",
|
||||
temperature: 25,
|
||||
capital: "Rome"
|
||||
}
|
||||
},
|
||||
// OpenAI Responses style (with complex tool call IDs)
|
||||
openaiResponses: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
||||
thinkingSignature:
|
||||
'{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}',
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
||||
textSignature: "msg_response_456",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney" },
|
||||
},
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-5-mini",
|
||||
usage: {
|
||||
input: 115,
|
||||
output: 58,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Sydney: 25°C, clear",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 486,
|
||||
city: "Sydney",
|
||||
temperature: 25,
|
||||
capital: "Rome",
|
||||
},
|
||||
},
|
||||
|
||||
// Aborted message (stopReason: 'error')
|
||||
aborted: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me start calculating 20 * 30...",
|
||||
thinkingSignature: "partial_sig"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I was about to calculate 20 × 30 which is"
|
||||
}
|
||||
],
|
||||
provider: "test",
|
||||
model: "test-model",
|
||||
usage: { input: 50, output: 25, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "error",
|
||||
error: "Request was aborted"
|
||||
} as AssistantMessage,
|
||||
toolResult: null,
|
||||
facts: {
|
||||
calculation: 600,
|
||||
city: "none",
|
||||
temperature: 0,
|
||||
capital: "none"
|
||||
}
|
||||
}
|
||||
// Aborted message (stopReason: 'error')
|
||||
aborted: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me start calculating 20 * 30...",
|
||||
thinkingSignature: "partial_sig",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I was about to calculate 20 × 30 which is",
|
||||
},
|
||||
],
|
||||
provider: "test",
|
||||
model: "test-model",
|
||||
usage: {
|
||||
input: 50,
|
||||
output: 25,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
error: "Request was aborted",
|
||||
} as AssistantMessage,
|
||||
toolResult: null,
|
||||
facts: {
|
||||
calculation: 600,
|
||||
city: "none",
|
||||
temperature: 0,
|
||||
capital: "none",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Test that a provider can handle contexts from different sources
|
||||
*/
|
||||
async function testProviderHandoff(
|
||||
targetProvider: LLM<any>,
|
||||
sourceLabel: string,
|
||||
sourceContext: typeof providerContexts[keyof typeof providerContexts]
|
||||
async function testProviderHandoff<TApi extends Api>(
|
||||
targetModel: Model<TApi>,
|
||||
sourceLabel: string,
|
||||
sourceContext: (typeof providerContexts)[keyof typeof providerContexts],
|
||||
): Promise<boolean> {
|
||||
// Build conversation context
|
||||
const messages: Message[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Please do some calculations, tell me about capitals, and check the weather."
|
||||
},
|
||||
sourceContext.message
|
||||
];
|
||||
// Build conversation context
|
||||
const messages: Message[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Please do some calculations, tell me about capitals, and check the weather.",
|
||||
},
|
||||
sourceContext.message,
|
||||
];
|
||||
|
||||
// Add tool result if present
|
||||
if (sourceContext.toolResult) {
|
||||
messages.push(sourceContext.toolResult);
|
||||
}
|
||||
// Add tool result if present
|
||||
if (sourceContext.toolResult) {
|
||||
messages.push(sourceContext.toolResult);
|
||||
}
|
||||
|
||||
// Ask follow-up question
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Based on our conversation, please answer:
|
||||
// Ask follow-up question
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Based on our conversation, please answer:
|
||||
1) What was the multiplication result?
|
||||
2) Which city's weather did we check?
|
||||
3) What was the temperature?
|
||||
4) What capital city was mentioned?
|
||||
Please include the specific numbers and names.`
|
||||
});
|
||||
Please include the specific numbers and names.`,
|
||||
});
|
||||
|
||||
const context: Context = {
|
||||
messages,
|
||||
tools: [weatherTool]
|
||||
};
|
||||
const context: Context = {
|
||||
messages,
|
||||
tools: [weatherTool],
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await targetProvider.generate(context, {});
|
||||
try {
|
||||
const response = await complete(targetModel, context, {});
|
||||
|
||||
// Check for error
|
||||
if (response.stopReason === "error") {
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Failed with error: ${response.error}`);
|
||||
return false;
|
||||
}
|
||||
// Check for error
|
||||
if (response.stopReason === "error") {
|
||||
console.log(`[${sourceLabel} → ${targetModel.provider}] Failed with error: ${response.error}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
const responseText = response.content
|
||||
.filter(b => b.type === "text")
|
||||
.map(b => b.text)
|
||||
.join(" ")
|
||||
.toLowerCase();
|
||||
// Extract text from response
|
||||
const responseText = response.content
|
||||
.filter((b) => b.type === "text")
|
||||
.map((b) => b.text)
|
||||
.join(" ")
|
||||
.toLowerCase();
|
||||
|
||||
// For aborted messages, we don't expect to find the facts
|
||||
if (sourceContext.message.stopReason === "error") {
|
||||
const hasToolCalls = response.content.some(b => b.type === "toolCall");
|
||||
const hasThinking = response.content.some(b => b.type === "thinking");
|
||||
const hasText = response.content.some(b => b.type === "text");
|
||||
// For aborted messages, we don't expect to find the facts
|
||||
if (sourceContext.message.stopReason === "error") {
|
||||
const hasToolCalls = response.content.some((b) => b.type === "toolCall");
|
||||
const hasThinking = response.content.some((b) => b.type === "thinking");
|
||||
const hasText = response.content.some((b) => b.type === "text");
|
||||
|
||||
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
||||
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`);
|
||||
return true;
|
||||
}
|
||||
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
||||
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
||||
console.log(
|
||||
`[${sourceLabel} → ${targetModel.provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if response contains our facts
|
||||
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
||||
const hasCity = sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
||||
const hasTemperature = sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
||||
const hasCapital = sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
||||
// Check if response contains our facts
|
||||
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
||||
const hasCity =
|
||||
sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
||||
const hasTemperature =
|
||||
sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
||||
const hasCapital =
|
||||
sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
||||
|
||||
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
||||
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
||||
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handoff test:`);
|
||||
if (!success) {
|
||||
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? '✓' : '✗'}`);
|
||||
console.log(` City (${sourceContext.facts.city}): ${hasCity ? '✓' : '✗'}`);
|
||||
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? '✓' : '✗'}`);
|
||||
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? '✓' : '✗'}`);
|
||||
} else {
|
||||
console.log(` ✓ All facts found`);
|
||||
}
|
||||
console.log(`[${sourceLabel} → ${targetModel.provider}] Handoff test:`);
|
||||
if (!success) {
|
||||
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? "✓" : "✗"}`);
|
||||
console.log(` City (${sourceContext.facts.city}): ${hasCity ? "✓" : "✗"}`);
|
||||
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? "✓" : "✗"}`);
|
||||
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? "✓" : "✗"}`);
|
||||
} else {
|
||||
console.log(` ✓ All facts found`);
|
||||
}
|
||||
|
||||
return success;
|
||||
} catch (error) {
|
||||
console.error(`[${sourceLabel} → ${targetProvider.getModel().provider}] Exception:`, error);
|
||||
return false;
|
||||
}
|
||||
return success;
|
||||
} catch (error) {
|
||||
console.error(`[${sourceLabel} → ${targetModel.provider}] Exception:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
describe("Cross-Provider Handoff Tests", () => {
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
||||
let provider: AnthropicLLM;
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
if (model) {
|
||||
provider = new AnthropicLLM(model, process.env.ANTHROPIC_API_KEY!);
|
||||
}
|
||||
});
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("Anthropic provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting Google with pre-built contexts:\n");
|
||||
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
||||
let provider: GoogleLLM;
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
if (model) {
|
||||
provider = new GoogleLLM(model, process.env.GEMINI_API_KEY!);
|
||||
}
|
||||
});
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("Google provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
console.log("\nTesting Google with pre-built contexts:\n");
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
||||
const model: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
||||
let provider: OpenAICompletionsLLM;
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
if (model) {
|
||||
provider = new OpenAICompletionsLLM(model, process.env.OPENAI_API_KEY!);
|
||||
}
|
||||
});
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("OpenAI Completions provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
||||
let provider: OpenAIResponsesLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (model) {
|
||||
provider = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
}
|
||||
});
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("OpenAI Responses provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
});
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
import { GoogleGenAI } from "@google/genai";
|
||||
import OpenAI from "openai";
|
||||
|
||||
const ai = new GoogleGenAI({});
|
||||
|
||||
async function main() {
|
||||
/*let pager = await ai.models.list();
|
||||
do {
|
||||
for (const model of pager.page) {
|
||||
console.log(JSON.stringify(model, null, 2));
|
||||
console.log("---");
|
||||
}
|
||||
if (!pager.hasNextPage()) break;
|
||||
await pager.nextPage();
|
||||
} while (true);*/
|
||||
|
||||
const openai = new OpenAI();
|
||||
const response = await openai.models.list();
|
||||
do {
|
||||
const page = response.data;
|
||||
for (const model of page) {
|
||||
const info = await openai.models.retrieve(model.id);
|
||||
console.log(JSON.stringify(model, null, 2));
|
||||
console.log("---");
|
||||
}
|
||||
if (!response.hasNextPage()) break;
|
||||
await response.getNextPage();
|
||||
} while (true);
|
||||
}
|
||||
|
||||
await main();
|
||||
|
|
@ -1,618 +0,0 @@
|
|||
import { describe, it, beforeAll, afterAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context, Tool, AssistantMessage, Model, ImageContent } from "../src/types.js";
|
||||
import { spawn, ChildProcess, execSync } from "child_process";
|
||||
import { createLLM, getModel } from "../src/models.js";
|
||||
import { readFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Calculator tool definition (same as examples)
|
||||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform"
|
||||
}
|
||||
},
|
||||
required: ["a", "b", "operation"]
|
||||
}
|
||||
};
|
||||
|
||||
async function basicTextGeneration<T extends LLMOptions>(llm: LLM<T>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [
|
||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
||||
]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
|
||||
const secondResponse = await llm.generate(context);
|
||||
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
||||
}
|
||||
|
||||
async function handleToolCall<T extends LLMOptions>(llm: LLM<T>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool."
|
||||
}],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find(b => b.type == "toolCall")!;
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleStreaming<T extends LLMOptions>(llm: LLM<T>) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
onEvent: (event) => {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
} as T);
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleThinking<T extends LLMOptions>(llm: LLM<T>, options: T) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
onEvent: (event) => {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
expect(event.content.endsWith(event.delta)).toBe(true);
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
},
|
||||
...options
|
||||
});
|
||||
|
||||
|
||||
expect(response.stopReason, `Error: ${(response as any).error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleImage<T extends LLMOptions>(llm: LLM<T>) {
|
||||
// Check if the model supports images
|
||||
const model = llm.getModel();
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const lowerContent = response.content.find(b => b.type == "text")?.text || "";
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
|
||||
async function multiTurn<T extends LLMOptions>(llm: LLM<T>, thinkingOptions: T) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
||||
}
|
||||
],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await llm.generate(context, thinkingOptions);
|
||||
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add": result = a + b; break;
|
||||
case "multiply": result = a * b; break;
|
||||
default: result = 0;
|
||||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
}
|
||||
|
||||
describe("AI Providers E2E Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
||||
let llm: GoogleLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {thinking: { enabled: true, budgetTokens: 1024 }});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAIResponsesLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", {retry: 2}, async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "high"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "high"});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-sonnet-4-20250514")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {thinking: { enabled: true } });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("xai", "grok-code-fast-1")!, process.env.XAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("groq", "openai/gpt-oss-20b")!, process.env.GROQ_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("cerebras", "gpt-oss-120b")!, process.env.CEREBRAS_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openrouter", "z-ai/glm-4.5v")!, process.env.OPENROUTER_API_KEY!);;
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Check if ollama is installed
|
||||
let ollamaInstalled = false;
|
||||
try {
|
||||
execSync("which ollama", { stdio: "ignore" });
|
||||
ollamaInstalled = true;
|
||||
} catch {
|
||||
ollamaInstalled = false;
|
||||
}
|
||||
|
||||
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
let ollamaProcess: ChildProcess | null = null;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Check if model is available, if not pull it
|
||||
try {
|
||||
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
||||
} catch {
|
||||
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
||||
try {
|
||||
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
||||
} catch (e) {
|
||||
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Start ollama server
|
||||
ollamaProcess = spawn("ollama", ["serve"], {
|
||||
detached: false,
|
||||
stdio: "ignore"
|
||||
});
|
||||
|
||||
// Wait for server to be ready
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkServer = async () => {
|
||||
try {
|
||||
const response = await fetch("http://localhost:11434/api/tags");
|
||||
if (response.ok) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
} catch {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
};
|
||||
setTimeout(checkServer, 1000); // Initial delay
|
||||
});
|
||||
|
||||
const model: Model = {
|
||||
id: "gpt-oss:20b",
|
||||
provider: "ollama",
|
||||
baseUrl: "http://localhost:11434/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16000,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
name: "Ollama GPT-OSS 20B"
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, "dummy");
|
||||
}, 30000); // 30 second timeout for setup
|
||||
|
||||
afterAll(() => {
|
||||
// Kill ollama server
|
||||
if (ollamaProcess) {
|
||||
ollamaProcess.kill("SIGTERM");
|
||||
ollamaProcess = null;
|
||||
}
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
/*
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = createLLM("anthropic", "claude-3-5-haiku-latest");
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: {enabled: true}});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
*/
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue