Massive refactor of API

- Switch to function based API
- Anthropic SDK style async generator
- Fully typed with escape hatches for custom models
This commit is contained in:
Mario Zechner 2025-09-02 23:59:36 +02:00
parent 004de3c9d0
commit 66cefb236e
29 changed files with 5835 additions and 6225 deletions

View file

@ -28,6 +28,6 @@
"lineWidth": 120 "lineWidth": 120
}, },
"files": { "files": {
"includes": ["packages/*/src/**/*", "*.json", "*.md"] "includes": ["packages/*/src/**/*", "packages/*/test/**/*", "*.json", "*.md"]
} }
} }

View file

@ -3,6 +3,7 @@
import { writeFileSync } from "fs"; import { writeFileSync } from "fs";
import { join, dirname } from "path"; import { join, dirname } from "path";
import { fileURLToPath } from "url"; import { fileURLToPath } from "url";
import { Api, KnownProvider, Model } from "../src/types.js";
const __filename = fileURLToPath(import.meta.url); const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename); const __dirname = dirname(__filename);
@ -28,30 +29,13 @@ interface ModelsDevModel {
}; };
} }
interface NormalizedModel { async function fetchOpenRouterModels(): Promise<Model<any>[]> {
id: string;
name: string;
provider: string;
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
}
async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
try { try {
console.log("Fetching models from OpenRouter API..."); console.log("Fetching models from OpenRouter API...");
const response = await fetch("https://openrouter.ai/api/v1/models"); const response = await fetch("https://openrouter.ai/api/v1/models");
const data = await response.json(); const data = await response.json();
const models: NormalizedModel[] = []; const models: Model<any>[] = [];
for (const model of data.data) { for (const model of data.data) {
// Only include models that support tools // Only include models that support tools
@ -59,27 +43,17 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
// Parse provider from model ID // Parse provider from model ID
const [providerPrefix] = model.id.split("/"); const [providerPrefix] = model.id.split("/");
let provider = ""; let provider: KnownProvider = "openrouter";
let modelKey = model.id; let modelKey = model.id;
// Skip models that we get from models.dev (Anthropic, Google, OpenAI) // Skip models that we get from models.dev (Anthropic, Google, OpenAI)
if (model.id.startsWith("google/") || if (model.id.startsWith("google/") ||
model.id.startsWith("openai/") || model.id.startsWith("openai/") ||
model.id.startsWith("anthropic/")) { model.id.startsWith("anthropic/") ||
continue; model.id.startsWith("x-ai/")) {
} else if (model.id.startsWith("x-ai/")) {
provider = "xai";
modelKey = model.id.replace("x-ai/", "");
} else {
// All other models go through OpenRouter
provider = "openrouter";
modelKey = model.id; // Keep full ID for OpenRouter
}
// Skip if not one of our supported providers from OpenRouter
if (!["xai", "openrouter"].includes(provider)) {
continue; continue;
} }
modelKey = model.id; // Keep full ID for OpenRouter
// Parse input modalities // Parse input modalities
const input: ("text" | "image")[] = ["text"]; const input: ("text" | "image")[] = ["text"];
@ -93,9 +67,11 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000; const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000; const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
const normalizedModel: NormalizedModel = { const normalizedModel: Model<any> = {
id: modelKey, id: modelKey,
name: model.name, name: model.name,
api: "openai-completions",
baseUrl: "https://openrouter.ai/api/v1",
provider, provider,
reasoning: model.supported_parameters?.includes("reasoning") || false, reasoning: model.supported_parameters?.includes("reasoning") || false,
input, input,
@ -108,14 +84,6 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
contextWindow: model.context_length || 4096, contextWindow: model.context_length || 4096,
maxTokens: model.top_provider?.max_completion_tokens || 4096, maxTokens: model.top_provider?.max_completion_tokens || 4096,
}; };
// Add baseUrl for providers that need it
if (provider === "xai") {
normalizedModel.baseUrl = "https://api.x.ai/v1";
} else if (provider === "openrouter") {
normalizedModel.baseUrl = "https://openrouter.ai/api/v1";
}
models.push(normalizedModel); models.push(normalizedModel);
} }
@ -127,13 +95,13 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
} }
} }
async function loadModelsDevData(): Promise<NormalizedModel[]> { async function loadModelsDevData(): Promise<Model<any>[]> {
try { try {
console.log("Fetching models from models.dev API..."); console.log("Fetching models from models.dev API...");
const response = await fetch("https://models.dev/api.json"); const response = await fetch("https://models.dev/api.json");
const data = await response.json(); const data = await response.json();
const models: NormalizedModel[] = []; const models: Model<any>[] = [];
// Process Anthropic models // Process Anthropic models
if (data.anthropic?.models) { if (data.anthropic?.models) {
@ -144,7 +112,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
models.push({ models.push({
id: modelId, id: modelId,
name: m.name || modelId, name: m.name || modelId,
api: "anthropic-messages",
provider: "anthropic", provider: "anthropic",
baseUrl: "https://api.anthropic.com",
reasoning: m.reasoning === true, reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"], input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: { cost: {
@ -168,7 +138,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
models.push({ models.push({
id: modelId, id: modelId,
name: m.name || modelId, name: m.name || modelId,
api: "google-generative-ai",
provider: "google", provider: "google",
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
reasoning: m.reasoning === true, reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"], input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: { cost: {
@ -192,7 +164,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
models.push({ models.push({
id: modelId, id: modelId,
name: m.name || modelId, name: m.name || modelId,
api: "openai-responses",
provider: "openai", provider: "openai",
baseUrl: "https://api.openai.com/v1",
reasoning: m.reasoning === true, reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"], input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: { cost: {
@ -216,6 +190,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
models.push({ models.push({
id: modelId, id: modelId,
name: m.name || modelId, name: m.name || modelId,
api: "openai-completions",
provider: "groq", provider: "groq",
baseUrl: "https://api.groq.com/openai/v1", baseUrl: "https://api.groq.com/openai/v1",
reasoning: m.reasoning === true, reasoning: m.reasoning === true,
@ -241,6 +216,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
models.push({ models.push({
id: modelId, id: modelId,
name: m.name || modelId, name: m.name || modelId,
api: "openai-completions",
provider: "cerebras", provider: "cerebras",
baseUrl: "https://api.cerebras.ai/v1", baseUrl: "https://api.cerebras.ai/v1",
reasoning: m.reasoning === true, reasoning: m.reasoning === true,
@ -257,6 +233,32 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
} }
} }
// Process xAi models
if (data.xai?.models) {
for (const [modelId, model] of Object.entries(data.xai.models)) {
const m = model as ModelsDevModel;
if (m.tool_call !== true) continue;
models.push({
id: modelId,
name: m.name || modelId,
api: "openai-completions",
provider: "xai",
baseUrl: "https://api.x.ai/v1",
reasoning: m.reasoning === true,
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
cost: {
input: m.cost?.input || 0,
output: m.cost?.output || 0,
cacheRead: m.cost?.cache_read || 0,
cacheWrite: m.cost?.cache_write || 0,
},
contextWindow: m.limit?.context || 4096,
maxTokens: m.limit?.output || 4096,
});
}
}
console.log(`Loaded ${models.length} tool-capable models from models.dev`); console.log(`Loaded ${models.length} tool-capable models from models.dev`);
return models; return models;
} catch (error) { } catch (error) {
@ -280,6 +282,8 @@ async function generateModels() {
allModels.push({ allModels.push({
id: "gpt-5-chat-latest", id: "gpt-5-chat-latest",
name: "GPT-5 Chat Latest", name: "GPT-5 Chat Latest",
api: "openai-responses",
baseUrl: "https://api.openai.com/v1",
provider: "openai", provider: "openai",
reasoning: false, reasoning: false,
input: ["text", "image"], input: ["text", "image"],
@ -294,8 +298,29 @@ async function generateModels() {
}); });
} }
// Add missing Grok models
if (!allModels.some(m => m.provider === "xai" && m.id === "grok-code-fast-1")) {
allModels.push({
id: "grok-code-fast-1",
name: "Grok Code Fast 1",
api: "openai-completions",
baseUrl: "https://api.x.ai/v1",
provider: "xai",
reasoning: false,
input: ["text"],
cost: {
input: 0.2,
output: 1.5,
cacheRead: 0.02,
cacheWrite: 0,
},
contextWindow: 32768,
maxTokens: 8192,
});
}
// Group by provider and deduplicate by model ID // Group by provider and deduplicate by model ID
const providers: Record<string, Record<string, NormalizedModel>> = {}; const providers: Record<string, Record<string, Model<any>>> = {};
for (const model of allModels) { for (const model of allModels) {
if (!providers[model.provider]) { if (!providers[model.provider]) {
providers[model.provider] = {}; providers[model.provider] = {};
@ -319,39 +344,33 @@ export const PROVIDERS = {
// Generate provider sections // Generate provider sections
for (const [providerId, models] of Object.entries(providers)) { for (const [providerId, models] of Object.entries(providers)) {
output += `\t${providerId}: {\n`; output += `\t${providerId}: {\n`;
output += `\t\tmodels: {\n`;
for (const model of Object.values(models)) { for (const model of Object.values(models)) {
output += `\t\t\t"${model.id}": {\n`; output += `\t\t"${model.id}": {\n`;
output += `\t\t\t\tid: "${model.id}",\n`; output += `\t\t\tid: "${model.id}",\n`;
output += `\t\t\t\tname: "${model.name}",\n`; output += `\t\t\tname: "${model.name}",\n`;
output += `\t\t\t\tprovider: "${model.provider}",\n`; output += `\t\t\tapi: "${model.api}",\n`;
output += `\t\t\tprovider: "${model.provider}",\n`;
if (model.baseUrl) { if (model.baseUrl) {
output += `\t\t\t\tbaseUrl: "${model.baseUrl}",\n`; output += `\t\t\tbaseUrl: "${model.baseUrl}",\n`;
} }
output += `\t\t\t\treasoning: ${model.reasoning},\n`; output += `\t\t\treasoning: ${model.reasoning},\n`;
output += `\t\t\t\tinput: ${JSON.stringify(model.input)},\n`; output += `\t\t\tinput: [${model.input.map(i => `"${i}"`).join(", ")}],\n`;
output += `\t\t\t\tcost: {\n`; output += `\t\t\tcost: {\n`;
output += `\t\t\t\t\tinput: ${model.cost.input},\n`; output += `\t\t\t\tinput: ${model.cost.input},\n`;
output += `\t\t\t\t\toutput: ${model.cost.output},\n`; output += `\t\t\t\toutput: ${model.cost.output},\n`;
output += `\t\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`; output += `\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
output += `\t\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`; output += `\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
output += `\t\t\t\t},\n`; output += `\t\t\t},\n`;
output += `\t\t\t\tcontextWindow: ${model.contextWindow},\n`; output += `\t\t\tcontextWindow: ${model.contextWindow},\n`;
output += `\t\t\t\tmaxTokens: ${model.maxTokens},\n`; output += `\t\t\tmaxTokens: ${model.maxTokens},\n`;
output += `\t\t\t} satisfies Model,\n`; output += `\t\t} satisfies Model<"${model.api}">,\n`;
} }
output += `\t\t}\n`;
output += `\t},\n`; output += `\t},\n`;
} }
output += `} as const; output += `} as const;
// Helper type to extract models for each provider
export type ProviderModels = {
[K in keyof typeof PROVIDERS]: typeof PROVIDERS[K]["models"]
};
`; `;
// Write file // Write file

View file

@ -1,47 +1,43 @@
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
import type { import type {
Api, Api,
AssistantMessage, AssistantMessage,
AssistantMessageEvent, AssistantMessageEvent,
Context, Context,
GenerateFunction,
GenerateOptionsUnified,
GenerateStream, GenerateStream,
KnownProvider, KnownProvider,
Model, Model,
OptionsForApi,
ReasoningEffort, ReasoningEffort,
SimpleGenerateOptions,
} from "./types.js"; } from "./types.js";
export class QueuedGenerateStream implements GenerateStream { export class QueuedGenerateStream implements GenerateStream {
private queue: AssistantMessageEvent[] = []; private queue: AssistantMessageEvent[] = [];
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = []; private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
private done = false; private done = false;
private error?: Error;
private finalMessagePromise: Promise<AssistantMessage>; private finalMessagePromise: Promise<AssistantMessage>;
private resolveFinalMessage!: (message: AssistantMessage) => void; private resolveFinalMessage!: (message: AssistantMessage) => void;
private rejectFinalMessage!: (error: Error) => void;
constructor() { constructor() {
this.finalMessagePromise = new Promise((resolve, reject) => { this.finalMessagePromise = new Promise((resolve) => {
this.resolveFinalMessage = resolve; this.resolveFinalMessage = resolve;
this.rejectFinalMessage = reject;
}); });
} }
push(event: AssistantMessageEvent): void { push(event: AssistantMessageEvent): void {
if (this.done) return; if (this.done) return;
// If it's the done event, resolve the final message
if (event.type === "done") { if (event.type === "done") {
this.done = true; this.done = true;
this.resolveFinalMessage(event.message); this.resolveFinalMessage(event.message);
} }
// If it's an error event, reject the final message
if (event.type === "error") { if (event.type === "error") {
this.error = new Error(event.error); this.done = true;
if (!this.done) { this.resolveFinalMessage(event.partial);
this.rejectFinalMessage(this.error);
}
} }
// Deliver to waiting consumer or queue it // Deliver to waiting consumer or queue it
@ -86,31 +82,14 @@ export class QueuedGenerateStream implements GenerateStream {
} }
} }
// API implementations registry
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
/**
* Register a custom API implementation
*/
export function registerApi(api: string, impl: GenerateFunction): void {
apiImplementations.set(api, impl);
}
// API key storage
const apiKeys: Map<string, string> = new Map(); const apiKeys: Map<string, string> = new Map();
/**
* Set an API key for a provider
*/
export function setApiKey(provider: KnownProvider, key: string): void; export function setApiKey(provider: KnownProvider, key: string): void;
export function setApiKey(provider: string, key: string): void; export function setApiKey(provider: string, key: string): void;
export function setApiKey(provider: any, key: string): void { export function setApiKey(provider: any, key: string): void {
apiKeys.set(provider, key); apiKeys.set(provider, key);
} }
/**
* Get API key for a provider
*/
export function getApiKey(provider: KnownProvider): string | undefined; export function getApiKey(provider: KnownProvider): string | undefined;
export function getApiKey(provider: string): string | undefined; export function getApiKey(provider: string): string | undefined;
export function getApiKey(provider: any): string | undefined { export function getApiKey(provider: any): string | undefined {
@ -133,45 +112,76 @@ export function getApiKey(provider: any): string | undefined {
return envVar ? process.env[envVar] : undefined; return envVar ? process.env[envVar] : undefined;
} }
/** export function stream<TApi extends Api>(
* Main generate function model: Model<TApi>,
*/ context: Context,
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream { options?: OptionsForApi<TApi>,
// Get implementation ): GenerateStream {
const impl = apiImplementations.get(model.api); const apiKey = options?.apiKey || getApiKey(model.provider);
if (!impl) { if (!apiKey) {
throw new Error(`Unsupported API: ${model.api}`); throw new Error(`No API key for provider: ${model.provider}`);
} }
const providerOptions = { ...options, apiKey };
// Get API key from options or environment const api: Api = model.api;
switch (api) {
case "anthropic-messages":
return streamAnthropic(model as Model<"anthropic-messages">, context, providerOptions);
case "openai-completions":
return streamOpenAICompletions(model as Model<"openai-completions">, context, providerOptions as any);
case "openai-responses":
return streamOpenAIResponses(model as Model<"openai-responses">, context, providerOptions as any);
case "google-generative-ai":
return streamGoogle(model as Model<"google-generative-ai">, context, providerOptions);
default: {
// This should never be reached if all Api cases are handled
const _exhaustive: never = api;
throw new Error(`Unhandled API: ${_exhaustive}`);
}
}
}
export async function complete<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: OptionsForApi<TApi>,
): Promise<AssistantMessage> {
const s = stream(model, context, options);
return s.finalMessage();
}
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleGenerateOptions,
): GenerateStream {
const apiKey = options?.apiKey || getApiKey(model.provider); const apiKey = options?.apiKey || getApiKey(model.provider);
if (!apiKey) { if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`); throw new Error(`No API key for provider: ${model.provider}`);
} }
// Map generic options to provider-specific const providerOptions = mapOptionsForApi(model, options, apiKey);
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey); return stream(model, context, providerOptions);
// Return the GenerateStream from implementation
return impl(model, context, providerOptions);
} }
/** export async function completeSimple<TApi extends Api>(
* Helper to generate and get complete response (no streaming) model: Model<TApi>,
*/
export async function generateComplete(
model: Model,
context: Context, context: Context,
options?: GenerateOptionsUnified, options?: SimpleGenerateOptions,
): Promise<AssistantMessage> { ): Promise<AssistantMessage> {
const stream = generate(model, context, options); const s = streamSimple(model, context, options);
return stream.finalMessage(); return s.finalMessage();
} }
/** function mapOptionsForApi<TApi extends Api>(
* Map generic options to provider-specific options model: Model<TApi>,
*/ options?: SimpleGenerateOptions,
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any { apiKey?: string,
): OptionsForApi<TApi> {
const base = { const base = {
temperature: options?.temperature, temperature: options?.temperature,
maxTokens: options?.maxTokens, maxTokens: options?.maxTokens,
@ -179,18 +189,10 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
apiKey: apiKey || options?.apiKey, apiKey: apiKey || options?.apiKey,
}; };
switch (api) { switch (model.api) {
case "openai-responses":
case "openai-completions":
return {
...base,
reasoning_effort: options?.reasoning,
};
case "anthropic-messages": { case "anthropic-messages": {
if (!options?.reasoning) return base; if (!options?.reasoning) return base satisfies AnthropicOptions;
// Map effort to token budget
const anthropicBudgets = { const anthropicBudgets = {
minimal: 1024, minimal: 1024,
low: 2048, low: 2048,
@ -200,55 +202,60 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
return { return {
...base, ...base,
thinking: { thinkingEnabled: true,
enabled: true, thinkingBudgetTokens: anthropicBudgets[options.reasoning],
budgetTokens: anthropicBudgets[options.reasoning], } satisfies AnthropicOptions;
},
};
} }
case "google-generative-ai": {
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
// Model-specific mapping for Google case "openai-completions":
const googleBudget = getGoogleBudget(model, options.reasoning);
return { return {
...base, ...base,
thinking_budget: googleBudget, reasoningEffort: options?.reasoning,
}; } satisfies OpenAICompletionsOptions;
case "openai-responses":
return {
...base,
reasoningEffort: options?.reasoning,
} satisfies OpenAIResponsesOptions;
case "google-generative-ai": {
if (!options?.reasoning) return base as any;
const googleBudget = getGoogleBudget(model as Model<"google-generative-ai">, options.reasoning);
return {
...base,
thinking: {
enabled: true,
budgetTokens: googleBudget,
},
} satisfies GoogleOptions;
}
default: {
// Exhaustiveness check
const _exhaustive: never = model.api;
throw new Error(`Unhandled API in mapOptionsForApi: ${_exhaustive}`);
} }
default:
return base;
} }
} }
/** function getGoogleBudget(model: Model<"google-generative-ai">, effort: ReasoningEffort): number {
* Get Google thinking budget based on model and effort // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
*/ if (model.id.includes("2.5-pro")) {
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
// Model-specific logic
if (model.id.includes("flash-lite")) {
const budgets = {
minimal: 512,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
if (model.id.includes("pro")) {
const budgets = { const budgets = {
minimal: 128, minimal: 128,
low: 2048, low: 2048,
medium: 8192, medium: 8192,
high: Math.min(25000, 32768), high: 32768,
}; };
return budgets[effort]; return budgets[effort];
} }
if (model.id.includes("flash")) { if (model.id.includes("2.5-flash")) {
// Covers 2.5-flash-lite as well
const budgets = { const budgets = {
minimal: 0, // Disable thinking minimal: 128,
low: 2048, low: 2048,
medium: 8192, medium: 8192,
high: 24576, high: 24576,
@ -259,10 +266,3 @@ function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
// Unknown model - use dynamic // Unknown model - use dynamic
return -1; return -1;
} }
// Register built-in API implementations
// Import the new function-based implementations
import { generateAnthropic } from "./providers/anthropic-generate.js";
// Register Anthropic implementation
apiImplementations.set("anthropic-messages", generateAnthropic);

View file

@ -1,37 +1,8 @@
// @mariozechner/pi-ai - Unified LLM API with automatic model discovery export * from "./generate.js";
// This package provides a common interface for working with multiple LLM providers export * from "./models.generated.js";
export * from "./models.js";
export const version = "0.5.8"; export * from "./providers/anthropic.js";
export * from "./providers/google.js";
// Export generate API export * from "./providers/openai-completions.js";
export { export * from "./providers/openai-responses.js";
generate, export * from "./types.js";
generateComplete,
getApiKey,
QueuedGenerateStream,
registerApi,
setApiKey,
} from "./generate.js";
// Export generated models data
export { PROVIDERS } from "./models.generated.js";
// Export model utilities
export {
calculateCost,
getModel,
type KnownProvider,
registerModel,
} from "./models.js";
// Legacy providers (to be deprecated)
export { AnthropicLLM } from "./providers/anthropic.js";
export { GoogleLLM } from "./providers/google.js";
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
// Export types
export type * from "./types.js";
// TODO: Remove these legacy exports once consumers are updated
export function createLLM(): never {
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
}

File diff suppressed because it is too large Load diff

View file

@ -1,44 +1,39 @@
import { PROVIDERS } from "./models.generated.js"; import { PROVIDERS } from "./models.generated.js";
import type { KnownProvider, Model, Usage } from "./types.js"; import type { Api, KnownProvider, Model, Usage } from "./types.js";
// Re-export Model type const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
export type { KnownProvider, Model } from "./types.js";
// Dynamic model registry initialized from PROVIDERS
const modelRegistry: Map<string, Map<string, Model>> = new Map();
// Initialize registry from PROVIDERS on module load // Initialize registry from PROVIDERS on module load
for (const [provider, models] of Object.entries(PROVIDERS)) { for (const [provider, models] of Object.entries(PROVIDERS)) {
const providerModels = new Map<string, Model>(); const providerModels = new Map<string, Model<Api>>();
for (const [id, model] of Object.entries(models)) { for (const [id, model] of Object.entries(models)) {
providerModels.set(id, model as Model); providerModels.set(id, model as Model<Api>);
} }
modelRegistry.set(provider, providerModels); modelRegistry.set(provider, providerModels);
} }
/** type ModelApi<
* Get a model from the registry - typed overload for known providers TProvider extends KnownProvider,
*/ TModelId extends keyof (typeof PROVIDERS)[TProvider],
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model; > = (typeof PROVIDERS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
export function getModel(provider: string, modelId: string): Model | undefined;
export function getModel(provider: any, modelId: any): Model | undefined { export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof PROVIDERS)[TProvider]>(
return modelRegistry.get(provider)?.get(modelId); provider: TProvider,
modelId: TModelId,
): Model<ModelApi<TProvider, TModelId>>;
export function getModel<TApi extends Api>(provider: string, modelId: string): Model<TApi> | undefined;
export function getModel<TApi extends Api>(provider: any, modelId: any): Model<TApi> | undefined {
return modelRegistry.get(provider)?.get(modelId) as Model<TApi> | undefined;
} }
/** export function registerModel<TApi extends Api>(model: Model<TApi>): void {
* Register a custom model
*/
export function registerModel(model: Model): void {
if (!modelRegistry.has(model.provider)) { if (!modelRegistry.has(model.provider)) {
modelRegistry.set(model.provider, new Map()); modelRegistry.set(model.provider, new Map());
} }
modelRegistry.get(model.provider)!.set(model.id, model); modelRegistry.get(model.provider)!.set(model.id, model);
} }
/** export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
* Calculate cost for token usage
*/
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
usage.cost.input = (model.cost.input / 1000000) * usage.input; usage.cost.input = (model.cost.input / 1000000) * usage.input;
usage.cost.output = (model.cost.output / 1000000) * usage.output; usage.cost.output = (model.cost.output / 1000000) * usage.output;
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead; usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;

View file

@ -1,425 +0,0 @@
import Anthropic from "@anthropic-ai/sdk";
import type {
ContentBlockParam,
MessageCreateParamsStreaming,
MessageParam,
Tool,
} from "@anthropic-ai/sdk/resources/messages.js";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
GenerateFunction,
GenerateOptions,
GenerateStream,
Message,
Model,
StopReason,
TextContent,
ThinkingContent,
ToolCall,
} from "../types.js";
import { transformMessages } from "./utils.js";
// Anthropic-specific options
export interface AnthropicOptions extends GenerateOptions {
thinking?: {
enabled: boolean;
budgetTokens?: number;
};
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
/**
* Generate function for Anthropic API
*/
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
model: Model,
context: Context,
options: AnthropicOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
// Start async processing
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "anthropic-messages" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
};
try {
// Create Anthropic client
const client = createAnthropicClient(model, options.apiKey!);
// Convert messages
const messages = convertMessages(context.messages, model, "anthropic-messages");
// Build params
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
// Create Anthropic stream
const anthropicStream = client.client.messages.stream(
{
...params,
stream: true,
},
{
signal: options.signal,
},
);
// Emit start event
stream.push({
type: "start",
partial: output,
});
// Process Anthropic events
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
for await (const event of anthropicStream) {
if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
currentBlock = {
type: "text",
text: "",
};
output.content.push(currentBlock);
stream.push({ type: "text_start", partial: output });
} else if (event.content_block.type === "thinking") {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: "",
};
output.content.push(currentBlock);
stream.push({ type: "thinking_start", partial: output });
} else if (event.content_block.type === "tool_use") {
// We wait for the full tool use to be streamed
currentBlock = {
type: "toolCall",
id: event.content_block.id,
name: event.content_block.name,
arguments: event.content_block.input as Record<string, any>,
partialJson: "",
};
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
if (currentBlock && currentBlock.type === "text") {
currentBlock.text += event.delta.text;
stream.push({
type: "text_delta",
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
if (currentBlock && currentBlock.type === "toolCall") {
currentBlock.partialJson += event.delta.partial_json;
}
} else if (event.delta.type === "signature_delta") {
if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
currentBlock.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
} else if (currentBlock.type === "thinking") {
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
} else if (currentBlock.type === "toolCall") {
const finalToolCall: ToolCall = {
type: "toolCall",
id: currentBlock.id,
name: currentBlock.name,
arguments: JSON.parse(currentBlock.partialJson),
};
output.content.push(finalToolCall);
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
}
currentBlock = null;
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
output.stopReason = mapStopReason(event.delta.stop_reason);
}
output.usage.input += event.usage.input_tokens || 0;
output.usage.output += event.usage.output_tokens || 0;
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
calculateCost(model, output.usage);
}
}
// Emit done event with final message
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", error: output.error, partial: output });
stream.end();
}
})();
return stream;
};
// Helper to create Anthropic client
interface AnthropicClientWrapper {
client: Anthropic;
isOAuthToken: boolean;
}
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
if (apiKey.includes("sk-ant-oat")) {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
};
// Clear the env var if we're in Node.js to prevent SDK from using it
if (typeof process !== "undefined" && process.env) {
process.env.ANTHROPIC_API_KEY = undefined;
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
defaultHeaders,
dangerouslyAllowBrowser: true,
});
return { client, isOAuthToken: true };
} else {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
};
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders,
});
return { client, isOAuthToken: false };
}
}
// Build Anthropic API params
function buildAnthropicParams(
model: Model,
context: Context,
options: AnthropicOptions,
messages: MessageParam[],
isOAuthToken: boolean,
): MessageCreateParamsStreaming {
const params: MessageCreateParamsStreaming = {
model: model.id,
messages,
max_tokens: options.maxTokens || model.maxTokens,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: {
type: "ephemeral",
},
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: context.systemPrompt,
cache_control: {
type: "ephemeral",
},
});
}
} else if (context.systemPrompt) {
params.system = context.systemPrompt;
}
if (options.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools);
}
// Only enable thinking if the model supports it
if (options.thinking?.enabled && model.reasoning) {
params.thinking = {
type: "enabled",
budget_tokens: options.thinking.budgetTokens || 1024,
};
}
if (options.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
return params;
}
// Convert messages to Anthropic format
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, model, api);
for (const msg of transformedMessages) {
if (msg.role === "user") {
// Handle both string and array content
if (typeof msg.content === "string") {
params.push({
role: "user",
content: msg.content,
});
} else {
// Convert array content to Anthropic format
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: item.text,
};
} else {
// Image content
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: item.data,
},
};
}
});
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
params.push({
role: "user",
content: filteredBlocks,
});
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
blocks.push({
type: "text",
text: block.text,
});
} else if (block.type === "thinking") {
blocks.push({
type: "thinking",
thinking: block.thinking,
signature: block.thinkingSignature || "",
});
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: block.name,
input: block.arguments,
});
}
}
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
params.push({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: msg.toolCallId,
content: msg.content,
is_error: msg.isError,
},
],
});
}
}
return params;
}
// Convert tools to Anthropic format
function convertTools(tools: Context["tools"]): Tool[] {
if (!tools) return [];
return tools.map((tool) => ({
name: tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: tool.parameters.properties || {},
required: tool.parameters.required || [],
},
}));
}
// Map Anthropic stop reason to our StopReason type
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
switch (reason) {
case "end_turn":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
case "refusal":
return "safety";
case "pause_turn": // Stop is good enough -> resubmit
return "stop";
case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen
default:
return "stop";
}
}

View file

@ -3,91 +3,46 @@ import type {
ContentBlockParam, ContentBlockParam,
MessageCreateParamsStreaming, MessageCreateParamsStreaming,
MessageParam, MessageParam,
Tool,
} from "@anthropic-ai/sdk/resources/messages.js"; } from "@anthropic-ai/sdk/resources/messages.js";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js"; import { calculateCost } from "../models.js";
import type { import type {
Api,
AssistantMessage, AssistantMessage,
Context, Context,
LLM, GenerateFunction,
LLMOptions, GenerateOptions,
GenerateStream,
Message, Message,
Model, Model,
StopReason, StopReason,
TextContent, TextContent,
ThinkingContent, ThinkingContent,
Tool,
ToolCall, ToolCall,
} from "../types.js"; } from "../types.js";
import { transformMessages } from "./utils.js"; import { transformMessages } from "./utils.js";
export interface AnthropicLLMOptions extends LLMOptions { export interface AnthropicOptions extends GenerateOptions {
thinking?: { thinkingEnabled?: boolean;
enabled: boolean; thinkingBudgetTokens?: number;
budgetTokens?: number;
};
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string }; toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
} }
export class AnthropicLLM implements LLM<AnthropicLLMOptions> { export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
private client: Anthropic; model: Model<"anthropic-messages">,
private modelInfo: Model; context: Context,
private isOAuthToken: boolean = false; options?: AnthropicOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
constructor(model: Model, apiKey?: string) { (async () => {
if (!apiKey) {
if (!process.env.ANTHROPIC_API_KEY) {
throw new Error(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.ANTHROPIC_API_KEY;
}
if (apiKey.includes("sk-ant-oat")) {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
};
// Clear the env var if we're in Node.js to prevent SDK from using it
if (typeof process !== "undefined" && process.env) {
process.env.ANTHROPIC_API_KEY = undefined;
}
this.client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
defaultHeaders,
dangerouslyAllowBrowser: true,
});
this.isOAuthToken = true;
} else {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
};
this.client = new Anthropic({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true, defaultHeaders });
this.isOAuthToken = false;
}
this.modelInfo = model;
}
getModel(): Model {
return this.modelInfo;
}
getApi(): string {
return "anthropic-messages";
}
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(), api: "anthropic-messages" as Api,
provider: this.modelInfo.provider, provider: model.provider,
model: this.modelInfo.id, model: model.id,
usage: { usage: {
input: 0, input: 0,
output: 0, output: 0,
@ -99,77 +54,14 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
}; };
try { try {
const messages = this.convertMessages(context.messages); const { client, isOAuthToken } = createClient(model, options?.apiKey!);
const params = buildParams(model, context, isOAuthToken, options);
const params: MessageCreateParamsStreaming = { const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
model: this.modelInfo.id, stream.push({ type: "start", partial: output });
messages,
max_tokens: options?.maxTokens || 4096,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (this.isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: {
type: "ephemeral",
},
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: context.systemPrompt,
cache_control: {
type: "ephemeral",
},
});
}
} else if (context.systemPrompt) {
params.system = context.systemPrompt;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (context.tools) {
params.tools = this.convertTools(context.tools);
}
// Only enable thinking if the model supports it
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
params.thinking = {
type: "enabled",
budget_tokens: options.thinking.budgetTokens || 1024,
};
}
if (options?.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
const stream = this.client.messages.stream(
{
...params,
stream: true,
},
{
signal: options?.signal,
},
);
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null; let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
for await (const event of stream) {
for await (const event of anthropicStream) {
if (event.type === "content_block_start") { if (event.type === "content_block_start") {
if (event.content_block.type === "text") { if (event.content_block.type === "text") {
currentBlock = { currentBlock = {
@ -177,7 +69,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
text: "", text: "",
}; };
output.content.push(currentBlock); output.content.push(currentBlock);
options?.onEvent?.({ type: "text_start" }); stream.push({ type: "text_start", partial: output });
} else if (event.content_block.type === "thinking") { } else if (event.content_block.type === "thinking") {
currentBlock = { currentBlock = {
type: "thinking", type: "thinking",
@ -185,9 +77,9 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
thinkingSignature: "", thinkingSignature: "",
}; };
output.content.push(currentBlock); output.content.push(currentBlock);
options?.onEvent?.({ type: "thinking_start" }); stream.push({ type: "thinking_start", partial: output });
} else if (event.content_block.type === "tool_use") { } else if (event.content_block.type === "tool_use") {
// We wait for the full tool use to be streamed to send the event // We wait for the full tool use to be streamed
currentBlock = { currentBlock = {
type: "toolCall", type: "toolCall",
id: event.content_block.id, id: event.content_block.id,
@ -200,15 +92,19 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
if (event.delta.type === "text_delta") { if (event.delta.type === "text_delta") {
if (currentBlock && currentBlock.type === "text") { if (currentBlock && currentBlock.type === "text") {
currentBlock.text += event.delta.text; currentBlock.text += event.delta.text;
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: event.delta.text }); stream.push({
type: "text_delta",
delta: event.delta.text,
partial: output,
});
} }
} else if (event.delta.type === "thinking_delta") { } else if (event.delta.type === "thinking_delta") {
if (currentBlock && currentBlock.type === "thinking") { if (currentBlock && currentBlock.type === "thinking") {
currentBlock.thinking += event.delta.thinking; currentBlock.thinking += event.delta.thinking;
options?.onEvent?.({ stream.push({
type: "thinking_delta", type: "thinking_delta",
content: currentBlock.thinking,
delta: event.delta.thinking, delta: event.delta.thinking,
partial: output,
}); });
} }
} else if (event.delta.type === "input_json_delta") { } else if (event.delta.type === "input_json_delta") {
@ -224,9 +120,17 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
} else if (event.type === "content_block_stop") { } else if (event.type === "content_block_stop") {
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.type === "thinking") { } else if (currentBlock.type === "thinking") {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
const finalToolCall: ToolCall = { const finalToolCall: ToolCall = {
type: "toolCall", type: "toolCall",
@ -235,150 +139,274 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
arguments: JSON.parse(currentBlock.partialJson), arguments: JSON.parse(currentBlock.partialJson),
}; };
output.content.push(finalToolCall); output.content.push(finalToolCall);
options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall }); stream.push({
type: "toolCall",
toolCall: finalToolCall,
partial: output,
});
} }
currentBlock = null; currentBlock = null;
} }
} else if (event.type === "message_delta") { } else if (event.type === "message_delta") {
if (event.delta.stop_reason) { if (event.delta.stop_reason) {
output.stopReason = this.mapStopReason(event.delta.stop_reason); output.stopReason = mapStopReason(event.delta.stop_reason);
} }
output.usage.input += event.usage.input_tokens || 0; output.usage.input += event.usage.input_tokens || 0;
output.usage.output += event.usage.output_tokens || 0; output.usage.output += event.usage.output_tokens || 0;
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0; output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0; output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
calculateCost(this.modelInfo, output.usage); calculateCost(model, output.usage);
} }
} }
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); if (options?.signal?.aborted) {
return output; throw new Error("Request was aborted");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) { } catch (error) {
output.stopReason = "error"; output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error); output.error = error instanceof Error ? error.message : JSON.stringify(error);
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
return output; stream.end();
}
})();
return stream;
};
function createClient(
model: Model<"anthropic-messages">,
apiKey: string,
): { client: Anthropic; isOAuthToken: boolean } {
if (apiKey.includes("sk-ant-oat")) {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
};
// Clear the env var if we're in Node.js to prevent SDK from using it
if (typeof process !== "undefined" && process.env) {
process.env.ANTHROPIC_API_KEY = undefined;
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
defaultHeaders,
dangerouslyAllowBrowser: true,
});
return { client, isOAuthToken: true };
} else {
const defaultHeaders = {
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
};
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders,
});
return { client, isOAuthToken: false };
}
}
function buildParams(
model: Model<"anthropic-messages">,
context: Context,
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(context.messages, model),
max_tokens: options?.maxTokens || model.maxTokens,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
cache_control: {
type: "ephemeral",
},
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: context.systemPrompt,
cache_control: {
type: "ephemeral",
},
});
}
} else if (context.systemPrompt) {
params.system = context.systemPrompt;
}
if (options?.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools);
}
if (options?.thinkingEnabled && model.reasoning) {
params.thinking = {
type: "enabled",
budget_tokens: options.thinkingBudgetTokens || 1024,
};
}
if (options?.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
} }
} }
private convertMessages(messages: Message[]): MessageParam[] { return params;
const params: MessageParam[] = []; }
// Transform messages for cross-provider compatibility function convertMessages(messages: Message[], model: Model<"anthropic-messages">): MessageParam[] {
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi()); const params: MessageParam[] = [];
for (const msg of transformedMessages) { // Transform messages for cross-provider compatibility
if (msg.role === "user") { const transformedMessages = transformMessages(messages, model);
// Handle both string and array content
if (typeof msg.content === "string") { for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim().length > 0) {
params.push({ params.push({
role: "user", role: "user",
content: msg.content, content: msg.content,
}); });
} else {
// Convert array content to Anthropic format
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: item.text,
};
} else {
// Image content
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: item.data,
},
};
}
});
const filteredBlocks = !this.modelInfo?.input.includes("image")
? blocks.filter((b) => b.type !== "image")
: blocks;
params.push({
role: "user",
content: filteredBlocks,
});
} }
} else if (msg.role === "assistant") { } else {
const blocks: ContentBlockParam[] = []; const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
for (const block of msg.content) { return {
if (block.type === "text") {
blocks.push({
type: "text", type: "text",
text: block.text, text: item.text,
}); };
} else if (block.type === "thinking") { } else {
blocks.push({ return {
type: "thinking", type: "image",
thinking: block.thinking, source: {
signature: block.thinkingSignature || "", type: "base64",
}); media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
} else if (block.type === "toolCall") { data: item.data,
blocks.push({ },
type: "tool_use", };
id: block.id,
name: block.name,
input: block.arguments,
});
} }
}
params.push({
role: "assistant",
content: blocks,
}); });
} else if (msg.role === "toolResult") { let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
filteredBlocks = filteredBlocks.filter((b) => {
if (b.type === "text") {
return b.text.trim().length > 0;
}
return true;
});
if (filteredBlocks.length === 0) continue;
params.push({ params.push({
role: "user", role: "user",
content: [ content: filteredBlocks,
{
type: "tool_result",
tool_use_id: msg.toolCallId,
content: msg.content,
is_error: msg.isError,
},
],
}); });
} }
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length === 0) continue;
blocks.push({
type: "text",
text: block.text,
});
} else if (block.type === "thinking") {
if (block.thinking.trim().length === 0) continue;
blocks.push({
type: "thinking",
thinking: block.thinking,
signature: block.thinkingSignature || "",
});
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: block.name,
input: block.arguments,
});
}
}
if (blocks.length === 0) continue;
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
params.push({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: msg.toolCallId,
content: msg.content,
is_error: msg.isError,
},
],
});
} }
return params;
} }
return params;
}
private convertTools(tools: Context["tools"]): Tool[] { function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] {
if (!tools) return []; if (!tools) return [];
return tools.map((tool) => ({ return tools.map((tool) => ({
name: tool.name, name: tool.name,
description: tool.description, description: tool.description,
input_schema: { input_schema: {
type: "object" as const, type: "object" as const,
properties: tool.parameters.properties || {}, properties: tool.parameters.properties || {},
required: tool.parameters.required || [], required: tool.parameters.required || [],
}, },
})); }));
} }
private mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason { function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason {
switch (reason) { switch (reason) {
case "end_turn": case "end_turn":
return "stop"; return "stop";
case "max_tokens": case "max_tokens":
return "length"; return "length";
case "tool_use": case "tool_use":
return "toolUse"; return "toolUse";
case "refusal": case "refusal":
return "safety"; return "safety";
case "pause_turn": // Stop is good enough -> resubmit case "pause_turn": // Stop is good enough -> resubmit
return "stop"; return "stop";
case "stop_sequence": case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen return "stop"; // We don't supply stop sequences, so this should never happen
default: default: {
return "stop"; const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
} }
} }
} }

View file

@ -1,19 +1,21 @@
import { import {
type Content, type Content,
type FinishReason, FinishReason,
FunctionCallingConfigMode, FunctionCallingConfigMode,
type GenerateContentConfig, type GenerateContentConfig,
type GenerateContentParameters, type GenerateContentParameters,
GoogleGenAI, GoogleGenAI,
type Part, type Part,
} from "@google/genai"; } from "@google/genai";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js"; import { calculateCost } from "../models.js";
import type { import type {
Api,
AssistantMessage, AssistantMessage,
Context, Context,
LLM, GenerateFunction,
LLMOptions, GenerateOptions,
Message, GenerateStream,
Model, Model,
StopReason, StopReason,
TextContent, TextContent,
@ -23,7 +25,7 @@ import type {
} from "../types.js"; } from "../types.js";
import { transformMessages } from "./utils.js"; import { transformMessages } from "./utils.js";
export interface GoogleLLMOptions extends LLMOptions { export interface GoogleOptions extends GenerateOptions {
toolChoice?: "auto" | "none" | "any"; toolChoice?: "auto" | "none" | "any";
thinking?: { thinking?: {
enabled: boolean; enabled: boolean;
@ -31,38 +33,20 @@ export interface GoogleLLMOptions extends LLMOptions {
}; };
} }
export class GoogleLLM implements LLM<GoogleLLMOptions> { export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
private client: GoogleGenAI; model: Model<"google-generative-ai">,
private modelInfo: Model; context: Context,
options?: GoogleOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
constructor(model: Model, apiKey?: string) { (async () => {
if (!apiKey) {
if (!process.env.GEMINI_API_KEY) {
throw new Error(
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.GEMINI_API_KEY;
}
this.client = new GoogleGenAI({ apiKey });
this.modelInfo = model;
}
getModel(): Model {
return this.modelInfo;
}
getApi(): string {
return "google-generative-ai";
}
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(), api: "google-generative-ai" as Api,
provider: this.modelInfo.provider, provider: model.provider,
model: this.modelInfo.id, model: model.id,
usage: { usage: {
input: 0, input: 0,
output: 0, output: 0,
@ -72,70 +56,20 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
}, },
stopReason: "stop", stopReason: "stop",
}; };
try { try {
const contents = this.convertMessages(context.messages); const client = createClient(options?.apiKey);
const params = buildParams(model, context, options);
const googleStream = await client.models.generateContentStream(params);
// Build generation config stream.push({ type: "start", partial: output });
const generationConfig: GenerateContentConfig = {};
if (options?.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options?.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
// Build the config object
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
...(context.tools && { tools: this.convertTools(context.tools) }),
};
// Add tool config if needed
if (context.tools && options?.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: this.mapToolChoice(options.toolChoice),
},
};
}
// Add thinking config if enabled and model supports it
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
config.thinkingConfig = {
includeThoughts: true,
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
};
}
// Abort signal
if (options?.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
// Build the request parameters
const params: GenerateContentParameters = {
model: this.modelInfo.id,
contents,
config,
};
const stream = await this.client.models.generateContentStream(params);
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
let currentBlock: TextContent | ThinkingContent | null = null; let currentBlock: TextContent | ThinkingContent | null = null;
for await (const chunk of stream) { for await (const chunk of googleStream) {
// Extract parts from the chunk
const candidate = chunk.candidates?.[0]; const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) { if (candidate?.content?.parts) {
for (const part of candidate.content.parts) { for (const part of candidate.content.parts) {
if (part.text !== undefined) { if (part.text !== undefined) {
const isThinking = part.thought === true; const isThinking = part.thought === true;
// Check if we need to switch blocks
if ( if (
!currentBlock || !currentBlock ||
(isThinking && currentBlock.type !== "thinking") || (isThinking && currentBlock.type !== "thinking") ||
@ -143,50 +77,60 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
) { ) {
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else { } else {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} }
} }
// Start new block
if (isThinking) { if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined }; currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
options?.onEvent?.({ type: "thinking_start" }); stream.push({ type: "thinking_start", partial: output });
} else { } else {
currentBlock = { type: "text", text: "" }; currentBlock = { type: "text", text: "" };
options?.onEvent?.({ type: "text_start" }); stream.push({ type: "text_start", partial: output });
} }
output.content.push(currentBlock); output.content.push(currentBlock);
} }
// Append content to current block
if (currentBlock.type === "thinking") { if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text; currentBlock.thinking += part.text;
currentBlock.thinkingSignature = part.thoughtSignature; currentBlock.thinkingSignature = part.thoughtSignature;
options?.onEvent?.({ stream.push({
type: "thinking_delta", type: "thinking_delta",
content: currentBlock.thinking,
delta: part.text, delta: part.text,
partial: output,
}); });
} else { } else {
currentBlock.text += part.text; currentBlock.text += part.text;
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: part.text }); stream.push({ type: "text_delta", delta: part.text, partial: output });
} }
} }
// Handle function calls
if (part.functionCall) { if (part.functionCall) {
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else { } else {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} }
currentBlock = null; currentBlock = null;
} }
// Add tool call
const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`; const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`;
const toolCall: ToolCall = { const toolCall: ToolCall = {
type: "toolCall", type: "toolCall",
@ -195,21 +139,18 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
arguments: part.functionCall.args as Record<string, any>, arguments: part.functionCall.args as Record<string, any>,
}; };
output.content.push(toolCall); output.content.push(toolCall);
options?.onEvent?.({ type: "toolCall", toolCall }); stream.push({ type: "toolCall", toolCall, partial: output });
} }
} }
} }
// Map finish reason
if (candidate?.finishReason) { if (candidate?.finishReason) {
output.stopReason = this.mapStopReason(candidate.finishReason); output.stopReason = mapStopReason(candidate.finishReason);
// Check if we have tool calls in blocks
if (output.content.some((b) => b.type === "toolCall")) { if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse"; output.stopReason = "toolUse";
} }
} }
// Capture usage metadata if available
if (chunk.usageMetadata) { if (chunk.usageMetadata) {
output.usage = { output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0, input: chunk.usageMetadata.promptTokenCount || 0,
@ -225,166 +166,223 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
total: 0, total: 0,
}, },
}; };
calculateCost(this.modelInfo, output.usage); calculateCost(model, output.usage);
} }
} }
// Finalize last block
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({ type: "text_end", content: currentBlock.text, partial: output });
} else { } else {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
} }
} }
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); stream.push({ type: "done", reason: output.stopReason, message: output });
return output; stream.end();
} catch (error) { } catch (error) {
output.stopReason = "error"; output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error); output.error = error instanceof Error ? error.message : JSON.stringify(error);
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
return output; stream.end();
} }
})();
return stream;
};
function createClient(apiKey?: string): GoogleGenAI {
if (!apiKey) {
if (!process.env.GEMINI_API_KEY) {
throw new Error(
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.GEMINI_API_KEY;
}
return new GoogleGenAI({ apiKey });
}
function buildParams(
model: Model<"google-generative-ai">,
context: Context,
options: GoogleOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
} }
private convertMessages(messages: Message[]): Content[] { const config: GenerateContentConfig = {
const contents: Content[] = []; ...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
...(context.tools && { tools: convertTools(context.tools) }),
};
// Transform messages for cross-provider compatibility if (context.tools && options.toolChoice) {
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi()); config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
}
for (const msg of transformedMessages) { if (options.thinking?.enabled && model.reasoning) {
if (msg.role === "user") { config.thinkingConfig = {
// Handle both string and array content includeThoughts: true,
if (typeof msg.content === "string") { ...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
contents.push({ };
role: "user", }
parts: [{ text: msg.content }],
});
} else {
// Convert array content to Google format
const parts: Part[] = msg.content.map((item) => {
if (item.type === "text") {
return { text: item.text };
} else {
// Image content - Google uses inlineData
return {
inlineData: {
mimeType: item.mimeType,
data: item.data,
},
};
}
});
const filteredParts = !this.modelInfo?.input.includes("image")
? parts.filter((p) => p.text !== undefined)
: parts;
contents.push({
role: "user",
parts: filteredParts,
});
}
} else if (msg.role === "assistant") {
const parts: Part[] = [];
// Process content blocks if (options.signal) {
for (const block of msg.content) { if (options.signal.aborted) {
if (block.type === "text") { throw new Error("Request aborted");
parts.push({ text: block.text }); }
} else if (block.type === "thinking") { config.abortSignal = options.signal;
const thinkingPart: Part = { }
thought: true,
thoughtSignature: block.thinkingSignature,
text: block.thinking,
};
parts.push(thinkingPart);
} else if (block.type === "toolCall") {
parts.push({
functionCall: {
id: block.id,
name: block.name,
args: block.arguments,
},
});
}
}
if (parts.length > 0) { const params: GenerateContentParameters = {
contents.push({ model: model.id,
role: "model", contents,
parts, config,
}); };
}
} else if (msg.role === "toolResult") { return params;
}
function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] {
const contents: Content[] = [];
const transformedMessages = transformMessages(context.messages, model);
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
contents.push({ contents.push({
role: "user", role: "user",
parts: [ parts: [{ text: msg.content }],
{ });
functionResponse: { } else {
id: msg.toolCallId, const parts: Part[] = msg.content.map((item) => {
name: msg.toolName, if (item.type === "text") {
response: { return { text: item.text };
result: msg.content, } else {
isError: msg.isError, return {
}, inlineData: {
mimeType: item.mimeType,
data: item.data,
}, },
}, };
], }
});
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
if (filteredParts.length === 0) continue;
contents.push({
role: "user",
parts: filteredParts,
}); });
} }
} } else if (msg.role === "assistant") {
const parts: Part[] = [];
return contents; for (const block of msg.content) {
} if (block.type === "text") {
parts.push({ text: block.text });
} else if (block.type === "thinking") {
const thinkingPart: Part = {
thought: true,
thoughtSignature: block.thinkingSignature,
text: block.thinking,
};
parts.push(thinkingPart);
} else if (block.type === "toolCall") {
parts.push({
functionCall: {
id: block.id,
name: block.name,
args: block.arguments,
},
});
}
}
private convertTools(tools: Tool[]): any[] { if (parts.length === 0) continue;
return [ contents.push({
{ role: "model",
functionDeclarations: tools.map((tool) => ({ parts,
name: tool.name, });
description: tool.description, } else if (msg.role === "toolResult") {
parameters: tool.parameters, contents.push({
})), role: "user",
}, parts: [
]; {
} functionResponse: {
id: msg.toolCallId,
private mapToolChoice(choice: string): FunctionCallingConfigMode { name: msg.toolName,
switch (choice) { response: {
case "auto": result: msg.content,
return FunctionCallingConfigMode.AUTO; isError: msg.isError,
case "none": },
return FunctionCallingConfigMode.NONE; },
case "any": },
return FunctionCallingConfigMode.ANY; ],
default: });
return FunctionCallingConfigMode.AUTO;
} }
} }
private mapStopReason(reason: FinishReason): StopReason { return contents;
switch (reason) { }
case "STOP":
return "stop"; function convertTools(tools: Tool[]): any[] {
case "MAX_TOKENS": return [
return "length"; {
case "BLOCKLIST": functionDeclarations: tools.map((tool) => ({
case "PROHIBITED_CONTENT": name: tool.name,
case "SPII": description: tool.description,
case "SAFETY": parameters: tool.parameters,
case "IMAGE_SAFETY": })),
return "safety"; },
case "RECITATION": ];
return "safety"; }
case "FINISH_REASON_UNSPECIFIED":
case "OTHER": function mapToolChoice(choice: string): FunctionCallingConfigMode {
case "LANGUAGE": switch (choice) {
case "MALFORMED_FUNCTION_CALL": case "auto":
case "UNEXPECTED_TOOL_CALL": return FunctionCallingConfigMode.AUTO;
return "error"; case "none":
default: return FunctionCallingConfigMode.NONE;
return "stop"; case "any":
return FunctionCallingConfigMode.ANY;
default:
return FunctionCallingConfigMode.AUTO;
}
}
function mapStopReason(reason: FinishReason): StopReason {
switch (reason) {
case FinishReason.STOP:
return "stop";
case FinishReason.MAX_TOKENS:
return "length";
case FinishReason.BLOCKLIST:
case FinishReason.PROHIBITED_CONTENT:
case FinishReason.SPII:
case FinishReason.SAFETY:
case FinishReason.IMAGE_SAFETY:
case FinishReason.RECITATION:
return "safety";
case FinishReason.FINISH_REASON_UNSPECIFIED:
case FinishReason.OTHER:
case FinishReason.LANGUAGE:
case FinishReason.MALFORMED_FUNCTION_CALL:
case FinishReason.UNEXPECTED_TOOL_CALL:
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
} }
} }
} }

View file

@ -1,18 +1,20 @@
import OpenAI from "openai"; import OpenAI from "openai";
import type { import type {
ChatCompletionAssistantMessageParam,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionContentPart, ChatCompletionContentPart,
ChatCompletionContentPartImage, ChatCompletionContentPartImage,
ChatCompletionContentPartText, ChatCompletionContentPartText,
ChatCompletionMessageParam, ChatCompletionMessageParam,
} from "openai/resources/chat/completions.js"; } from "openai/resources/chat/completions.js";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js"; import { calculateCost } from "../models.js";
import type { import type {
AssistantMessage, AssistantMessage,
Context, Context,
LLM, GenerateFunction,
LLMOptions, GenerateOptions,
Message, GenerateStream,
Model, Model,
StopReason, StopReason,
TextContent, TextContent,
@ -22,43 +24,25 @@ import type {
} from "../types.js"; } from "../types.js";
import { transformMessages } from "./utils.js"; import { transformMessages } from "./utils.js";
export interface OpenAICompletionsLLMOptions extends LLMOptions { export interface OpenAICompletionsOptions extends GenerateOptions {
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } }; toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
reasoningEffort?: "low" | "medium" | "high"; reasoningEffort?: "minimal" | "low" | "medium" | "high";
} }
export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> { export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
private client: OpenAI; model: Model<"openai-completions">,
private modelInfo: Model; context: Context,
options?: OpenAICompletionsOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
constructor(model: Model, apiKey?: string) { (async () => {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
this.modelInfo = model;
}
getModel(): Model {
return this.modelInfo;
}
getApi(): string {
return "openai-completions";
}
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(), api: model.api,
provider: this.modelInfo.provider, provider: model.provider,
model: this.modelInfo.id, model: model.id,
usage: { usage: {
input: 0, input: 0,
output: 0, output: 0,
@ -70,52 +54,13 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
}; };
try { try {
const messages = this.convertMessages(request.messages, request.systemPrompt); const client = createClient(model, options?.apiKey);
const params = buildParams(model, context, options);
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
model: this.modelInfo.id, stream.push({ type: "start", partial: output });
messages,
stream: true,
stream_options: { include_usage: true },
};
// Cerebras/xAI dont like the "store" field
if (!this.modelInfo.baseUrl?.includes("cerebras.ai") && !this.modelInfo.baseUrl?.includes("api.x.ai")) {
params.store = false;
}
if (options?.maxTokens) {
params.max_completion_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (request.tools) {
params.tools = this.convertTools(request.tools);
}
if (options?.toolChoice) {
params.tool_choice = options.toolChoice;
}
if (
options?.reasoningEffort &&
this.modelInfo.reasoning &&
!this.modelInfo.id.toLowerCase().includes("grok")
) {
params.reasoning_effort = options.reasoningEffort;
}
const stream = await this.client.chat.completions.create(params, {
signal: options?.signal,
});
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null; let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
for await (const chunk of stream) { for await (const chunk of openaiStream) {
if (chunk.usage) { if (chunk.usage) {
output.usage = { output.usage = {
input: chunk.usage.prompt_tokens || 0, input: chunk.usage.prompt_tokens || 0,
@ -132,137 +77,170 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
total: 0, total: 0,
}, },
}; };
calculateCost(this.modelInfo, output.usage); calculateCost(model, output.usage);
} }
const choice = chunk.choices[0]; const choice = chunk.choices[0];
if (!choice) continue; if (!choice) continue;
// Capture finish reason
if (choice.finish_reason) { if (choice.finish_reason) {
output.stopReason = this.mapStopReason(choice.finish_reason); output.stopReason = mapStopReason(choice.finish_reason);
} }
if (choice.delta) { if (choice.delta) {
// Handle text content
if ( if (
choice.delta.content !== null && choice.delta.content !== null &&
choice.delta.content !== undefined && choice.delta.content !== undefined &&
choice.delta.content.length > 0 choice.delta.content.length > 0
) { ) {
// Check if we need to switch to text block
if (!currentBlock || currentBlock.type !== "text") { if (!currentBlock || currentBlock.type !== "text") {
// Save current block if exists
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "thinking") { if (currentBlock.type === "thinking") {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
delete currentBlock.partialArgs; delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); stream.push({
type: "toolCall",
toolCall: currentBlock as ToolCall,
partial: output,
});
} }
} }
// Start new text block
currentBlock = { type: "text", text: "" }; currentBlock = { type: "text", text: "" };
output.content.push(currentBlock); output.content.push(currentBlock);
options?.onEvent?.({ type: "text_start" }); stream.push({ type: "text_start", partial: output });
} }
// Append to text block
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
currentBlock.text += choice.delta.content; currentBlock.text += choice.delta.content;
options?.onEvent?.({ stream.push({
type: "text_delta", type: "text_delta",
content: currentBlock.text,
delta: choice.delta.content, delta: choice.delta.content,
partial: output,
}); });
} }
} }
// Handle reasoning_content field // Some endpoints return reasoning in reasoning_content (llama.cpp)
if ( if (
(choice.delta as any).reasoning_content !== null && (choice.delta as any).reasoning_content !== null &&
(choice.delta as any).reasoning_content !== undefined && (choice.delta as any).reasoning_content !== undefined &&
(choice.delta as any).reasoning_content.length > 0 (choice.delta as any).reasoning_content.length > 0
) { ) {
// Check if we need to switch to thinking block
if (!currentBlock || currentBlock.type !== "thinking") { if (!currentBlock || currentBlock.type !== "thinking") {
// Save current block if exists
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
delete currentBlock.partialArgs; delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); stream.push({
type: "toolCall",
toolCall: currentBlock as ToolCall,
partial: output,
});
} }
} }
// Start new thinking block currentBlock = {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" }; type: "thinking",
thinking: "",
thinkingSignature: "reasoning_content",
};
output.content.push(currentBlock); output.content.push(currentBlock);
options?.onEvent?.({ type: "thinking_start" }); stream.push({ type: "thinking_start", partial: output });
} }
// Append to thinking block
if (currentBlock.type === "thinking") { if (currentBlock.type === "thinking") {
const delta = (choice.delta as any).reasoning_content; const delta = (choice.delta as any).reasoning_content;
currentBlock.thinking += delta; currentBlock.thinking += delta;
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta }); stream.push({
type: "thinking_delta",
delta,
partial: output,
});
} }
} }
// Handle reasoning field // Some endpoints return reasoning in reasining (ollama, xAI, ...)
if ( if (
(choice.delta as any).reasoning !== null && (choice.delta as any).reasoning !== null &&
(choice.delta as any).reasoning !== undefined && (choice.delta as any).reasoning !== undefined &&
(choice.delta as any).reasoning.length > 0 (choice.delta as any).reasoning.length > 0
) { ) {
// Check if we need to switch to thinking block
if (!currentBlock || currentBlock.type !== "thinking") { if (!currentBlock || currentBlock.type !== "thinking") {
// Save current block if exists
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
delete currentBlock.partialArgs; delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); stream.push({
type: "toolCall",
toolCall: currentBlock as ToolCall,
partial: output,
});
} }
} }
// Start new thinking block currentBlock = {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" }; type: "thinking",
thinking: "",
thinkingSignature: "reasoning",
};
output.content.push(currentBlock); output.content.push(currentBlock);
options?.onEvent?.({ type: "thinking_start" }); stream.push({ type: "thinking_start", partial: output });
} }
// Append to thinking block
if (currentBlock.type === "thinking") { if (currentBlock.type === "thinking") {
const delta = (choice.delta as any).reasoning; const delta = (choice.delta as any).reasoning;
currentBlock.thinking += delta; currentBlock.thinking += delta;
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta }); stream.push({ type: "thinking_delta", delta, partial: output });
} }
} }
// Handle tool calls
if (choice?.delta?.tool_calls) { if (choice?.delta?.tool_calls) {
for (const toolCall of choice.delta.tool_calls) { for (const toolCall of choice.delta.tool_calls) {
// Check if we need a new tool call block
if ( if (
!currentBlock || !currentBlock ||
currentBlock.type !== "toolCall" || currentBlock.type !== "toolCall" ||
(toolCall.id && currentBlock.id !== toolCall.id) (toolCall.id && currentBlock.id !== toolCall.id)
) { ) {
// Save current block if exists
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.type === "thinking") { } else if (currentBlock.type === "thinking") {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
delete currentBlock.partialArgs; delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); stream.push({
type: "toolCall",
toolCall: currentBlock as ToolCall,
partial: output,
});
} }
} }
// Start new tool call block
currentBlock = { currentBlock = {
type: "toolCall", type: "toolCall",
id: toolCall.id || "", id: toolCall.id || "",
@ -273,7 +251,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
output.content.push(currentBlock); output.content.push(currentBlock);
} }
// Accumulate tool call data
if (currentBlock.type === "toolCall") { if (currentBlock.type === "toolCall") {
if (toolCall.id) currentBlock.id = toolCall.id; if (toolCall.id) currentBlock.id = toolCall.id;
if (toolCall.function?.name) currentBlock.name = toolCall.function.name; if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
@ -286,16 +263,27 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
} }
} }
// Save final block if exists
if (currentBlock) { if (currentBlock) {
if (currentBlock.type === "text") { if (currentBlock.type === "text") {
options?.onEvent?.({ type: "text_end", content: currentBlock.text }); stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.type === "thinking") { } else if (currentBlock.type === "thinking") {
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking }); stream.push({
type: "thinking_end",
content: currentBlock.thinking,
partial: output,
});
} else if (currentBlock.type === "toolCall") { } else if (currentBlock.type === "toolCall") {
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}"); currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
delete currentBlock.partialArgs; delete currentBlock.partialArgs;
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall }); stream.push({
type: "toolCall",
toolCall: currentBlock as ToolCall,
partial: output,
});
} }
} }
@ -303,141 +291,188 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
throw new Error("Request was aborted"); throw new Error("Request was aborted");
} }
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
return output; return output;
} catch (error) { } catch (error) {
// Update output with error information
output.stopReason = "error"; output.stopReason = "error";
output.error = error instanceof Error ? error.message : String(error); output.error = error instanceof Error ? error.message : String(error);
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
return output; stream.end();
} }
})();
return stream;
};
function createClient(model: Model<"openai-completions">, apiKey?: string) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
}
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
const messages = convertMessages(model, context);
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: model.id,
messages,
stream: true,
stream_options: { include_usage: true },
};
// Cerebras/xAI dont like the "store" field
if (!model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai")) {
params.store = false;
} }
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] { if (options?.maxTokens) {
const params: ChatCompletionMessageParam[] = []; params.max_completion_tokens = options?.maxTokens;
}
// Transform messages for cross-provider compatibility if (options?.temperature !== undefined) {
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi()); params.temperature = options?.temperature;
}
// Add system prompt if provided if (context.tools) {
if (systemPrompt) { params.tools = convertTools(context.tools);
// Cerebras/xAi don't like the "developer" role }
const useDeveloperRole =
this.modelInfo.reasoning &&
!this.modelInfo.baseUrl?.includes("cerebras.ai") &&
!this.modelInfo.baseUrl?.includes("api.x.ai");
const role = useDeveloperRole ? "developer" : "system";
params.push({ role: role, content: systemPrompt });
}
// Convert messages if (options?.toolChoice) {
for (const msg of transformedMessages) { params.tool_choice = options.toolChoice;
if (msg.role === "user") { }
// Handle both string and array content
if (typeof msg.content === "string") {
params.push({
role: "user",
content: msg.content,
});
} else {
// Convert array content to OpenAI format
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
if (item.type === "text") {
return {
type: "text",
text: item.text,
} satisfies ChatCompletionContentPartText;
} else {
// Image content - OpenAI uses data URLs
return {
type: "image_url",
image_url: {
url: `data:${item.mimeType};base64,${item.data}`,
},
} satisfies ChatCompletionContentPartImage;
}
});
const filteredContent = !this.modelInfo?.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
params.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
const assistantMsg: ChatCompletionMessageParam = {
role: "assistant",
content: null,
};
// Build content from blocks // Grok models don't like reasoning_effort
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[]; if (options?.reasoningEffort && model.reasoning && !model.id.toLowerCase().includes("grok")) {
if (textBlocks.length > 0) { params.reasoning_effort = options.reasoningEffort;
assistantMsg.content = textBlocks.map((b) => b.text).join(""); }
}
// Handle thinking blocks for llama.cpp server + gpt-oss return params;
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[]; }
if (thinkingBlocks.length > 0) {
// Use the signature from the first thinking block if available
const signature = thinkingBlocks[0].thinkingSignature;
if (signature && signature.length > 0) {
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
}
}
// Handle tool calls function convertMessages(model: Model<"openai-completions">, context: Context): ChatCompletionMessageParam[] {
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[]; const params: ChatCompletionMessageParam[] = [];
if (toolCalls.length > 0) {
assistantMsg.tool_calls = toolCalls.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.name,
arguments: JSON.stringify(tc.arguments),
},
}));
}
params.push(assistantMsg); const transformedMessages = transformMessages(context.messages, model);
} else if (msg.role === "toolResult") {
if (context.systemPrompt) {
// Cerebras/xAi don't like the "developer" role
const useDeveloperRole =
model.reasoning && !model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai");
const role = useDeveloperRole ? "developer" : "system";
params.push({ role: role, content: context.systemPrompt });
}
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
params.push({ params.push({
role: "tool", role: "user",
content: msg.content, content: msg.content,
tool_call_id: msg.toolCallId, });
} else {
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
if (item.type === "text") {
return {
type: "text",
text: item.text,
} satisfies ChatCompletionContentPartText;
} else {
return {
type: "image_url",
image_url: {
url: `data:${item.mimeType};base64,${item.data}`,
},
} satisfies ChatCompletionContentPartImage;
}
});
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
if (filteredContent.length === 0) continue;
params.push({
role: "user",
content: filteredContent,
}); });
} }
} else if (msg.role === "assistant") {
const assistantMsg: ChatCompletionAssistantMessageParam = {
role: "assistant",
content: null,
};
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
if (textBlocks.length > 0) {
assistantMsg.content = textBlocks.map((b) => b.text).join("");
}
// Handle thinking blocks for llama.cpp server + gpt-oss
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
if (thinkingBlocks.length > 0) {
// Use the signature from the first thinking block if available
const signature = thinkingBlocks[0].thinkingSignature;
if (signature && signature.length > 0) {
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
}
}
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
if (toolCalls.length > 0) {
assistantMsg.tool_calls = toolCalls.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.name,
arguments: JSON.stringify(tc.arguments),
},
}));
}
params.push(assistantMsg);
} else if (msg.role === "toolResult") {
params.push({
role: "tool",
content: msg.content,
tool_call_id: msg.toolCallId,
});
} }
return params;
} }
private convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] { return params;
return tools.map((tool) => ({ }
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
},
}));
}
private mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"] | null): StopReason { function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
switch (reason) { return tools.map((tool) => ({
case "stop": type: "function",
return "stop"; function: {
case "length": name: tool.name,
return "length"; description: tool.description,
case "function_call": parameters: tool.parameters,
case "tool_calls": },
return "toolUse"; }));
case "content_filter": }
return "safety";
default: function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
return "stop"; if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
return "length";
case "function_call":
case "tool_calls":
return "toolUse";
case "content_filter":
return "safety";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
} }
} }
} }

View file

@ -10,58 +10,49 @@ import type {
ResponseOutputMessage, ResponseOutputMessage,
ResponseReasoningItem, ResponseReasoningItem,
} from "openai/resources/responses/responses.js"; } from "openai/resources/responses/responses.js";
import { QueuedGenerateStream } from "../generate.js";
import { calculateCost } from "../models.js"; import { calculateCost } from "../models.js";
import type { import type {
Api,
AssistantMessage, AssistantMessage,
Context, Context,
LLM, GenerateFunction,
LLMOptions, GenerateOptions,
GenerateStream,
Message, Message,
Model, Model,
StopReason, StopReason,
TextContent, TextContent,
ThinkingContent,
Tool, Tool,
ToolCall, ToolCall,
} from "../types.js"; } from "../types.js";
import { transformMessages } from "./utils.js"; import { transformMessages } from "./utils.js";
export interface OpenAIResponsesLLMOptions extends LLMOptions { // OpenAI Responses-specific options
export interface OpenAIResponsesOptions extends GenerateOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high"; reasoningEffort?: "minimal" | "low" | "medium" | "high";
reasoningSummary?: "auto" | "detailed" | "concise" | null; reasoningSummary?: "auto" | "detailed" | "concise" | null;
} }
export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> { /**
private client: OpenAI; * Generate function for OpenAI Responses API
private modelInfo: Model; */
export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
): GenerateStream => {
const stream = new QueuedGenerateStream();
constructor(model: Model, apiKey?: string) { // Start async processing
if (!apiKey) { (async () => {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
this.modelInfo = model;
}
getModel(): Model {
return this.modelInfo;
}
getApi(): string {
return "openai-responses";
}
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
const output: AssistantMessage = { const output: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: this.getApi(), api: "openai-responses" as Api,
provider: this.modelInfo.provider, provider: model.provider,
model: this.modelInfo.id, model: model.id,
usage: { usage: {
input: 0, input: 0,
output: 0, output: 0,
@ -71,77 +62,31 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
}, },
stopReason: "stop", stopReason: "stop",
}; };
try { try {
const input = this.convertToInput(request.messages, request.systemPrompt); // Create OpenAI client
const client = createClient(model, options?.apiKey);
const params = buildParams(model, context, options);
const openaiStream = await client.responses.create(params, { signal: options?.signal });
stream.push({ type: "start", partial: output });
const params: ResponseCreateParamsStreaming = {
model: this.modelInfo.id,
input,
stream: true,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (request.tools) {
params.tools = this.convertTools(request.tools);
}
// Add reasoning options for models that support it
if (this.modelInfo?.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
params.reasoning = {
effort: this.modelInfo.name.startsWith("gpt-5") ? "minimal" : null,
summary: null,
};
if (this.modelInfo.name.startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
input.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
const stream = await this.client.responses.create(params, {
signal: options?.signal,
});
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = [];
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null; let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
let currentBlock: ThinkingContent | TextContent | ToolCall | null = null;
for await (const event of stream) { for await (const event of openaiStream) {
// Handle output item start // Handle output item start
if (event.type === "response.output_item.added") { if (event.type === "response.output_item.added") {
const item = event.item; const item = event.item;
if (item.type === "reasoning") { if (item.type === "reasoning") {
options?.onEvent?.({ type: "thinking_start" });
outputItems.push(item);
currentItem = item; currentItem = item;
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", partial: output });
} else if (item.type === "message") { } else if (item.type === "message") {
options?.onEvent?.({ type: "text_start" });
outputItems.push(item);
currentItem = item; currentItem = item;
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", partial: output });
} }
} }
// Handle reasoning summary deltas // Handle reasoning summary deltas
@ -151,30 +96,42 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
currentItem.summary.push(event.part); currentItem.summary.push(event.part);
} }
} else if (event.type === "response.reasoning_summary_text.delta") { } else if (event.type === "response.reasoning_summary_text.delta") {
if (currentItem && currentItem.type === "reasoning") { if (
currentItem &&
currentItem.type === "reasoning" &&
currentBlock &&
currentBlock.type === "thinking"
) {
currentItem.summary = currentItem.summary || []; currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1]; const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) { if (lastPart) {
currentBlock.thinking += event.delta;
lastPart.text += event.delta; lastPart.text += event.delta;
options?.onEvent?.({ stream.push({
type: "thinking_delta", type: "thinking_delta",
content: currentItem.summary.map((s) => s.text).join("\n\n"),
delta: event.delta, delta: event.delta,
partial: output,
}); });
} }
} }
} }
// Add a new line between summary parts (hack...) // Add a new line between summary parts (hack...)
else if (event.type === "response.reasoning_summary_part.done") { else if (event.type === "response.reasoning_summary_part.done") {
if (currentItem && currentItem.type === "reasoning") { if (
currentItem &&
currentItem.type === "reasoning" &&
currentBlock &&
currentBlock.type === "thinking"
) {
currentItem.summary = currentItem.summary || []; currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1]; const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) { if (lastPart) {
currentBlock.thinking += "\n\n";
lastPart.text += "\n\n"; lastPart.text += "\n\n";
options?.onEvent?.({ stream.push({
type: "thinking_delta", type: "thinking_delta",
content: currentItem.summary.map((s) => s.text).join("\n\n"),
delta: "\n\n", delta: "\n\n",
partial: output,
}); });
} }
} }
@ -186,30 +143,28 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
currentItem.content.push(event.part); currentItem.content.push(event.part);
} }
} else if (event.type === "response.output_text.delta") { } else if (event.type === "response.output_text.delta") {
if (currentItem && currentItem.type === "message") { if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
const lastPart = currentItem.content[currentItem.content.length - 1]; const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart && lastPart.type === "output_text") { if (lastPart && lastPart.type === "output_text") {
currentBlock.text += event.delta;
lastPart.text += event.delta; lastPart.text += event.delta;
options?.onEvent?.({ stream.push({
type: "text_delta", type: "text_delta",
content: currentItem.content
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
.join(""),
delta: event.delta, delta: event.delta,
partial: output,
}); });
} }
} }
} else if (event.type === "response.refusal.delta") { } else if (event.type === "response.refusal.delta") {
if (currentItem && currentItem.type === "message") { if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
const lastPart = currentItem.content[currentItem.content.length - 1]; const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart && lastPart.type === "refusal") { if (lastPart && lastPart.type === "refusal") {
currentBlock.text += event.delta;
lastPart.refusal += event.delta; lastPart.refusal += event.delta;
options?.onEvent?.({ stream.push({
type: "text_delta", type: "text_delta",
content: currentItem.content
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
.join(""),
delta: event.delta, delta: event.delta,
partial: output,
}); });
} }
} }
@ -218,14 +173,24 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
else if (event.type === "response.output_item.done") { else if (event.type === "response.output_item.done") {
const item = event.item; const item = event.item;
if (item.type === "reasoning") { if (item.type === "reasoning" && currentBlock && currentBlock.type === "thinking") {
outputItems[outputItems.length - 1] = item; // Update with final item currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || ""; currentBlock.thinkingSignature = JSON.stringify(item);
options?.onEvent?.({ type: "thinking_end", content: thinkingContent }); stream.push({
} else if (item.type === "message") { type: "thinking_end",
outputItems[outputItems.length - 1] = item; // Update with final item content: currentBlock.thinking,
const textContent = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""); partial: output,
options?.onEvent?.({ type: "text_end", content: textContent }); });
currentBlock = null;
} else if (item.type === "message" && currentBlock && currentBlock.type === "text") {
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
currentBlock.textSignature = item.id;
stream.push({
type: "text_end",
content: currentBlock.text,
partial: output,
});
currentBlock = null;
} else if (item.type === "function_call") { } else if (item.type === "function_call") {
const toolCall: ToolCall = { const toolCall: ToolCall = {
type: "toolCall", type: "toolCall",
@ -233,8 +198,8 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
name: item.name, name: item.name,
arguments: JSON.parse(item.arguments), arguments: JSON.parse(item.arguments),
}; };
options?.onEvent?.({ type: "toolCall", toolCall }); output.content.push(toolCall);
outputItems.push(item); stream.push({ type: "toolCall", toolCall, partial: output });
} }
} }
// Handle completion // Handle completion
@ -249,10 +214,10 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
}; };
} }
calculateCost(this.modelInfo, output.usage); calculateCost(model, output.usage);
// Map status to stop reason // Map status to stop reason
output.stopReason = this.mapStopReason(response?.status); output.stopReason = mapStopReason(response?.status);
if (outputItems.some((b) => b.type === "function_call") && output.stopReason === "stop") { if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
output.stopReason = "toolUse"; output.stopReason = "toolUse";
} }
} }
@ -260,173 +225,215 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
else if (event.type === "error") { else if (event.type === "error") {
output.stopReason = "error"; output.stopReason = "error";
output.error = `Code ${event.code}: ${event.message}` || "Unknown error"; output.error = `Code ${event.code}: ${event.message}` || "Unknown error";
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
stream.end();
return output; return output;
} else if (event.type === "response.failed") { } else if (event.type === "response.failed") {
output.stopReason = "error"; output.stopReason = "error";
output.error = "Unknown error"; output.error = "Unknown error";
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
stream.end();
return output; return output;
} }
} }
// Convert output items to blocks
for (const item of outputItems) {
if (item.type === "reasoning") {
output.content.push({
type: "thinking",
thinking: item.summary?.map((s: any) => s.text).join("\n\n") || "",
thinkingSignature: JSON.stringify(item), // Full item for resubmission
});
} else if (item.type === "message") {
output.content.push({
type: "text",
text: item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""),
textSignature: item.id, // ID for resubmission
});
} else if (item.type === "function_call") {
output.content.push({
type: "toolCall",
id: item.call_id + "|" + item.id,
name: item.name,
arguments: JSON.parse(item.arguments),
});
}
}
if (options?.signal?.aborted) { if (options?.signal?.aborted) {
throw new Error("Request was aborted"); throw new Error("Request was aborted");
} }
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output }); stream.push({ type: "done", reason: output.stopReason, message: output });
return output; stream.end();
} catch (error) { } catch (error) {
output.stopReason = "error"; output.stopReason = "error";
output.error = error instanceof Error ? error.message : JSON.stringify(error); output.error = error instanceof Error ? error.message : JSON.stringify(error);
options?.onEvent?.({ type: "error", error: output.error }); stream.push({ type: "error", error: output.error, partial: output });
return output; stream.end();
} }
})();
return stream;
};
function createClient(model: Model<"openai-responses">, apiKey?: string) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
}
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
const messages = convertMessages(model, context);
const params: ResponseCreateParamsStreaming = {
model: model.id,
input: messages,
stream: true,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
} }
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput { if (options?.temperature !== undefined) {
const input: ResponseInput = []; params.temperature = options?.temperature;
}
// Transform messages for cross-provider compatibility if (context.tools) {
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi()); params.tools = convertTools(context.tools);
}
// Add system prompt if provided if (model.reasoning) {
if (systemPrompt) { if (options?.reasoningEffort || options?.reasoningSummary) {
const role = this.modelInfo?.reasoning ? "developer" : "system"; params.reasoning = {
input.push({ effort: options?.reasoningEffort || "medium",
role, summary: options?.reasoningSummary || "auto",
content: systemPrompt, };
}); params.include = ["reasoning.encrypted_content"];
} } else {
params.reasoning = {
effort: model.name.startsWith("gpt-5") ? "minimal" : null,
summary: null,
};
// Convert messages if (model.name.startsWith("gpt-5")) {
for (const msg of transformedMessages) { // Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
if (msg.role === "user") { messages.push({
// Handle both string and array content role: "developer",
if (typeof msg.content === "string") { content: [
input.push({ {
role: "user", type: "input_text",
content: [{ type: "input_text", text: msg.content }], text: "# Juice: 0 !important",
}); },
} else { ],
// Convert array content to OpenAI Responses format
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
if (item.type === "text") {
return {
type: "input_text",
text: item.text,
} satisfies ResponseInputText;
} else {
// Image content - OpenAI Responses uses data URLs
return {
type: "input_image",
detail: "auto",
image_url: `data:${item.mimeType};base64,${item.data}`,
} satisfies ResponseInputImage;
}
});
const filteredContent = !this.modelInfo?.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
input.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
// Process content blocks in order
const output: ResponseInput = [];
for (const block of msg.content) {
// Do not submit thinking blocks if the completion had an error (i.e. abort)
if (block.type === "thinking" && msg.stopReason !== "error") {
// Push the full reasoning item(s) from signature
if (block.thinkingSignature) {
const reasoningItem = JSON.parse(block.thinkingSignature);
output.push(reasoningItem);
}
} else if (block.type === "text") {
const textBlock = block as TextContent;
output.push({
type: "message",
role: "assistant",
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
status: "completed",
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
} satisfies ResponseOutputMessage);
// Do not submit thinking blocks if the completion had an error (i.e. abort)
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
const toolCall = block as ToolCall;
output.push({
type: "function_call",
id: toolCall.id.split("|")[1], // Extract original ID
call_id: toolCall.id.split("|")[0], // Extract call session ID
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
});
}
}
// Add all output items to input
input.push(...output);
} else if (msg.role === "toolResult") {
// Tool results are sent as function_call_output
input.push({
type: "function_call_output",
call_id: msg.toolCallId.split("|")[0], // Extract call session ID
output: msg.content,
}); });
} }
} }
return input;
} }
private convertTools(tools: Tool[]): OpenAITool[] { return params;
return tools.map((tool) => ({ }
type: "function",
name: tool.name, function convertMessages(model: Model<"openai-responses">, context: Context): ResponseInput {
description: tool.description, const messages: ResponseInput = [];
parameters: tool.parameters,
strict: null, const transformedMessages = transformMessages(context.messages, model);
}));
if (context.systemPrompt) {
const role = model.reasoning ? "developer" : "system";
messages.push({
role,
content: context.systemPrompt,
});
} }
private mapStopReason(status: string | undefined): StopReason { for (const msg of transformedMessages) {
switch (status) { if (msg.role === "user") {
case "completed": if (typeof msg.content === "string") {
return "stop"; messages.push({
case "incomplete": role: "user",
return "length"; content: [{ type: "input_text", text: msg.content }],
case "failed": });
case "cancelled": } else {
return "error"; const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
default: if (item.type === "text") {
return "stop"; return {
type: "input_text",
text: item.text,
} satisfies ResponseInputText;
} else {
return {
type: "input_image",
detail: "auto",
image_url: `data:${item.mimeType};base64,${item.data}`,
} satisfies ResponseInputImage;
}
});
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
if (filteredContent.length === 0) continue;
messages.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
const output: ResponseInput = [];
for (const block of msg.content) {
// Do not submit thinking blocks if the completion had an error (i.e. abort)
if (block.type === "thinking" && msg.stopReason !== "error") {
if (block.thinkingSignature) {
const reasoningItem = JSON.parse(block.thinkingSignature);
output.push(reasoningItem);
}
} else if (block.type === "text") {
const textBlock = block as TextContent;
output.push({
type: "message",
role: "assistant",
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
status: "completed",
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
} satisfies ResponseOutputMessage);
// Do not submit toolcall blocks if the completion had an error (i.e. abort)
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
const toolCall = block as ToolCall;
output.push({
type: "function_call",
id: toolCall.id.split("|")[1],
call_id: toolCall.id.split("|")[0],
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
});
}
}
if (output.length === 0) continue;
messages.push(...output);
} else if (msg.role === "toolResult") {
messages.push({
type: "function_call_output",
call_id: msg.toolCallId.split("|")[0],
output: msg.content,
});
}
}
return messages;
}
function convertTools(tools: Tool[]): OpenAITool[] {
return tools.map((tool) => ({
type: "function",
name: tool.name,
description: tool.description,
parameters: tool.parameters,
strict: null,
}));
}
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
if (!status) return "stop";
switch (status) {
case "completed":
return "stop";
case "incomplete":
return "length";
case "failed":
case "cancelled":
return "error";
// These two are wonky ...
case "in_progress":
case "queued":
return "stop";
default: {
const _exhaustive: never = status;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
} }
} }
} }

View file

@ -1,18 +1,6 @@
import type { AssistantMessage, Message, Model } from "../types.js"; import type { Api, AssistantMessage, Message, Model } from "../types.js";
/** export function transformMessages<TApi extends Api>(messages: Message[], model: Model<TApi>): Message[] {
* Transform messages for cross-provider compatibility.
*
* - User and toolResult messages are copied verbatim
* - Assistant messages:
* - If from the same provider/model, copied as-is
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
*
* @param messages The messages to transform
* @param model The target model that will process these messages
* @returns A copy of the messages array with transformations applied
*/
export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
return messages.map((msg) => { return messages.map((msg) => {
// User and toolResult messages pass through unchanged // User and toolResult messages pass through unchanged
if (msg.role === "user" || msg.role === "toolResult") { if (msg.role === "user" || msg.role === "toolResult") {
@ -24,7 +12,7 @@ export function transformMessages(messages: Message[], model: Model, api: string
const assistantMsg = msg as AssistantMessage; const assistantMsg = msg as AssistantMessage;
// If message is from the same provider and API, keep as is // If message is from the same provider and API, keep as is
if (assistantMsg.provider === model.provider && assistantMsg.api === api) { if (assistantMsg.provider === model.provider && assistantMsg.api === model.api) {
return msg; return msg;
} }
@ -47,8 +35,6 @@ export function transformMessages(messages: Message[], model: Model, api: string
content: transformedContent, content: transformedContent,
}; };
} }
// Should not reach here, but return as-is for safety
return msg; return msg;
}); });
} }

View file

@ -1,5 +1,27 @@
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai"; import type { AnthropicOptions } from "./providers/anthropic";
export type Api = KnownApi | string; import type { GoogleOptions } from "./providers/google";
import type { OpenAICompletionsOptions } from "./providers/openai-completions";
import type { OpenAIResponsesOptions } from "./providers/openai-responses";
export type Api = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
export interface ApiOptionsMap {
"anthropic-messages": AnthropicOptions;
"openai-completions": OpenAICompletionsOptions;
"openai-responses": OpenAIResponsesOptions;
"google-generative-ai": GoogleOptions;
}
// Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys
type _CheckExhaustive = ApiOptionsMap extends Record<Api, GenerateOptions>
? Record<Api, GenerateOptions> extends ApiOptionsMap
? true
: ["ApiOptionsMap is missing some KnownApi values", Exclude<Api, keyof ApiOptionsMap>]
: ["ApiOptionsMap doesn't extend Record<KnownApi, GenerateOptions>"];
const _exhaustive: _CheckExhaustive = true;
// Helper type to get options for a specific API
export type OptionsForApi<TApi extends Api> = ApiOptionsMap[TApi];
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter"; export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
export type Provider = KnownProvider | string; export type Provider = KnownProvider | string;
@ -21,31 +43,17 @@ export interface GenerateOptions {
} }
// Unified options with reasoning (what public generate() accepts) // Unified options with reasoning (what public generate() accepts)
export interface GenerateOptionsUnified extends GenerateOptions { export interface SimpleGenerateOptions extends GenerateOptions {
reasoning?: ReasoningEffort; reasoning?: ReasoningEffort;
} }
// Generic GenerateFunction with typed options // Generic GenerateFunction with typed options
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = ( export type GenerateFunction<TApi extends Api> = (
model: Model, model: Model<TApi>,
context: Context, context: Context,
options: TOptions, options: OptionsForApi<TApi>,
) => GenerateStream; ) => GenerateStream;
// Legacy LLM interface (to be removed)
export interface LLMOptions {
temperature?: number;
maxTokens?: number;
onEvent?: (event: AssistantMessageEvent) => void;
signal?: AbortSignal;
}
export interface LLM<T extends LLMOptions> {
generate(request: Context, options?: T): Promise<AssistantMessage>;
getModel(): Model;
getApi(): string;
}
export interface TextContent { export interface TextContent {
type: "text"; type: "text";
text: string; text: string;
@ -100,7 +108,7 @@ export interface AssistantMessage {
model: string; model: string;
usage: Usage; usage: Usage;
stopReason: StopReason; stopReason: StopReason;
error?: string | Error; error?: string;
} }
export interface ToolResultMessage { export interface ToolResultMessage {
@ -138,10 +146,10 @@ export type AssistantMessageEvent =
| { type: "error"; error: string; partial: AssistantMessage }; | { type: "error"; error: string; partial: AssistantMessage };
// Model interface for the unified model system // Model interface for the unified model system
export interface Model { export interface Model<TApi extends Api> {
id: string; id: string;
name: string; name: string;
api: Api; api: TApi;
provider: Provider; provider: Provider;
baseUrl: string; baseUrl: string;
reasoning: boolean; reasoning: boolean;

View file

@ -1,128 +1,103 @@
import { describe, it, beforeAll, expect } from "vitest"; import { beforeAll, describe, expect, it } from "vitest";
import { GoogleLLM } from "../src/providers/google.js"; import { complete, stream } from "../src/generate.js";
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
import { AnthropicLLM } from "../src/providers/anthropic.js";
import type { LLM, LLMOptions, Context } from "../src/types.js";
import { getModel } from "../src/models.js"; import { getModel } from "../src/models.js";
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
async function testAbortSignal<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
const context: Context = { const context: Context = {
messages: [{ messages: [
role: "user", {
content: "What is 15 + 27? Think step by step. Then list 50 first names." role: "user",
}] content: "What is 15 + 27? Think step by step. Then list 50 first names.",
}; },
],
};
let abortFired = false; let abortFired = false;
const controller = new AbortController(); const controller = new AbortController();
const response = await llm.generate(context, { const response = await stream(llm, context, { ...options, signal: controller.signal });
...options, for await (const event of response) {
signal: controller.signal, if (abortFired) return;
onEvent: (event) => { setTimeout(() => controller.abort(), 3000);
// console.log(JSON.stringify(event, null, 2)); abortFired = true;
if (abortFired) return; break;
setTimeout(() => controller.abort(), 2000); }
abortFired = true; const msg = await response.finalMessage();
}
});
// If we get here without throwing, the abort didn't work // If we get here without throwing, the abort didn't work
expect(response.stopReason).toBe("error"); expect(msg.stopReason).toBe("error");
expect(response.content.length).toBeGreaterThan(0); expect(msg.content.length).toBeGreaterThan(0);
context.messages.push(response); context.messages.push(msg);
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." }); context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
// Ensure we can still make requests after abort const followUp = await complete(llm, context, options);
const followUp = await llm.generate(context, options); expect(followUp.stopReason).toBe("stop");
expect(followUp.stopReason).toBe("stop"); expect(followUp.content.length).toBeGreaterThan(0);
expect(followUp.content.length).toBeGreaterThan(0);
} }
async function testImmediateAbort<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testImmediateAbort<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
const controller = new AbortController(); const controller = new AbortController();
// Abort immediately controller.abort();
controller.abort();
const context: Context = { const context: Context = {
messages: [{ role: "user", content: "Hello" }] messages: [{ role: "user", content: "Hello" }],
}; };
const response = await llm.generate(context, { const response = await complete(llm, context, { ...options, signal: controller.signal });
...options, expect(response.stopReason).toBe("error");
signal: controller.signal
});
expect(response.stopReason).toBe("error");
} }
describe("AI Providers Abort Tests", () => { describe("AI Providers Abort Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => { describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
let llm: GoogleLLM; const llm = getModel("google", "gemini-2.5-flash");
beforeAll(() => { it("should abort mid-stream", async () => {
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!); await testAbortSignal(llm, { thinking: { enabled: true } });
}); });
it("should abort mid-stream", async () => { it("should handle immediate abort", async () => {
await testAbortSignal(llm, { thinking: { enabled: true } }); await testImmediateAbort(llm, { thinking: { enabled: true } });
}); });
});
it("should handle immediate abort", async () => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
await testImmediateAbort(llm, { thinking: { enabled: true } }); const llm: Model<"openai-completions"> = {
}); ...getModel("openai", "gpt-4o-mini")!,
}); api: "openai-completions",
};
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => { it("should abort mid-stream", async () => {
let llm: OpenAICompletionsLLM; await testAbortSignal(llm);
});
beforeAll(() => { it("should handle immediate abort", async () => {
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!); await testImmediateAbort(llm);
}); });
});
it("should abort mid-stream", async () => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
await testAbortSignal(llm); const llm = getModel("openai", "gpt-5-mini");
});
it("should handle immediate abort", async () => { it("should abort mid-stream", async () => {
await testImmediateAbort(llm); await testAbortSignal(llm);
}); });
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => { it("should handle immediate abort", async () => {
let llm: OpenAIResponsesLLM; await testImmediateAbort(llm);
});
});
beforeAll(() => { describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
const model = getModel("openai", "gpt-5-mini"); const llm = getModel("anthropic", "claude-opus-4-1-20250805");
if (!model) {
throw new Error("Model not found");
}
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
});
it("should abort mid-stream", async () => { it("should abort mid-stream", async () => {
await testAbortSignal(llm, {}); await testAbortSignal(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
}); });
it("should handle immediate abort", async () => { it("should handle immediate abort", async () => {
await testImmediateAbort(llm, {}); await testImmediateAbort(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
}); });
}); });
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
let llm: AnthropicLLM;
beforeAll(() => {
llm = new AnthropicLLM(getModel("anthropic", "claude-opus-4-1-20250805")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
});
it("should abort mid-stream", async () => {
await testAbortSignal(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
});
it("should handle immediate abort", async () => {
await testImmediateAbort(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
});
});
}); });

View file

@ -1,313 +1,265 @@
import { describe, it, beforeAll, expect } from "vitest"; import { describe, expect, it } from "vitest";
import { GoogleLLM } from "../src/providers/google.js"; import { complete } from "../src/generate.js";
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
import { AnthropicLLM } from "../src/providers/anthropic.js";
import type { LLM, LLMOptions, Context, UserMessage, AssistantMessage } from "../src/types.js";
import { getModel } from "../src/models.js"; import { getModel } from "../src/models.js";
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
async function testEmptyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Test with completely empty content array // Test with completely empty content array
const emptyMessage: UserMessage = { const emptyMessage: UserMessage = {
role: "user", role: "user",
content: [] content: [],
}; };
const context: Context = { const context: Context = {
messages: [emptyMessage] messages: [emptyMessage],
}; };
const response = await llm.generate(context, options); const response = await complete(llm, context, options);
// Should either handle gracefully or return an error // Should either handle gracefully or return an error
expect(response).toBeDefined(); expect(response).toBeDefined();
expect(response.role).toBe("assistant"); expect(response.role).toBe("assistant");
// Should handle empty string gracefully
// Most providers should return an error or empty response if (response.stopReason === "error") {
if (response.stopReason === "error") { expect(response.error).toBeDefined();
expect(response.error).toBeDefined(); } else {
} else { expect(response.content).toBeDefined();
// If it didn't error, it should have some content or gracefully handle empty }
expect(response.content).toBeDefined();
}
} }
async function testEmptyStringMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testEmptyStringMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Test with empty string content // Test with empty string content
const context: Context = { const context: Context = {
messages: [{ messages: [
role: "user", {
content: "" role: "user",
}] content: "",
}; },
],
};
const response = await llm.generate(context, options); const response = await complete(llm, context, options);
expect(response).toBeDefined(); expect(response).toBeDefined();
expect(response.role).toBe("assistant"); expect(response.role).toBe("assistant");
// Should handle empty string gracefully // Should handle empty string gracefully
if (response.stopReason === "error") { if (response.stopReason === "error") {
expect(response.error).toBeDefined(); expect(response.error).toBeDefined();
} else { } else {
expect(response.content).toBeDefined(); expect(response.content).toBeDefined();
} }
} }
async function testWhitespaceOnlyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testWhitespaceOnlyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Test with whitespace-only content // Test with whitespace-only content
const context: Context = { const context: Context = {
messages: [{ messages: [
role: "user", {
content: " \n\t " role: "user",
}] content: " \n\t ",
}; },
],
};
const response = await llm.generate(context, options); const response = await complete(llm, context, options);
expect(response).toBeDefined(); expect(response).toBeDefined();
expect(response.role).toBe("assistant"); expect(response.role).toBe("assistant");
// Should handle whitespace-only gracefully // Should handle whitespace-only gracefully
if (response.stopReason === "error") { if (response.stopReason === "error") {
expect(response.error).toBeDefined(); expect(response.error).toBeDefined();
} else { } else {
expect(response.content).toBeDefined(); expect(response.content).toBeDefined();
} }
} }
async function testEmptyAssistantMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) { async function testEmptyAssistantMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
// Test with empty assistant message in conversation flow // Test with empty assistant message in conversation flow
// User -> Empty Assistant -> User // User -> Empty Assistant -> User
const emptyAssistant: AssistantMessage = { const emptyAssistant: AssistantMessage = {
role: "assistant", role: "assistant",
content: [], content: [],
api: llm.getApi(), api: llm.api,
provider: llm.getModel().provider, provider: llm.provider,
model: llm.getModel().id, model: llm.id,
usage: { usage: {
input: 10, input: 10,
output: 0, output: 0,
cacheRead: 0, cacheRead: 0,
cacheWrite: 0, cacheWrite: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
}, },
stopReason: "stop" stopReason: "stop",
}; };
const context: Context = { const context: Context = {
messages: [ messages: [
{ {
role: "user", role: "user",
content: "Hello, how are you?" content: "Hello, how are you?",
}, },
emptyAssistant, emptyAssistant,
{ {
role: "user", role: "user",
content: "Please respond this time." content: "Please respond this time.",
} },
] ],
}; };
const response = await llm.generate(context, options); const response = await complete(llm, context, options);
expect(response).toBeDefined(); expect(response).toBeDefined();
expect(response.role).toBe("assistant"); expect(response.role).toBe("assistant");
// Should handle empty assistant message in context gracefully // Should handle empty assistant message in context gracefully
if (response.stopReason === "error") { if (response.stopReason === "error") {
expect(response.error).toBeDefined(); expect(response.error).toBeDefined();
} else { } else {
expect(response.content).toBeDefined(); expect(response.content).toBeDefined();
expect(response.content.length).toBeGreaterThan(0); expect(response.content.length).toBeGreaterThan(0);
} }
} }
describe("AI Providers Empty Message Tests", () => { describe("AI Providers Empty Message Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => { describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
let llm: GoogleLLM; const llm = getModel("google", "gemini-2.5-flash");
beforeAll(() => { it("should handle empty content array", async () => {
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!); await testEmptyMessage(llm);
}); });
it("should handle empty content array", async () => { it("should handle empty string content", async () => {
await testEmptyMessage(llm); await testEmptyStringMessage(llm);
}); });
it("should handle empty string content", async () => { it("should handle whitespace-only content", async () => {
await testEmptyStringMessage(llm); await testWhitespaceOnlyMessage(llm);
}); });
it("should handle whitespace-only content", async () => { it("should handle empty assistant message in conversation", async () => {
await testWhitespaceOnlyMessage(llm); await testEmptyAssistantMessage(llm);
}); });
});
it("should handle empty assistant message in conversation", async () => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
await testEmptyAssistantMessage(llm); const llm = getModel("openai", "gpt-4o-mini");
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => { it("should handle empty content array", async () => {
let llm: OpenAICompletionsLLM; await testEmptyMessage(llm);
});
beforeAll(() => { it("should handle empty string content", async () => {
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!); await testEmptyStringMessage(llm);
}); });
it("should handle empty content array", async () => { it("should handle whitespace-only content", async () => {
await testEmptyMessage(llm); await testWhitespaceOnlyMessage(llm);
}); });
it("should handle empty string content", async () => { it("should handle empty assistant message in conversation", async () => {
await testEmptyStringMessage(llm); await testEmptyAssistantMessage(llm);
}); });
});
it("should handle whitespace-only content", async () => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
await testWhitespaceOnlyMessage(llm); const llm = getModel("openai", "gpt-5-mini");
});
it("should handle empty assistant message in conversation", async () => { it("should handle empty content array", async () => {
await testEmptyAssistantMessage(llm); await testEmptyMessage(llm);
}); });
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => { it("should handle empty string content", async () => {
let llm: OpenAIResponsesLLM; await testEmptyStringMessage(llm);
});
beforeAll(() => { it("should handle whitespace-only content", async () => {
const model = getModel("openai", "gpt-5-mini"); await testWhitespaceOnlyMessage(llm);
if (!model) { });
throw new Error("Model gpt-5-mini not found");
}
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
});
it("should handle empty content array", async () => { it("should handle empty assistant message in conversation", async () => {
await testEmptyMessage(llm); await testEmptyAssistantMessage(llm);
}); });
});
it("should handle empty string content", async () => { describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
await testEmptyStringMessage(llm); const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
});
it("should handle whitespace-only content", async () => { it("should handle empty content array", async () => {
await testWhitespaceOnlyMessage(llm); await testEmptyMessage(llm);
}); });
it("should handle empty assistant message in conversation", async () => { it("should handle empty string content", async () => {
await testEmptyAssistantMessage(llm); await testEmptyStringMessage(llm);
}); });
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => { it("should handle whitespace-only content", async () => {
let llm: AnthropicLLM; await testWhitespaceOnlyMessage(llm);
});
beforeAll(() => { it("should handle empty assistant message in conversation", async () => {
llm = new AnthropicLLM(getModel("anthropic", "claude-3-5-haiku-20241022")!, process.env.ANTHROPIC_OAUTH_TOKEN!); await testEmptyAssistantMessage(llm);
}); });
});
it("should handle empty content array", async () => { describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
await testEmptyMessage(llm); const llm = getModel("xai", "grok-3");
});
it("should handle empty string content", async () => { it("should handle empty content array", async () => {
await testEmptyStringMessage(llm); await testEmptyMessage(llm);
}); });
it("should handle whitespace-only content", async () => { it("should handle empty string content", async () => {
await testWhitespaceOnlyMessage(llm); await testEmptyStringMessage(llm);
}); });
it("should handle empty assistant message in conversation", async () => { it("should handle whitespace-only content", async () => {
await testEmptyAssistantMessage(llm); await testWhitespaceOnlyMessage(llm);
}); });
});
// Test with xAI/Grok if available it("should handle empty assistant message in conversation", async () => {
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => { await testEmptyAssistantMessage(llm);
let llm: OpenAICompletionsLLM; });
});
beforeAll(() => { describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
const model = getModel("xai", "grok-3"); const llm = getModel("groq", "openai/gpt-oss-20b");
if (!model) {
throw new Error("Model grok-3 not found");
}
llm = new OpenAICompletionsLLM(model, process.env.XAI_API_KEY!);
});
it("should handle empty content array", async () => { it("should handle empty content array", async () => {
await testEmptyMessage(llm); await testEmptyMessage(llm);
}); });
it("should handle empty string content", async () => { it("should handle empty string content", async () => {
await testEmptyStringMessage(llm); await testEmptyStringMessage(llm);
}); });
it("should handle whitespace-only content", async () => { it("should handle whitespace-only content", async () => {
await testWhitespaceOnlyMessage(llm); await testWhitespaceOnlyMessage(llm);
}); });
it("should handle empty assistant message in conversation", async () => { it("should handle empty assistant message in conversation", async () => {
await testEmptyAssistantMessage(llm); await testEmptyAssistantMessage(llm);
}); });
}); });
// Test with Groq if available describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => { const llm = getModel("cerebras", "gpt-oss-120b");
let llm: OpenAICompletionsLLM;
beforeAll(() => { it("should handle empty content array", async () => {
const model = getModel("groq", "llama-3.3-70b-versatile"); await testEmptyMessage(llm);
if (!model) { });
throw new Error("Model llama-3.3-70b-versatile not found");
}
llm = new OpenAICompletionsLLM(model, process.env.GROQ_API_KEY!);
});
it("should handle empty content array", async () => { it("should handle empty string content", async () => {
await testEmptyMessage(llm); await testEmptyStringMessage(llm);
}); });
it("should handle empty string content", async () => { it("should handle whitespace-only content", async () => {
await testEmptyStringMessage(llm); await testWhitespaceOnlyMessage(llm);
}); });
it("should handle whitespace-only content", async () => { it("should handle empty assistant message in conversation", async () => {
await testWhitespaceOnlyMessage(llm); await testEmptyAssistantMessage(llm);
}); });
});
it("should handle empty assistant message in conversation", async () => {
await testEmptyAssistantMessage(llm);
});
});
// Test with Cerebras if available
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
const model = getModel("cerebras", "gpt-oss-120b");
if (!model) {
throw new Error("Model gpt-oss-120b not found");
}
llm = new OpenAICompletionsLLM(model, process.env.CEREBRAS_API_KEY!);
});
it("should handle empty content array", async () => {
await testEmptyMessage(llm);
});
it("should handle empty string content", async () => {
await testEmptyStringMessage(llm);
});
it("should handle whitespace-only content", async () => {
await testWhitespaceOnlyMessage(llm);
});
it("should handle empty assistant message in conversation", async () => {
await testEmptyAssistantMessage(llm);
});
});
}); });

View file

@ -1,311 +1,612 @@
import { describe, it, beforeAll, expect } from "vitest"; import { type ChildProcess, execSync, spawn } from "child_process";
import { getModel } from "../src/models.js";
import { generate, generateComplete } from "../src/generate.js";
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
import { readFileSync } from "fs"; import { readFileSync } from "fs";
import { join, dirname } from "path"; import { dirname, join } from "path";
import { fileURLToPath } from "url"; import { fileURLToPath } from "url";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { complete, stream } from "../src/generate.js";
import { getModel } from "../src/models.js";
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
const __filename = fileURLToPath(import.meta.url); const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename); const __dirname = dirname(__filename);
// Calculator tool definition (same as examples) // Calculator tool definition (same as examples)
const calculatorTool: Tool = { const calculatorTool: Tool = {
name: "calculator", name: "calculator",
description: "Perform basic arithmetic operations", description: "Perform basic arithmetic operations",
parameters: { parameters: {
type: "object", type: "object",
properties: { properties: {
a: { type: "number", description: "First number" }, a: { type: "number", description: "First number" },
b: { type: "number", description: "Second number" }, b: { type: "number", description: "Second number" },
operation: { operation: {
type: "string", type: "string",
enum: ["add", "subtract", "multiply", "divide"], enum: ["add", "subtract", "multiply", "divide"],
description: "The operation to perform" description: "The operation to perform",
} },
}, },
required: ["a", "b", "operation"] required: ["a", "b", "operation"],
} },
}; };
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) { async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
const context: Context = { const context: Context = {
systemPrompt: "You are a helpful assistant. Be concise.", systemPrompt: "You are a helpful assistant. Be concise.",
messages: [ messages: [{ role: "user", content: "Reply with exactly: 'Hello test successful'" }],
{ role: "user", content: "Reply with exactly: 'Hello test successful'" } };
] const response = await complete(model, context, options);
};
const response = await generateComplete(model, context, options); expect(response.role).toBe("assistant");
expect(response.content).toBeTruthy();
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.error).toBeFalsy();
expect(response.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain("Hello test successful");
expect(response.role).toBe("assistant"); context.messages.push(response);
expect(response.content).toBeTruthy(); context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.error).toBeFalsy();
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
context.messages.push(response); const secondResponse = await complete(model, context, options);
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
const secondResponse = await generateComplete(model, context, options); expect(secondResponse.role).toBe("assistant");
expect(secondResponse.content).toBeTruthy();
expect(secondResponse.role).toBe("assistant"); expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
expect(secondResponse.content).toBeTruthy(); expect(secondResponse.usage.output).toBeGreaterThan(0);
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0); expect(secondResponse.error).toBeFalsy();
expect(secondResponse.usage.output).toBeGreaterThan(0); expect(secondResponse.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain(
expect(secondResponse.error).toBeFalsy(); "Goodbye test successful",
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful"); );
} }
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) { async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
const context: Context = { const context: Context = {
systemPrompt: "You are a helpful assistant that uses tools when asked.", systemPrompt: "You are a helpful assistant that uses tools when asked.",
messages: [{ messages: [
role: "user", {
content: "Calculate 15 + 27 using the calculator tool." role: "user",
}], content: "Calculate 15 + 27 using the calculator tool.",
tools: [calculatorTool] },
}; ],
tools: [calculatorTool],
};
const response = await generateComplete(model, context, options); const response = await complete(model, context, options);
expect(response.stopReason).toBe("toolUse"); expect(response.stopReason).toBe("toolUse");
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy(); expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
const toolCall = response.content.find(b => b.type == "toolCall"); const toolCall = response.content.find((b) => b.type === "toolCall");
if (toolCall && toolCall.type === "toolCall") { if (toolCall && toolCall.type === "toolCall") {
expect(toolCall.name).toBe("calculator"); expect(toolCall.name).toBe("calculator");
expect(toolCall.id).toBeTruthy(); expect(toolCall.id).toBeTruthy();
} }
} }
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) { async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
let textStarted = false; let textStarted = false;
let textChunks = ""; let textChunks = "";
let textCompleted = false; let textCompleted = false;
const context: Context = { const context: Context = {
messages: [{ role: "user", content: "Count from 1 to 3" }] messages: [{ role: "user", content: "Count from 1 to 3" }],
}; };
const stream = generate(model, context, options); const s = stream(model, context, options);
for await (const event of stream) { for await (const event of s) {
if (event.type === "text_start") { if (event.type === "text_start") {
textStarted = true; textStarted = true;
} else if (event.type === "text_delta") { } else if (event.type === "text_delta") {
textChunks += event.delta; textChunks += event.delta;
} else if (event.type === "text_end") { } else if (event.type === "text_end") {
textCompleted = true; textCompleted = true;
} }
} }
const response = await stream.finalMessage(); const response = await s.finalMessage();
expect(textStarted).toBe(true); expect(textStarted).toBe(true);
expect(textChunks.length).toBeGreaterThan(0); expect(textChunks.length).toBeGreaterThan(0);
expect(textCompleted).toBe(true); expect(textCompleted).toBe(true);
expect(response.content.some(b => b.type == "text")).toBeTruthy(); expect(response.content.some((b) => b.type === "text")).toBeTruthy();
} }
async function handleThinking(model: Model, options: GenerateOptionsUnified) { async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
let thinkingStarted = false; let thinkingStarted = false;
let thinkingChunks = ""; let thinkingChunks = "";
let thinkingCompleted = false; let thinkingCompleted = false;
const context: Context = { const context: Context = {
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }] messages: [
}; {
role: "user",
content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.`,
},
],
};
const stream = generate(model, context, options); const s = stream(model, context, options);
for await (const event of stream) { for await (const event of s) {
if (event.type === "thinking_start") { if (event.type === "thinking_start") {
thinkingStarted = true; thinkingStarted = true;
} else if (event.type === "thinking_delta") { } else if (event.type === "thinking_delta") {
thinkingChunks += event.delta; thinkingChunks += event.delta;
} else if (event.type === "thinking_end") { } else if (event.type === "thinking_end") {
thinkingCompleted = true; thinkingCompleted = true;
} }
} }
const response = await stream.finalMessage(); const response = await s.finalMessage();
expect(response.stopReason, `Error: ${response.error}`).toBe("stop"); expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
expect(thinkingStarted).toBe(true); expect(thinkingStarted).toBe(true);
expect(thinkingChunks.length).toBeGreaterThan(0); expect(thinkingChunks.length).toBeGreaterThan(0);
expect(thinkingCompleted).toBe(true); expect(thinkingCompleted).toBe(true);
expect(response.content.some(b => b.type == "thinking")).toBeTruthy(); expect(response.content.some((b) => b.type === "thinking")).toBeTruthy();
} }
async function handleImage(model: Model, options?: GenerateOptionsUnified) { async function handleImage<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
// Check if the model supports images // Check if the model supports images
if (!model.input.includes("image")) { if (!model.input.includes("image")) {
console.log(`Skipping image test - model ${model.id} doesn't support images`); console.log(`Skipping image test - model ${model.id} doesn't support images`);
return; return;
} }
// Read the test image // Read the test image
const imagePath = join(__dirname, "data", "red-circle.png"); const imagePath = join(__dirname, "data", "red-circle.png");
const imageBuffer = readFileSync(imagePath); const imageBuffer = readFileSync(imagePath);
const base64Image = imageBuffer.toString("base64"); const base64Image = imageBuffer.toString("base64");
const imageContent: ImageContent = { const imageContent: ImageContent = {
type: "image", type: "image",
data: base64Image, data: base64Image,
mimeType: "image/png", mimeType: "image/png",
}; };
const context: Context = { const context: Context = {
messages: [ messages: [
{ {
role: "user", role: "user",
content: [ content: [
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." }, {
imageContent, type: "text",
], text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...).",
}, },
], imageContent,
}; ],
},
],
};
const response = await generateComplete(model, context, options); const response = await complete(model, context, options);
// Check the response mentions red and circle // Check the response mentions red and circle
expect(response.content.length > 0).toBeTruthy(); expect(response.content.length > 0).toBeTruthy();
const textContent = response.content.find(b => b.type == "text"); const textContent = response.content.find((b) => b.type === "text");
if (textContent && textContent.type === "text") { if (textContent && textContent.type === "text") {
const lowerContent = textContent.text.toLowerCase(); const lowerContent = textContent.text.toLowerCase();
expect(lowerContent).toContain("red"); expect(lowerContent).toContain("red");
expect(lowerContent).toContain("circle"); expect(lowerContent).toContain("circle");
} }
} }
async function multiTurn(model: Model, options?: GenerateOptionsUnified) { async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
const context: Context = { const context: Context = {
systemPrompt: "You are a helpful assistant that can use tools to answer questions.", systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
messages: [ messages: [
{ {
role: "user", role: "user",
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool." content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool.",
} },
], ],
tools: [calculatorTool] tools: [calculatorTool],
}; };
// Collect all text content from all assistant responses // Collect all text content from all assistant responses
let allTextContent = ""; let allTextContent = "";
let hasSeenThinking = false; let hasSeenThinking = false;
let hasSeenToolCalls = false; let hasSeenToolCalls = false;
const maxTurns = 5; // Prevent infinite loops const maxTurns = 5; // Prevent infinite loops
for (let turn = 0; turn < maxTurns; turn++) { for (let turn = 0; turn < maxTurns; turn++) {
const response = await generateComplete(model, context, options); const response = await complete(model, context, options);
// Add the assistant response to context // Add the assistant response to context
context.messages.push(response); context.messages.push(response);
// Process content blocks // Process content blocks
for (const block of response.content) { for (const block of response.content) {
if (block.type === "text") { if (block.type === "text") {
allTextContent += block.text; allTextContent += block.text;
} else if (block.type === "thinking") { } else if (block.type === "thinking") {
hasSeenThinking = true; hasSeenThinking = true;
} else if (block.type === "toolCall") { } else if (block.type === "toolCall") {
hasSeenToolCalls = true; hasSeenToolCalls = true;
// Process the tool call // Process the tool call
expect(block.name).toBe("calculator"); expect(block.name).toBe("calculator");
expect(block.id).toBeTruthy(); expect(block.id).toBeTruthy();
expect(block.arguments).toBeTruthy(); expect(block.arguments).toBeTruthy();
const { a, b, operation } = block.arguments; const { a, b, operation } = block.arguments;
let result: number; let result: number;
switch (operation) { switch (operation) {
case "add": result = a + b; break; case "add":
case "multiply": result = a * b; break; result = a + b;
default: result = 0; break;
} case "multiply":
result = a * b;
break;
default:
result = 0;
}
// Add tool result to context // Add tool result to context
context.messages.push({ context.messages.push({
role: "toolResult", role: "toolResult",
toolCallId: block.id, toolCallId: block.id,
toolName: block.name, toolName: block.name,
content: `${result}`, content: `${result}`,
isError: false isError: false,
}); });
} }
} }
// If we got a stop response with text content, we're likely done // If we got a stop response with text content, we're likely done
expect(response.stopReason).not.toBe("error"); expect(response.stopReason).not.toBe("error");
if (response.stopReason === "stop") { if (response.stopReason === "stop") {
break; break;
} }
} }
// Verify we got either thinking content or tool calls (or both) // Verify we got either thinking content or tool calls (or both)
expect(hasSeenThinking || hasSeenToolCalls).toBe(true); expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
// The accumulated text should reference both calculations // The accumulated text should reference both calculations
expect(allTextContent).toBeTruthy(); expect(allTextContent).toBeTruthy();
expect(allTextContent.includes("714")).toBe(true); expect(allTextContent.includes("714")).toBe(true);
expect(allTextContent.includes("887")).toBe(true); expect(allTextContent.includes("887")).toBe(true);
} }
describe("Generate E2E Tests", () => { describe("Generate E2E Tests", () => {
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => { describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
let model: Model; const llm = getModel("google", "gemini-2.5-flash");
beforeAll(() => { it("should complete basic text generation", async () => {
model = getModel("anthropic", "claude-3-5-haiku-20241022"); await basicTextGeneration(llm);
}); });
it("should complete basic text generation", async () => { it("should handle tool calling", async () => {
await basicTextGeneration(model); await handleToolCall(llm);
}); });
it("should handle tool calling", async () => { it("should handle streaming", async () => {
await handleToolCall(model); await handleStreaming(llm);
}); });
it("should handle streaming", async () => { it("should handle ", async () => {
await handleStreaming(model); await handleThinking(llm, { thinking: { enabled: true, budgetTokens: 1024 } });
}); });
it("should handle image input", async () => { it("should handle multi-turn with thinking and tools", async () => {
await handleImage(model); await multiTurn(llm, { thinking: { enabled: true, budgetTokens: 2048 } });
}); });
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => { it("should handle image input", async () => {
let model: Model; await handleImage(llm);
});
});
beforeAll(() => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
model = getModel("anthropic", "claude-sonnet-4-20250514"); const llm: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
});
it("should complete basic text generation", async () => { it("should complete basic text generation", async () => {
await basicTextGeneration(model); await basicTextGeneration(llm);
}); });
it("should handle tool calling", async () => { it("should handle tool calling", async () => {
await handleToolCall(model); await handleToolCall(llm);
}); });
it("should handle streaming", async () => { it("should handle streaming", async () => {
await handleStreaming(model); await handleStreaming(llm);
}); });
it("should handle thinking mode", async () => { it("should handle image input", async () => {
await handleThinking(model, { reasoning: "low" }); await handleImage(llm);
}); });
});
it("should handle multi-turn with thinking and tools", async () => { describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
await multiTurn(model, { reasoning: "medium" }); const llm = getModel("openai", "gpt-5-mini");
});
it("should handle image input", async () => { it("should complete basic text generation", async () => {
await handleImage(model); await basicTextGeneration(llm);
}); });
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle ", { retry: 2 }, async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
it("should complete basic text generation", async () => {
await basicTextGeneration(model, { thinkingEnabled: true });
});
it("should handle tool calling", async () => {
await handleToolCall(model);
});
it("should handle streaming", async () => {
await handleStreaming(model);
});
it("should handle image input", async () => {
await handleImage(model);
});
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
const model = getModel("anthropic", "claude-sonnet-4-20250514");
it("should complete basic text generation", async () => {
await basicTextGeneration(model, { thinkingEnabled: true });
});
it("should handle tool calling", async () => {
await handleToolCall(model);
});
it("should handle streaming", async () => {
await handleStreaming(model);
});
it("should handle thinking", async () => {
await handleThinking(model, { thinkingEnabled: true });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(model, { thinkingEnabled: true });
});
it("should handle image input", async () => {
await handleImage(model);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
const model = getModel("openai", "gpt-5-mini");
it("should complete basic text generation", async () => {
await basicTextGeneration(model);
});
it("should handle tool calling", async () => {
await handleToolCall(model);
});
it("should handle streaming", async () => {
await handleStreaming(model);
});
it("should handle image input", async () => {
await handleImage(model);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
const llm = getModel("xai", "grok-code-fast-1");
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
const llm = getModel("groq", "openai/gpt-oss-20b");
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
const llm = getModel("cerebras", "gpt-oss-120b");
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
});
});
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
const llm = getModel("openrouter", "z-ai/glm-4.5v");
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, { reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
await multiTurn(llm, { reasoningEffort: "medium" });
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
// Check if ollama is installed
let ollamaInstalled = false;
try {
execSync("which ollama", { stdio: "ignore" });
ollamaInstalled = true;
} catch {
ollamaInstalled = false;
}
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
let llm: Model<"openai-completions">;
let ollamaProcess: ChildProcess | null = null;
beforeAll(async () => {
// Check if model is available, if not pull it
try {
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
} catch {
console.log("Pulling gpt-oss:20b model for Ollama tests...");
try {
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
} catch (e) {
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
return;
}
}
// Start ollama server
ollamaProcess = spawn("ollama", ["serve"], {
detached: false,
stdio: "ignore",
});
// Wait for server to be ready
await new Promise<void>((resolve) => {
const checkServer = async () => {
try {
const response = await fetch("http://localhost:11434/api/tags");
if (response.ok) {
resolve();
} else {
setTimeout(checkServer, 500);
}
} catch {
setTimeout(checkServer, 500);
}
};
setTimeout(checkServer, 1000); // Initial delay
});
llm = {
id: "gpt-oss:20b",
api: "openai-completions",
provider: "ollama",
baseUrl: "http://localhost:11434/v1",
reasoning: true,
input: ["text"],
contextWindow: 128000,
maxTokens: 16000,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
},
name: "Ollama GPT-OSS 20B",
};
}, 30000); // 30 second timeout for setup
afterAll(() => {
// Kill ollama server
if (ollamaProcess) {
ollamaProcess.kill("SIGTERM");
ollamaProcess = null;
}
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm, { apiKey: "test" });
});
it("should handle tool calling", async () => {
await handleToolCall(llm, { apiKey: "test" });
});
it("should handle streaming", async () => {
await handleStreaming(llm, { apiKey: "test" });
});
it("should handle thinking mode", async () => {
await handleThinking(llm, { apiKey: "test", reasoningEffort: "medium" });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, { apiKey: "test", reasoningEffort: "medium" });
});
});
}); });

View file

@ -1,503 +1,489 @@
import { describe, it, expect, beforeAll } from "vitest"; import { describe, expect, it } from "vitest";
import { GoogleLLM } from "../src/providers/google.js"; import { complete } from "../src/generate.js";
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js"; import { getModel } from "../src/models.js";
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js"; import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js";
import { AnthropicLLM } from "../src/providers/anthropic.js";
import type { LLM, Context, AssistantMessage, Tool, Message } from "../src/types.js";
import { createLLM, getModel } from "../src/models.js";
// Tool for testing // Tool for testing
const weatherTool: Tool = { const weatherTool: Tool = {
name: "get_weather", name: "get_weather",
description: "Get the weather for a location", description: "Get the weather for a location",
parameters: { parameters: {
type: "object", type: "object",
properties: { properties: {
location: { type: "string", description: "City name" } location: { type: "string", description: "City name" },
}, },
required: ["location"] required: ["location"],
} },
}; };
// Pre-built contexts representing typical outputs from each provider // Pre-built contexts representing typical outputs from each provider
const providerContexts = { const providerContexts = {
// Anthropic-style message with thinking block // Anthropic-style message with thinking block
anthropic: { anthropic: {
message: { message: {
role: "assistant", role: "assistant",
content: [ content: [
{ {
type: "thinking", type: "thinking",
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391", thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
thinkingSignature: "signature_abc123" thinkingSignature: "signature_abc123",
}, },
{ {
type: "text", type: "text",
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you." text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you.",
}, },
{ {
type: "toolCall", type: "toolCall",
id: "toolu_01abc123", id: "toolu_01abc123",
name: "get_weather", name: "get_weather",
arguments: { location: "Tokyo" } arguments: { location: "Tokyo" },
} },
], ],
provider: "anthropic", provider: "anthropic",
model: "claude-3-5-haiku-latest", model: "claude-3-5-haiku-latest",
usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, usage: {
stopReason: "toolUse" input: 100,
} as AssistantMessage, output: 50,
toolResult: { cacheRead: 0,
role: "toolResult" as const, cacheWrite: 0,
toolCallId: "toolu_01abc123", cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
toolName: "get_weather", },
content: "Weather in Tokyo: 18°C, partly cloudy", stopReason: "toolUse",
isError: false } as AssistantMessage,
}, toolResult: {
facts: { role: "toolResult" as const,
calculation: 391, toolCallId: "toolu_01abc123",
city: "Tokyo", toolName: "get_weather",
temperature: 18, content: "Weather in Tokyo: 18°C, partly cloudy",
capital: "Vienna" isError: false,
} },
}, facts: {
calculation: 391,
city: "Tokyo",
temperature: 18,
capital: "Vienna",
},
},
// Google-style message with thinking // Google-style message with thinking
google: { google: {
message: { message: {
role: "assistant", role: "assistant",
content: [ content: [
{ {
type: "thinking", type: "thinking",
thinking: "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456", thinking:
thinkingSignature: undefined "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
}, thinkingSignature: undefined,
{ },
type: "text", {
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you." type: "text",
}, text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you.",
{ },
type: "toolCall", {
id: "call_gemini_123", type: "toolCall",
name: "get_weather", id: "call_gemini_123",
arguments: { location: "Berlin" } name: "get_weather",
} arguments: { location: "Berlin" },
], },
provider: "google", ],
model: "gemini-2.5-flash", provider: "google",
usage: { input: 120, output: 60, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, model: "gemini-2.5-flash",
stopReason: "toolUse" usage: {
} as AssistantMessage, input: 120,
toolResult: { output: 60,
role: "toolResult" as const, cacheRead: 0,
toolCallId: "call_gemini_123", cacheWrite: 0,
toolName: "get_weather", cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
content: "Weather in Berlin: 22°C, sunny", },
isError: false stopReason: "toolUse",
}, } as AssistantMessage,
facts: { toolResult: {
calculation: 456, role: "toolResult" as const,
city: "Berlin", toolCallId: "call_gemini_123",
temperature: 22, toolName: "get_weather",
capital: "Paris" content: "Weather in Berlin: 22°C, sunny",
} isError: false,
}, },
facts: {
calculation: 456,
city: "Berlin",
temperature: 22,
capital: "Paris",
},
},
// OpenAI Completions style (with reasoning_content) // OpenAI Completions style (with reasoning_content)
openaiCompletions: { openaiCompletions: {
message: { message: {
role: "assistant", role: "assistant",
content: [ content: [
{ {
type: "thinking", type: "thinking",
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525", thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
thinkingSignature: "reasoning_content" thinkingSignature: "reasoning_content",
}, },
{ {
type: "text", type: "text",
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now." text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now.",
}, },
{ {
type: "toolCall", type: "toolCall",
id: "call_abc123", id: "call_abc123",
name: "get_weather", name: "get_weather",
arguments: { location: "London" } arguments: { location: "London" },
} },
], ],
provider: "openai", provider: "openai",
model: "gpt-4o-mini", model: "gpt-4o-mini",
usage: { input: 110, output: 55, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, usage: {
stopReason: "toolUse" input: 110,
} as AssistantMessage, output: 55,
toolResult: { cacheRead: 0,
role: "toolResult" as const, cacheWrite: 0,
toolCallId: "call_abc123", cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
toolName: "get_weather", },
content: "Weather in London: 15°C, rainy", stopReason: "toolUse",
isError: false } as AssistantMessage,
}, toolResult: {
facts: { role: "toolResult" as const,
calculation: 525, toolCallId: "call_abc123",
city: "London", toolName: "get_weather",
temperature: 15, content: "Weather in London: 15°C, rainy",
capital: "Madrid" isError: false,
} },
}, facts: {
calculation: 525,
city: "London",
temperature: 15,
capital: "Madrid",
},
},
// OpenAI Responses style (with complex tool call IDs) // OpenAI Responses style (with complex tool call IDs)
openaiResponses: { openaiResponses: {
message: { message: {
role: "assistant", role: "assistant",
content: [ content: [
{ {
type: "thinking", type: "thinking",
thinking: "Calculating 18 * 27: 18 * 27 = 486", thinking: "Calculating 18 * 27: 18 * 27 = 486",
thinkingSignature: '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}' thinkingSignature:
}, '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}',
{ },
type: "text", {
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.", type: "text",
textSignature: "msg_response_456" text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
}, textSignature: "msg_response_456",
{ },
type: "toolCall", {
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only type: "toolCall",
name: "get_weather", id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
arguments: { location: "Sydney" } name: "get_weather",
} arguments: { location: "Sydney" },
], },
provider: "openai", ],
model: "gpt-5-mini", provider: "openai",
usage: { input: 115, output: 58, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, model: "gpt-5-mini",
stopReason: "toolUse" usage: {
} as AssistantMessage, input: 115,
toolResult: { output: 58,
role: "toolResult" as const, cacheRead: 0,
toolCallId: "call_789_item_012", // Match the updated ID format cacheWrite: 0,
toolName: "get_weather", cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
content: "Weather in Sydney: 25°C, clear", },
isError: false stopReason: "toolUse",
}, } as AssistantMessage,
facts: { toolResult: {
calculation: 486, role: "toolResult" as const,
city: "Sydney", toolCallId: "call_789_item_012", // Match the updated ID format
temperature: 25, toolName: "get_weather",
capital: "Rome" content: "Weather in Sydney: 25°C, clear",
} isError: false,
}, },
facts: {
calculation: 486,
city: "Sydney",
temperature: 25,
capital: "Rome",
},
},
// Aborted message (stopReason: 'error') // Aborted message (stopReason: 'error')
aborted: { aborted: {
message: { message: {
role: "assistant", role: "assistant",
content: [ content: [
{ {
type: "thinking", type: "thinking",
thinking: "Let me start calculating 20 * 30...", thinking: "Let me start calculating 20 * 30...",
thinkingSignature: "partial_sig" thinkingSignature: "partial_sig",
}, },
{ {
type: "text", type: "text",
text: "I was about to calculate 20 × 30 which is" text: "I was about to calculate 20 × 30 which is",
} },
], ],
provider: "test", provider: "test",
model: "test-model", model: "test-model",
usage: { input: 50, output: 25, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, usage: {
stopReason: "error", input: 50,
error: "Request was aborted" output: 25,
} as AssistantMessage, cacheRead: 0,
toolResult: null, cacheWrite: 0,
facts: { cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
calculation: 600, },
city: "none", stopReason: "error",
temperature: 0, error: "Request was aborted",
capital: "none" } as AssistantMessage,
} toolResult: null,
} facts: {
calculation: 600,
city: "none",
temperature: 0,
capital: "none",
},
},
}; };
/** /**
* Test that a provider can handle contexts from different sources * Test that a provider can handle contexts from different sources
*/ */
async function testProviderHandoff( async function testProviderHandoff<TApi extends Api>(
targetProvider: LLM<any>, targetModel: Model<TApi>,
sourceLabel: string, sourceLabel: string,
sourceContext: typeof providerContexts[keyof typeof providerContexts] sourceContext: (typeof providerContexts)[keyof typeof providerContexts],
): Promise<boolean> { ): Promise<boolean> {
// Build conversation context // Build conversation context
const messages: Message[] = [ const messages: Message[] = [
{ {
role: "user", role: "user",
content: "Please do some calculations, tell me about capitals, and check the weather." content: "Please do some calculations, tell me about capitals, and check the weather.",
}, },
sourceContext.message sourceContext.message,
]; ];
// Add tool result if present // Add tool result if present
if (sourceContext.toolResult) { if (sourceContext.toolResult) {
messages.push(sourceContext.toolResult); messages.push(sourceContext.toolResult);
} }
// Ask follow-up question // Ask follow-up question
messages.push({ messages.push({
role: "user", role: "user",
content: `Based on our conversation, please answer: content: `Based on our conversation, please answer:
1) What was the multiplication result? 1) What was the multiplication result?
2) Which city's weather did we check? 2) Which city's weather did we check?
3) What was the temperature? 3) What was the temperature?
4) What capital city was mentioned? 4) What capital city was mentioned?
Please include the specific numbers and names.` Please include the specific numbers and names.`,
}); });
const context: Context = { const context: Context = {
messages, messages,
tools: [weatherTool] tools: [weatherTool],
}; };
try { try {
const response = await targetProvider.generate(context, {}); const response = await complete(targetModel, context, {});
// Check for error // Check for error
if (response.stopReason === "error") { if (response.stopReason === "error") {
console.log(`[${sourceLabel}${targetProvider.getModel().provider}] Failed with error: ${response.error}`); console.log(`[${sourceLabel}${targetModel.provider}] Failed with error: ${response.error}`);
return false; return false;
} }
// Extract text from response // Extract text from response
const responseText = response.content const responseText = response.content
.filter(b => b.type === "text") .filter((b) => b.type === "text")
.map(b => b.text) .map((b) => b.text)
.join(" ") .join(" ")
.toLowerCase(); .toLowerCase();
// For aborted messages, we don't expect to find the facts // For aborted messages, we don't expect to find the facts
if (sourceContext.message.stopReason === "error") { if (sourceContext.message.stopReason === "error") {
const hasToolCalls = response.content.some(b => b.type === "toolCall"); const hasToolCalls = response.content.some((b) => b.type === "toolCall");
const hasThinking = response.content.some(b => b.type === "thinking"); const hasThinking = response.content.some((b) => b.type === "thinking");
const hasText = response.content.some(b => b.type === "text"); const hasText = response.content.some((b) => b.type === "text");
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true); expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
expect(hasThinking || hasText || hasToolCalls).toBe(true); expect(hasThinking || hasText || hasToolCalls).toBe(true);
console.log(`[${sourceLabel}${targetProvider.getModel().provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`); console.log(
return true; `[${sourceLabel}${targetModel.provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`,
} );
return true;
}
// Check if response contains our facts // Check if response contains our facts
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString()); const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
const hasCity = sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase()); const hasCity =
const hasTemperature = sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString()); sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
const hasCapital = sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase()); const hasTemperature =
sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
const hasCapital =
sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
const success = hasCalculation && hasCity && hasTemperature && hasCapital; const success = hasCalculation && hasCity && hasTemperature && hasCapital;
console.log(`[${sourceLabel}${targetProvider.getModel().provider}] Handoff test:`); console.log(`[${sourceLabel}${targetModel.provider}] Handoff test:`);
if (!success) { if (!success) {
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? '✓' : '✗'}`); console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? "✓" : "✗"}`);
console.log(` City (${sourceContext.facts.city}): ${hasCity ? '✓' : '✗'}`); console.log(` City (${sourceContext.facts.city}): ${hasCity ? "✓" : "✗"}`);
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? '✓' : '✗'}`); console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? "✓" : "✗"}`);
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? '✓' : '✗'}`); console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? "✓" : "✗"}`);
} else { } else {
console.log(` ✓ All facts found`); console.log(` ✓ All facts found`);
} }
return success; return success;
} catch (error) { } catch (error) {
console.error(`[${sourceLabel}${targetProvider.getModel().provider}] Exception:`, error); console.error(`[${sourceLabel}${targetModel.provider}] Exception:`, error);
return false; return false;
} }
} }
describe("Cross-Provider Handoff Tests", () => { describe("Cross-Provider Handoff Tests", () => {
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => { describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
let provider: AnthropicLLM; const model = getModel("anthropic", "claude-3-5-haiku-20241022");
beforeAll(() => { it("should handle contexts from all providers", async () => {
const model = getModel("anthropic", "claude-3-5-haiku-20241022"); console.log("\nTesting Anthropic with pre-built contexts:\n");
if (model) {
provider = new AnthropicLLM(model, process.env.ANTHROPIC_API_KEY!);
}
});
it("should handle contexts from all providers", async () => { const contextTests = [
if (!provider) { { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
console.log("Anthropic provider not available, skipping"); { label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
return; { label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
} { label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
];
console.log("\nTesting Anthropic with pre-built contexts:\n"); let successCount = 0;
let skippedCount = 0;
const contextTests = [ for (const { label, context, sourceModel } of contextTests) {
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, // Skip testing same model against itself
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, if (sourceModel && sourceModel === model.id) {
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, console.log(`[${label}${model.provider}] Skipping same-model test`);
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" }, skippedCount++;
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null } continue;
]; }
const success = await testProviderHandoff(model, label, context);
if (success) successCount++;
}
let successCount = 0; const totalTests = contextTests.length - skippedCount;
let skippedCount = 0; console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
for (const { label, context, sourceModel } of contextTests) { // All non-skipped handoffs should succeed
// Skip testing same model against itself expect(successCount).toBe(totalTests);
if (sourceModel && sourceModel === provider.getModel().id) { });
console.log(`[${label}${provider.getModel().provider}] Skipping same-model test`); });
skippedCount++;
continue;
}
const success = await testProviderHandoff(provider, label, context);
if (success) successCount++;
}
const totalTests = contextTests.length - skippedCount; describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); const model = getModel("google", "gemini-2.5-flash");
// All non-skipped handoffs should succeed it("should handle contexts from all providers", async () => {
expect(successCount).toBe(totalTests); console.log("\nTesting Google with pre-built contexts:\n");
});
});
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => { const contextTests = [
let provider: GoogleLLM; { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
];
beforeAll(() => { let successCount = 0;
const model = getModel("google", "gemini-2.5-flash"); let skippedCount = 0;
if (model) {
provider = new GoogleLLM(model, process.env.GEMINI_API_KEY!);
}
});
it("should handle contexts from all providers", async () => { for (const { label, context, sourceModel } of contextTests) {
if (!provider) { // Skip testing same model against itself
console.log("Google provider not available, skipping"); if (sourceModel && sourceModel === model.id) {
return; console.log(`[${label}${model.provider}] Skipping same-model test`);
} skippedCount++;
continue;
}
const success = await testProviderHandoff(model, label, context);
if (success) successCount++;
}
console.log("\nTesting Google with pre-built contexts:\n"); const totalTests = contextTests.length - skippedCount;
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
const contextTests = [ // All non-skipped handoffs should succeed
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, expect(successCount).toBe(totalTests);
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" }, });
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" }, });
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
];
let successCount = 0; describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
let skippedCount = 0; const model: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
for (const { label, context, sourceModel } of contextTests) { it("should handle contexts from all providers", async () => {
// Skip testing same model against itself console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
if (sourceModel && sourceModel === provider.getModel().id) {
console.log(`[${label}${provider.getModel().provider}] Skipping same-model test`);
skippedCount++;
continue;
}
const success = await testProviderHandoff(provider, label, context);
if (success) successCount++;
}
const totalTests = contextTests.length - skippedCount; const contextTests = [
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
];
// All non-skipped handoffs should succeed let successCount = 0;
expect(successCount).toBe(totalTests); let skippedCount = 0;
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => { for (const { label, context, sourceModel } of contextTests) {
let provider: OpenAICompletionsLLM; // Skip testing same model against itself
if (sourceModel && sourceModel === model.id) {
console.log(`[${label}${model.provider}] Skipping same-model test`);
skippedCount++;
continue;
}
const success = await testProviderHandoff(model, label, context);
if (success) successCount++;
}
beforeAll(() => { const totalTests = contextTests.length - skippedCount;
const model = getModel("openai", "gpt-4o-mini"); console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
if (model) {
provider = new OpenAICompletionsLLM(model, process.env.OPENAI_API_KEY!);
}
});
it("should handle contexts from all providers", async () => { // All non-skipped handoffs should succeed
if (!provider) { expect(successCount).toBe(totalTests);
console.log("OpenAI Completions provider not available, skipping"); });
return; });
}
console.log("\nTesting OpenAI Completions with pre-built contexts:\n"); describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
const model = getModel("openai", "gpt-5-mini");
const contextTests = [ it("should handle contexts from all providers", async () => {
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" }, console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
];
let successCount = 0; const contextTests = [
let skippedCount = 0; { label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
];
for (const { label, context, sourceModel } of contextTests) { let successCount = 0;
// Skip testing same model against itself let skippedCount = 0;
if (sourceModel && sourceModel === provider.getModel().id) {
console.log(`[${label}${provider.getModel().provider}] Skipping same-model test`);
skippedCount++;
continue;
}
const success = await testProviderHandoff(provider, label, context);
if (success) successCount++;
}
const totalTests = contextTests.length - skippedCount; for (const { label, context, sourceModel } of contextTests) {
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`); // Skip testing same model against itself
if (sourceModel && sourceModel === model.id) {
console.log(`[${label}${model.provider}] Skipping same-model test`);
skippedCount++;
continue;
}
const success = await testProviderHandoff(model, label, context);
if (success) successCount++;
}
// All non-skipped handoffs should succeed const totalTests = contextTests.length - skippedCount;
expect(successCount).toBe(totalTests); console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => { // All non-skipped handoffs should succeed
let provider: OpenAIResponsesLLM; expect(successCount).toBe(totalTests);
});
beforeAll(() => { });
const model = getModel("openai", "gpt-5-mini");
if (model) {
provider = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
}
});
it("should handle contexts from all providers", async () => {
if (!provider) {
console.log("OpenAI Responses provider not available, skipping");
return;
}
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
const contextTests = [
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
];
let successCount = 0;
let skippedCount = 0;
for (const { label, context, sourceModel } of contextTests) {
// Skip testing same model against itself
if (sourceModel && sourceModel === provider.getModel().id) {
console.log(`[${label}${provider.getModel().provider}] Skipping same-model test`);
skippedCount++;
continue;
}
const success = await testProviderHandoff(provider, label, context);
if (success) successCount++;
}
const totalTests = contextTests.length - skippedCount;
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
// All non-skipped handoffs should succeed
expect(successCount).toBe(totalTests);
});
});
}); });

View file

@ -1,31 +0,0 @@
import { GoogleGenAI } from "@google/genai";
import OpenAI from "openai";
const ai = new GoogleGenAI({});
async function main() {
/*let pager = await ai.models.list();
do {
for (const model of pager.page) {
console.log(JSON.stringify(model, null, 2));
console.log("---");
}
if (!pager.hasNextPage()) break;
await pager.nextPage();
} while (true);*/
const openai = new OpenAI();
const response = await openai.models.list();
do {
const page = response.data;
for (const model of page) {
const info = await openai.models.retrieve(model.id);
console.log(JSON.stringify(model, null, 2));
console.log("---");
}
if (!response.hasNextPage()) break;
await response.getNextPage();
} while (true);
}
await main();

View file

@ -1,618 +0,0 @@
import { describe, it, beforeAll, afterAll, expect } from "vitest";
import { GoogleLLM } from "../src/providers/google.js";
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
import { AnthropicLLM } from "../src/providers/anthropic.js";
import type { LLM, LLMOptions, Context, Tool, AssistantMessage, Model, ImageContent } from "../src/types.js";
import { spawn, ChildProcess, execSync } from "child_process";
import { createLLM, getModel } from "../src/models.js";
import { readFileSync } from "fs";
import { join, dirname } from "path";
import { fileURLToPath } from "url";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Calculator tool definition (same as examples)
const calculatorTool: Tool = {
name: "calculator",
description: "Perform basic arithmetic operations",
parameters: {
type: "object",
properties: {
a: { type: "number", description: "First number" },
b: { type: "number", description: "Second number" },
operation: {
type: "string",
enum: ["add", "subtract", "multiply", "divide"],
description: "The operation to perform"
}
},
required: ["a", "b", "operation"]
}
};
async function basicTextGeneration<T extends LLMOptions>(llm: LLM<T>) {
const context: Context = {
systemPrompt: "You are a helpful assistant. Be concise.",
messages: [
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
]
};
const response = await llm.generate(context);
expect(response.role).toBe("assistant");
expect(response.content).toBeTruthy();
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.error).toBeFalsy();
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
context.messages.push(response);
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
const secondResponse = await llm.generate(context);
expect(secondResponse.role).toBe("assistant");
expect(secondResponse.content).toBeTruthy();
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
expect(secondResponse.usage.output).toBeGreaterThan(0);
expect(secondResponse.error).toBeFalsy();
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
}
async function handleToolCall<T extends LLMOptions>(llm: LLM<T>) {
const context: Context = {
systemPrompt: "You are a helpful assistant that uses tools when asked.",
messages: [{
role: "user",
content: "Calculate 15 + 27 using the calculator tool."
}],
tools: [calculatorTool]
};
const response = await llm.generate(context);
expect(response.stopReason).toBe("toolUse");
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
const toolCall = response.content.find(b => b.type == "toolCall")!;
expect(toolCall.name).toBe("calculator");
expect(toolCall.id).toBeTruthy();
}
async function handleStreaming<T extends LLMOptions>(llm: LLM<T>) {
let textStarted = false;
let textChunks = "";
let textCompleted = false;
const context: Context = {
messages: [{ role: "user", content: "Count from 1 to 3" }]
};
const response = await llm.generate(context, {
onEvent: (event) => {
if (event.type === "text_start") {
textStarted = true;
} else if (event.type === "text_delta") {
textChunks += event.delta;
} else if (event.type === "text_end") {
textCompleted = true;
}
}
} as T);
expect(textStarted).toBe(true);
expect(textChunks.length).toBeGreaterThan(0);
expect(textCompleted).toBe(true);
expect(response.content.some(b => b.type == "text")).toBeTruthy();
}
async function handleThinking<T extends LLMOptions>(llm: LLM<T>, options: T) {
let thinkingStarted = false;
let thinkingChunks = "";
let thinkingCompleted = false;
const context: Context = {
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
};
const response = await llm.generate(context, {
onEvent: (event) => {
if (event.type === "thinking_start") {
thinkingStarted = true;
} else if (event.type === "thinking_delta") {
expect(event.content.endsWith(event.delta)).toBe(true);
thinkingChunks += event.delta;
} else if (event.type === "thinking_end") {
thinkingCompleted = true;
}
},
...options
});
expect(response.stopReason, `Error: ${(response as any).error}`).toBe("stop");
expect(thinkingStarted).toBe(true);
expect(thinkingChunks.length).toBeGreaterThan(0);
expect(thinkingCompleted).toBe(true);
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
}
async function handleImage<T extends LLMOptions>(llm: LLM<T>) {
// Check if the model supports images
const model = llm.getModel();
if (!model.input.includes("image")) {
console.log(`Skipping image test - model ${model.id} doesn't support images`);
return;
}
// Read the test image
const imagePath = join(__dirname, "data", "red-circle.png");
const imageBuffer = readFileSync(imagePath);
const base64Image = imageBuffer.toString("base64");
const imageContent: ImageContent = {
type: "image",
data: base64Image,
mimeType: "image/png",
};
const context: Context = {
messages: [
{
role: "user",
content: [
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
imageContent,
],
},
],
};
const response = await llm.generate(context);
// Check the response mentions red and circle
expect(response.content.length > 0).toBeTruthy();
const lowerContent = response.content.find(b => b.type == "text")?.text || "";
expect(lowerContent).toContain("red");
expect(lowerContent).toContain("circle");
}
async function multiTurn<T extends LLMOptions>(llm: LLM<T>, thinkingOptions: T) {
const context: Context = {
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
messages: [
{
role: "user",
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
}
],
tools: [calculatorTool]
};
// Collect all text content from all assistant responses
let allTextContent = "";
let hasSeenThinking = false;
let hasSeenToolCalls = false;
const maxTurns = 5; // Prevent infinite loops
for (let turn = 0; turn < maxTurns; turn++) {
const response = await llm.generate(context, thinkingOptions);
// Add the assistant response to context
context.messages.push(response);
// Process content blocks
for (const block of response.content) {
if (block.type === "text") {
allTextContent += block.text;
} else if (block.type === "thinking") {
hasSeenThinking = true;
} else if (block.type === "toolCall") {
hasSeenToolCalls = true;
// Process the tool call
expect(block.name).toBe("calculator");
expect(block.id).toBeTruthy();
expect(block.arguments).toBeTruthy();
const { a, b, operation } = block.arguments;
let result: number;
switch (operation) {
case "add": result = a + b; break;
case "multiply": result = a * b; break;
default: result = 0;
}
// Add tool result to context
context.messages.push({
role: "toolResult",
toolCallId: block.id,
toolName: block.name,
content: `${result}`,
isError: false
});
}
}
// If we got a stop response with text content, we're likely done
expect(response.stopReason).not.toBe("error");
if (response.stopReason === "stop") {
break;
}
}
// Verify we got either thinking content or tool calls (or both)
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
// The accumulated text should reference both calculations
expect(allTextContent).toBeTruthy();
expect(allTextContent.includes("714")).toBe(true);
expect(allTextContent.includes("887")).toBe(true);
}
describe("AI Providers E2E Tests", () => {
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
let llm: GoogleLLM;
beforeAll(() => {
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {thinking: { enabled: true, budgetTokens: 1024 }});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
let llm: OpenAIResponsesLLM;
beforeAll(() => {
llm = new OpenAIResponsesLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", {retry: 2}, async () => {
await handleThinking(llm, {reasoningEffort: "high"});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {reasoningEffort: "high"});
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
let llm: AnthropicLLM;
beforeAll(() => {
llm = new AnthropicLLM(getModel("anthropic", "claude-sonnet-4-20250514")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {thinking: { enabled: true } });
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
llm = new OpenAICompletionsLLM(getModel("xai", "grok-code-fast-1")!, process.env.XAI_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {reasoningEffort: "medium"});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {reasoningEffort: "medium"});
});
});
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
llm = new OpenAICompletionsLLM(getModel("groq", "openai/gpt-oss-20b")!, process.env.GROQ_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {reasoningEffort: "medium"});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {reasoningEffort: "medium"});
});
});
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
llm = new OpenAICompletionsLLM(getModel("cerebras", "gpt-oss-120b")!, process.env.CEREBRAS_API_KEY!);
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {reasoningEffort: "medium"});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {reasoningEffort: "medium"});
});
});
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
let llm: OpenAICompletionsLLM;
beforeAll(() => {
llm = new OpenAICompletionsLLM(getModel("openrouter", "z-ai/glm-4.5v")!, process.env.OPENROUTER_API_KEY!);;
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {reasoningEffort: "medium"});
});
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
await multiTurn(llm, {reasoningEffort: "medium"});
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
// Check if ollama is installed
let ollamaInstalled = false;
try {
execSync("which ollama", { stdio: "ignore" });
ollamaInstalled = true;
} catch {
ollamaInstalled = false;
}
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
let llm: OpenAICompletionsLLM;
let ollamaProcess: ChildProcess | null = null;
beforeAll(async () => {
// Check if model is available, if not pull it
try {
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
} catch {
console.log("Pulling gpt-oss:20b model for Ollama tests...");
try {
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
} catch (e) {
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
return;
}
}
// Start ollama server
ollamaProcess = spawn("ollama", ["serve"], {
detached: false,
stdio: "ignore"
});
// Wait for server to be ready
await new Promise<void>((resolve) => {
const checkServer = async () => {
try {
const response = await fetch("http://localhost:11434/api/tags");
if (response.ok) {
resolve();
} else {
setTimeout(checkServer, 500);
}
} catch {
setTimeout(checkServer, 500);
}
};
setTimeout(checkServer, 1000); // Initial delay
});
const model: Model = {
id: "gpt-oss:20b",
provider: "ollama",
baseUrl: "http://localhost:11434/v1",
reasoning: true,
input: ["text"],
contextWindow: 128000,
maxTokens: 16000,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
},
name: "Ollama GPT-OSS 20B"
}
llm = new OpenAICompletionsLLM(model, "dummy");
}, 30000); // 30 second timeout for setup
afterAll(() => {
// Kill ollama server
if (ollamaProcess) {
ollamaProcess.kill("SIGTERM");
ollamaProcess = null;
}
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle thinking mode", async () => {
await handleThinking(llm, {reasoningEffort: "medium"});
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {reasoningEffort: "medium"});
});
});
/*
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => {
let llm: AnthropicLLM;
beforeAll(() => {
llm = createLLM("anthropic", "claude-3-5-haiku-latest");
});
it("should complete basic text generation", async () => {
await basicTextGeneration(llm);
});
it("should handle tool calling", async () => {
await handleToolCall(llm);
});
it("should handle streaming", async () => {
await handleStreaming(llm);
});
it("should handle multi-turn with thinking and tools", async () => {
await multiTurn(llm, {thinking: {enabled: true}});
});
it("should handle image input", async () => {
await handleImage(llm);
});
});
*/
});

View file

@ -1,15 +1,6 @@
#!/usr/bin/env npx tsx #!/usr/bin/env npx tsx
import {
Container,
LoadingAnimation,
TextComponent,
TextEditor,
TUI,
WhitespaceComponent,
} from "../src/index.js";
import chalk from "chalk"; import chalk from "chalk";
import { Container, LoadingAnimation, TextComponent, TextEditor, TUI, WhitespaceComponent } from "../src/index.js";
/** /**
* Test the new smart double-buffered TUI implementation * Test the new smart double-buffered TUI implementation
@ -24,7 +15,7 @@ async function main() {
// Monkey-patch requestRender to measure performance // Monkey-patch requestRender to measure performance
const originalRequestRender = ui.requestRender.bind(ui); const originalRequestRender = ui.requestRender.bind(ui);
ui.requestRender = function() { ui.requestRender = () => {
const startTime = process.hrtime.bigint(); const startTime = process.hrtime.bigint();
originalRequestRender(); originalRequestRender();
process.nextTick(() => { process.nextTick(() => {
@ -38,10 +29,12 @@ async function main() {
// Add header // Add header
const header = new TextComponent( const header = new TextComponent(
chalk.bold.green("Smart Double Buffer TUI Test") + "\n" + chalk.bold.green("Smart Double Buffer TUI Test") +
chalk.dim("Testing new implementation with component-level caching and smart diffing") + "\n" + "\n" +
chalk.dim("Press CTRL+C to exit"), chalk.dim("Testing new implementation with component-level caching and smart diffing") +
{ bottom: 1 } "\n" +
chalk.dim("Press CTRL+C to exit"),
{ bottom: 1 },
); );
ui.addChild(header); ui.addChild(header);
@ -57,7 +50,9 @@ async function main() {
// Add text editor // Add text editor
const editor = new TextEditor(); const editor = new TextEditor();
editor.setText("Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still."); editor.setText(
"Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still.",
);
container.addChild(editor); container.addChild(editor);
// Add the container to UI // Add the container to UI
@ -71,15 +66,20 @@ async function main() {
const statsInterval = setInterval(() => { const statsInterval = setInterval(() => {
if (renderCount > 0) { if (renderCount > 0) {
const avgRenderTime = Number(totalRenderTime / BigInt(renderCount)) / 1_000_000; // Convert to ms const avgRenderTime = Number(totalRenderTime / BigInt(renderCount)) / 1_000_000; // Convert to ms
const lastRenderTime = renderTimings.length > 0 const lastRenderTime =
? Number(renderTimings[renderTimings.length - 1]) / 1_000_000 renderTimings.length > 0 ? Number(renderTimings[renderTimings.length - 1]) / 1_000_000 : 0;
: 0;
const avgLinesRedrawn = ui.getAverageLinesRedrawn(); const avgLinesRedrawn = ui.getAverageLinesRedrawn();
statsComponent.setText( statsComponent.setText(
chalk.yellow(`Performance Stats:`) + "\n" + chalk.yellow(`Performance Stats:`) +
chalk.dim(`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`) + "\n" + "\n" +
chalk.dim(`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`) chalk.dim(
`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`,
) +
"\n" +
chalk.dim(
`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`,
),
); );
} }
}, 1000); }, 1000);
@ -96,7 +96,11 @@ async function main() {
ui.stop(); ui.stop();
console.log("\n" + chalk.green("Exited double-buffer test")); console.log("\n" + chalk.green("Exited double-buffer test"));
console.log(chalk.dim(`Total renders: ${renderCount}`)); console.log(chalk.dim(`Total renders: ${renderCount}`));
console.log(chalk.dim(`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`)); console.log(
chalk.dim(
`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`,
),
);
console.log(chalk.dim(`Total lines redrawn: ${ui.getLinesRedrawn()}`)); console.log(chalk.dim(`Total lines redrawn: ${ui.getLinesRedrawn()}`));
console.log(chalk.dim(`Average lines redrawn per render: ${ui.getAverageLinesRedrawn().toFixed(1)}`)); console.log(chalk.dim(`Average lines redrawn per render: ${ui.getAverageLinesRedrawn().toFixed(1)}`));
process.exit(0); process.exit(0);

View file

@ -1,5 +1,12 @@
#!/usr/bin/env npx tsx #!/usr/bin/env npx tsx
import { TUI, Container, TextEditor, TextComponent, MarkdownComponent, CombinedAutocompleteProvider } from "../src/index.js"; import {
CombinedAutocompleteProvider,
Container,
MarkdownComponent,
TextComponent,
TextEditor,
TUI,
} from "../src/index.js";
/** /**
* Chat Application with Autocomplete * Chat Application with Autocomplete
@ -16,7 +23,7 @@ const ui = new TUI();
// Add header with instructions // Add header with instructions
const header = new TextComponent( const header = new TextComponent(
"💬 Chat Demo | Type '/' for commands | Start typing a filename + Tab to autocomplete | Ctrl+C to exit", "💬 Chat Demo | Type '/' for commands | Start typing a filename + Tab to autocomplete | Ctrl+C to exit",
{ bottom: 1 } { bottom: 1 },
); );
const chatHistory = new Container(); const chatHistory = new Container();
@ -82,7 +89,8 @@ ui.onGlobalKeyPress = (data: string) => {
}; };
// Add initial welcome message to chat history // Add initial welcome message to chat history
chatHistory.addChild(new MarkdownComponent(` chatHistory.addChild(
new MarkdownComponent(`
## Welcome to the Chat Demo! ## Welcome to the Chat Demo!
**Available slash commands:** **Available slash commands:**
@ -96,7 +104,8 @@ chatHistory.addChild(new MarkdownComponent(`
- Works with home directory (\`~/\`) - Works with home directory (\`~/\`)
Try it out! Type a message or command below. Try it out! Type a message or command below.
`)); `),
);
ui.addChild(header); ui.addChild(header);
ui.addChild(chatHistory); ui.addChild(chatHistory);

View file

@ -1,7 +1,7 @@
import { test, describe } from "node:test";
import assert from "node:assert"; import assert from "node:assert";
import { describe, test } from "node:test";
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
import { VirtualTerminal } from "./virtual-terminal.js"; import { VirtualTerminal } from "./virtual-terminal.js";
import { TUI, Container, TextComponent, TextEditor } from "../src/index.js";
describe("Differential Rendering - Dynamic Content", () => { describe("Differential Rendering - Dynamic Content", () => {
test("handles static text, dynamic container, and text editor correctly", async () => { test("handles static text, dynamic container, and text editor correctly", async () => {
@ -23,7 +23,7 @@ describe("Differential Rendering - Dynamic Content", () => {
ui.setFocus(editor); ui.setFocus(editor);
// Wait for next tick to complete and flush virtual terminal // Wait for next tick to complete and flush virtual terminal
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 4: Check initial output in scrollbuffer // Step 4: Check initial output in scrollbuffer
@ -35,12 +35,14 @@ describe("Differential Rendering - Dynamic Content", () => {
console.log("ScrollBuffer lines:", scrollBuffer.length); console.log("ScrollBuffer lines:", scrollBuffer.length);
// Count non-empty lines in scrollbuffer // Count non-empty lines in scrollbuffer
let nonEmptyInBuffer = scrollBuffer.filter(line => line.trim() !== "").length; const nonEmptyInBuffer = scrollBuffer.filter((line) => line.trim() !== "").length;
console.log("Non-empty lines in scrollbuffer:", nonEmptyInBuffer); console.log("Non-empty lines in scrollbuffer:", nonEmptyInBuffer);
// Verify initial render has static text in scrollbuffer // Verify initial render has static text in scrollbuffer
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")), assert.ok(
`Expected static text in scrollbuffer`); scrollBuffer.some((line) => line.includes("Static Header Text")),
`Expected static text in scrollbuffer`,
);
// Step 5: Add 100 text components to container // Step 5: Add 100 text components to container
console.log("\nAdding 100 components to container..."); console.log("\nAdding 100 components to container...");
@ -52,7 +54,7 @@ describe("Differential Rendering - Dynamic Content", () => {
ui.requestRender(); ui.requestRender();
// Wait for next tick to complete and flush // Wait for next tick to complete and flush
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 6: Check output after adding 100 components // Step 6: Check output after adding 100 components
@ -65,7 +67,7 @@ describe("Differential Rendering - Dynamic Content", () => {
// Count all dynamic items in scrollbuffer // Count all dynamic items in scrollbuffer
let dynamicItemsInBuffer = 0; let dynamicItemsInBuffer = 0;
let allItemNumbers = new Set<number>(); const allItemNumbers = new Set<number>();
for (const line of scrollBuffer) { for (const line of scrollBuffer) {
const match = line.match(/Dynamic Item (\d+)/); const match = line.match(/Dynamic Item (\d+)/);
if (match) { if (match) {
@ -80,8 +82,11 @@ describe("Differential Rendering - Dynamic Content", () => {
// CRITICAL TEST: The scrollbuffer should contain ALL 100 items // CRITICAL TEST: The scrollbuffer should contain ALL 100 items
// This is what the differential render should preserve! // This is what the differential render should preserve!
assert.strictEqual(allItemNumbers.size, 100, assert.strictEqual(
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`); allItemNumbers.size,
100,
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`,
);
// Verify items are 1-100 // Verify items are 1-100
for (let i = 1; i <= 100; i++) { for (let i = 1; i <= 100; i++) {
@ -89,15 +94,20 @@ describe("Differential Rendering - Dynamic Content", () => {
} }
// Also verify the static header is still in scrollbuffer // Also verify the static header is still in scrollbuffer
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")), assert.ok(
"Static header should still be in scrollbuffer"); scrollBuffer.some((line) => line.includes("Static Header Text")),
"Static header should still be in scrollbuffer",
);
// And the editor should be there too // And the editor should be there too
assert.ok(scrollBuffer.some(line => line.includes("╭") && line.includes("╮")), assert.ok(
"Editor top border should be in scrollbuffer"); scrollBuffer.some((line) => line.includes("╭") && line.includes("╮")),
assert.ok(scrollBuffer.some(line => line.includes("╰") && line.includes("╯")), "Editor top border should be in scrollbuffer",
"Editor bottom border should be in scrollbuffer"); );
assert.ok(
scrollBuffer.some((line) => line.includes("╰") && line.includes("╯")),
"Editor bottom border should be in scrollbuffer",
);
ui.stop(); ui.stop();
}); });
@ -124,7 +134,7 @@ describe("Differential Rendering - Dynamic Content", () => {
contentContainer.addChild(new TextComponent("Content Line 2")); contentContainer.addChild(new TextComponent("Content Line 2"));
// Initial render // Initial render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
let viewport = terminal.getViewport(); let viewport = terminal.getViewport();
@ -142,7 +152,7 @@ describe("Differential Rendering - Dynamic Content", () => {
statusContainer.addChild(new TextComponent("Status: Processing...")); statusContainer.addChild(new TextComponent("Status: Processing..."));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
viewport = terminal.getViewport(); viewport = terminal.getViewport();
@ -162,7 +172,7 @@ describe("Differential Rendering - Dynamic Content", () => {
} }
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
viewport = terminal.getViewport(); viewport = terminal.getViewport();
@ -180,7 +190,7 @@ describe("Differential Rendering - Dynamic Content", () => {
contentLine10.setText("Content Line 10 - MODIFIED"); contentLine10.setText("Content Line 10 - MODIFIED");
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
viewport = terminal.getViewport(); viewport = terminal.getViewport();

View file

@ -1,6 +1,6 @@
import { TUI, SelectList } from "../src/index.js";
import { readdirSync, statSync } from "fs"; import { readdirSync, statSync } from "fs";
import { join } from "path"; import { join } from "path";
import { SelectList, TUI } from "../src/index.js";
const ui = new TUI(); const ui = new TUI();
ui.start(); ui.start();

View file

@ -1,6 +1,6 @@
import { describe, test } from "node:test";
import assert from "node:assert"; import assert from "node:assert";
import { TextEditor, TextComponent, Container, TUI } from "../src/index.js"; import { describe, test } from "node:test";
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
import { VirtualTerminal } from "./virtual-terminal.js"; import { VirtualTerminal } from "./virtual-terminal.js";
describe("Layout shift artifacts", () => { describe("Layout shift artifacts", () => {
@ -27,7 +27,7 @@ describe("Layout shift artifacts", () => {
// Initial render // Initial render
ui.start(); ui.start();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Capture initial state // Capture initial state
@ -40,7 +40,7 @@ describe("Layout shift artifacts", () => {
ui.requestRender(); ui.requestRender();
// Wait for render // Wait for render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Capture state with status message // Capture state with status message
@ -51,7 +51,7 @@ describe("Layout shift artifacts", () => {
ui.requestRender(); ui.requestRender();
// Wait for render // Wait for render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Capture final state // Capture final state
@ -64,8 +64,12 @@ describe("Layout shift artifacts", () => {
const nextLine = finalViewport[i + 1]; const nextLine = finalViewport[i + 1];
// Check if we have duplicate bottom borders (the artifact) // Check if we have duplicate bottom borders (the artifact)
if (currentLine.includes("╰") && currentLine.includes("╯") && if (
nextLine.includes("╰") && nextLine.includes("╯")) { currentLine.includes("╰") &&
currentLine.includes("╯") &&
nextLine.includes("╰") &&
nextLine.includes("╯")
) {
foundDuplicateBorder = true; foundDuplicateBorder = true;
} }
} }
@ -74,18 +78,12 @@ describe("Layout shift artifacts", () => {
assert.strictEqual(foundDuplicateBorder, false, "Found duplicate bottom borders - rendering artifact detected!"); assert.strictEqual(foundDuplicateBorder, false, "Found duplicate bottom borders - rendering artifact detected!");
// Also check that there's only one bottom border total // Also check that there's only one bottom border total
const bottomBorderCount = finalViewport.filter((line) => const bottomBorderCount = finalViewport.filter((line) => line.includes("╰")).length;
line.includes("╰")
).length;
assert.strictEqual(bottomBorderCount, 1, `Expected 1 bottom border, found ${bottomBorderCount}`); assert.strictEqual(bottomBorderCount, 1, `Expected 1 bottom border, found ${bottomBorderCount}`);
// Verify the editor is back in its original position // Verify the editor is back in its original position
const finalEditorStartLine = finalViewport.findIndex((line) => const finalEditorStartLine = finalViewport.findIndex((line) => line.includes("╭"));
line.includes("╭") const initialEditorStartLine = initialViewport.findIndex((line) => line.includes("╭"));
);
const initialEditorStartLine = initialViewport.findIndex((line) =>
line.includes("╭")
);
assert.strictEqual(finalEditorStartLine, initialEditorStartLine); assert.strictEqual(finalEditorStartLine, initialEditorStartLine);
ui.stop(); ui.stop();
@ -103,7 +101,7 @@ describe("Layout shift artifacts", () => {
// Initial render // Initial render
ui.start(); ui.start();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Rapidly add and remove a status message // Rapidly add and remove a status message
@ -112,25 +110,21 @@ describe("Layout shift artifacts", () => {
// Add status // Add status
ui.children.splice(1, 0, status); ui.children.splice(1, 0, status);
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Remove status immediately // Remove status immediately
ui.children.splice(1, 1); ui.children.splice(1, 1);
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await term.flush(); await term.flush();
// Final output check // Final output check
const finalViewport = term.getViewport(); const finalViewport = term.getViewport();
// Should only have one set of borders for the editor // Should only have one set of borders for the editor
const topBorderCount = finalViewport.filter((line) => const topBorderCount = finalViewport.filter((line) => line.includes("╭") && line.includes("╮")).length;
line.includes("╭") && line.includes("╮") const bottomBorderCount = finalViewport.filter((line) => line.includes("╰") && line.includes("╯")).length;
).length;
const bottomBorderCount = finalViewport.filter((line) =>
line.includes("╰") && line.includes("╯")
).length;
assert.strictEqual(topBorderCount, 1); assert.strictEqual(topBorderCount, 1);
assert.strictEqual(bottomBorderCount, 1); assert.strictEqual(bottomBorderCount, 1);

View file

@ -1,5 +1,5 @@
#!/usr/bin/env npx tsx #!/usr/bin/env npx tsx
import { TUI, Container, TextComponent, TextEditor, MarkdownComponent } from "../src/index.js"; import { Container, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
/** /**
* Multi-Component Layout Demo * Multi-Component Layout Demo

View file

@ -1,7 +1,7 @@
import { test, describe } from "node:test";
import assert from "node:assert"; import assert from "node:assert";
import { describe, test } from "node:test";
import { Container, LoadingAnimation, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
import { VirtualTerminal } from "./virtual-terminal.js"; import { VirtualTerminal } from "./virtual-terminal.js";
import { TUI, Container, TextComponent, MarkdownComponent, TextEditor, LoadingAnimation } from "../src/index.js";
describe("Multi-Message Garbled Output Reproduction", () => { describe("Multi-Message Garbled Output Reproduction", () => {
test("handles rapid message additions with large content without garbling", async () => { test("handles rapid message additions with large content without garbling", async () => {
@ -20,7 +20,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
ui.setFocus(editor); ui.setFocus(editor);
// Initial render // Initial render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 1: Simulate user message // Step 1: Simulate user message
@ -32,7 +32,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
statusContainer.addChild(loadingAnim); statusContainer.addChild(loadingAnim);
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 3: Simulate rapid tool calls with large outputs // Step 3: Simulate rapid tool calls with large outputs
@ -54,7 +54,7 @@ node_modules/get-tsconfig/README.md
chatContainer.addChild(new TextComponent(globResult)); chatContainer.addChild(new TextComponent(globResult));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Simulate multiple read tool calls with long content // Simulate multiple read tool calls with long content
@ -74,7 +74,7 @@ A collection of tools for managing LLM deployments and building AI agents.
chatContainer.addChild(new MarkdownComponent(readmeContent)); chatContainer.addChild(new MarkdownComponent(readmeContent));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Second read with even more content // Second read with even more content
@ -94,7 +94,7 @@ Terminal UI framework with surgical differential rendering for building flicker-
chatContainer.addChild(new MarkdownComponent(tuiReadmeContent)); chatContainer.addChild(new MarkdownComponent(tuiReadmeContent));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 4: Stop loading animation and add assistant response // Step 4: Stop loading animation and add assistant response
@ -114,7 +114,7 @@ The TUI library features surgical differential rendering that minimizes screen u
chatContainer.addChild(new MarkdownComponent(assistantResponse)); chatContainer.addChild(new MarkdownComponent(assistantResponse));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Step 5: CRITICAL - Send a new message while previous content is displayed // Step 5: CRITICAL - Send a new message while previous content is displayed
@ -126,7 +126,7 @@ The TUI library features surgical differential rendering that minimizes screen u
statusContainer.addChild(loadingAnim2); statusContainer.addChild(loadingAnim2);
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Add assistant response // Add assistant response
@ -144,7 +144,7 @@ Key aspects:
chatContainer.addChild(new MarkdownComponent(secondResponse)); chatContainer.addChild(new MarkdownComponent(secondResponse));
ui.requestRender(); ui.requestRender();
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Debug: Show the garbled output after the problematic step // Debug: Show the garbled output after the problematic step
@ -158,14 +158,20 @@ Key aspects:
const finalOutput = terminal.getScrollBuffer(); const finalOutput = terminal.getScrollBuffer();
// Check that first user message is NOT garbled // Check that first user message is NOT garbled
const userLine1 = finalOutput.find(line => line.includes("read all README.md files")); const userLine1 = finalOutput.find((line) => line.includes("read all README.md files"));
assert.strictEqual(userLine1, "read all README.md files except in node_modules", assert.strictEqual(
`First user message is garbled: "${userLine1}"`); userLine1,
"read all README.md files except in node_modules",
`First user message is garbled: "${userLine1}"`,
);
// Check that second user message is clean // Check that second user message is clean
const userLine2 = finalOutput.find(line => line.includes("What is the main purpose")); const userLine2 = finalOutput.find((line) => line.includes("What is the main purpose"));
assert.strictEqual(userLine2, "What is the main purpose of the TUI library?", assert.strictEqual(
`Second user message is garbled: "${userLine2}"`); userLine2,
"What is the main purpose of the TUI library?",
`Second user message is garbled: "${userLine2}"`,
);
// Check for common garbling patterns // Check for common garbling patterns
const garbledPatterns = [ const garbledPatterns = [
@ -173,11 +179,11 @@ Key aspects:
"README.mdectly", "README.mdectly",
"modulesl rendering", "modulesl rendering",
"[assistant]ns.", "[assistant]ns.",
"node_modules/@esbuild/darwin-arm64/README.mdategy" "node_modules/@esbuild/darwin-arm64/README.mdategy",
]; ];
for (const pattern of garbledPatterns) { for (const pattern of garbledPatterns) {
const hasGarbled = finalOutput.some(line => line.includes(pattern)); const hasGarbled = finalOutput.some((line) => line.includes(pattern));
assert.ok(!hasGarbled, `Found garbled pattern "${pattern}" in output`); assert.ok(!hasGarbled, `Found garbled pattern "${pattern}" in output`);
} }

View file

@ -1,18 +1,17 @@
import { test, describe } from "node:test";
import assert from "node:assert"; import assert from "node:assert";
import { VirtualTerminal } from "./virtual-terminal.js"; import { describe, test } from "node:test";
import { import {
TUI,
Container, Container,
TextComponent,
TextEditor,
WhitespaceComponent,
MarkdownComponent, MarkdownComponent,
SelectList, SelectList,
TextComponent,
TextEditor,
TUI,
WhitespaceComponent,
} from "../src/index.js"; } from "../src/index.js";
import { VirtualTerminal } from "./virtual-terminal.js";
describe("TUI Rendering", () => { describe("TUI Rendering", () => {
test("renders single text component", async () => { test("renders single text component", async () => {
const terminal = new VirtualTerminal(80, 24); const terminal = new VirtualTerminal(80, 24);
const ui = new TUI(terminal); const ui = new TUI(terminal);
@ -22,7 +21,7 @@ describe("TUI Rendering", () => {
ui.addChild(text); ui.addChild(text);
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
// Wait for writes to complete and get the rendered output // Wait for writes to complete and get the rendered output
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
@ -48,7 +47,7 @@ describe("TUI Rendering", () => {
ui.addChild(new TextComponent("Line 3")); ui.addChild(new TextComponent("Line 3"));
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Line 1"); assert.strictEqual(output[0], "Line 1");
@ -68,7 +67,7 @@ describe("TUI Rendering", () => {
ui.addChild(new TextComponent("Bottom text")); ui.addChild(new TextComponent("Bottom text"));
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Top text"); assert.strictEqual(output[0], "Top text");
@ -96,7 +95,7 @@ describe("TUI Rendering", () => {
ui.addChild(new TextComponent("After container")); ui.addChild(new TextComponent("After container"));
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Before container"); assert.strictEqual(output[0], "Before container");
@ -117,7 +116,7 @@ describe("TUI Rendering", () => {
ui.setFocus(editor); ui.setFocus(editor);
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
// Initial state - empty editor with cursor // Initial state - empty editor with cursor
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
@ -142,7 +141,7 @@ describe("TUI Rendering", () => {
ui.addChild(dynamicText); ui.addChild(dynamicText);
// Wait for initial render // Wait for initial render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Save initial state // Save initial state
@ -153,7 +152,7 @@ describe("TUI Rendering", () => {
ui.requestRender(); ui.requestRender();
// Wait for render // Wait for render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
// Flush terminal buffer // Flush terminal buffer
await terminal.flush(); await terminal.flush();
@ -180,7 +179,7 @@ describe("TUI Rendering", () => {
ui.addChild(text3); ui.addChild(text3);
// Wait for initial render // Wait for initial render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
let output = await terminal.flushAndGetViewport(); let output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Line 1"); assert.strictEqual(output[0], "Line 1");
@ -191,7 +190,7 @@ describe("TUI Rendering", () => {
ui.removeChild(text2); ui.removeChild(text2);
ui.requestRender(); ui.requestRender();
await new Promise(resolve => setImmediate(resolve)); await new Promise((resolve) => setImmediate(resolve));
output = await terminal.flushAndGetViewport(); output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Line 1"); assert.strictEqual(output[0], "Line 1");
@ -212,7 +211,7 @@ describe("TUI Rendering", () => {
} }
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
@ -241,7 +240,7 @@ describe("TUI Rendering", () => {
ui.addChild(new TextComponent("After")); ui.addChild(new TextComponent("After"));
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
assert.strictEqual(output[0], "Before"); assert.strictEqual(output[0], "Before");
@ -262,7 +261,7 @@ describe("TUI Rendering", () => {
ui.addChild(markdown); ui.addChild(markdown);
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
// Should have formatted markdown // Should have formatted markdown
@ -289,7 +288,7 @@ describe("TUI Rendering", () => {
ui.setFocus(selectList); ui.setFocus(selectList);
// Wait for next tick for render to complete // Wait for next tick for render to complete
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
const output = await terminal.flushAndGetViewport(); const output = await terminal.flushAndGetViewport();
// First option should be selected (has → indicator) // First option should be selected (has → indicator)
@ -334,7 +333,7 @@ describe("TUI Rendering", () => {
ui.setFocus(editor); ui.setFocus(editor);
// Wait for initial render // Wait for initial render
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Check that the editor is rendered after the existing content // Check that the editor is rendered after the existing content
@ -365,7 +364,7 @@ describe("TUI Rendering", () => {
terminal.sendInput("Hello World"); terminal.sendInput("Hello World");
// Wait for the input to be processed // Wait for the input to be processed
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Check that text appears in the editor // Check that text appears in the editor
@ -383,7 +382,7 @@ describe("TUI Rendering", () => {
terminal.sendInput("\n"); terminal.sendInput("\n");
// Wait for the input to be processed // Wait for the input to be processed
await new Promise(resolve => process.nextTick(resolve)); await new Promise((resolve) => process.nextTick(resolve));
await terminal.flush(); await terminal.flush();
// Check that existing content is still preserved after adding new line // Check that existing content is still preserved after adding new line

View file

@ -1,5 +1,5 @@
import { test, describe } from "node:test";
import assert from "node:assert"; import assert from "node:assert";
import { describe, test } from "node:test";
import { VirtualTerminal } from "./virtual-terminal.js"; import { VirtualTerminal } from "./virtual-terminal.js";
describe("VirtualTerminal", () => { describe("VirtualTerminal", () => {
@ -86,13 +86,13 @@ describe("VirtualTerminal", () => {
assert.strictEqual(viewport.length, 10); assert.strictEqual(viewport.length, 10);
assert.strictEqual(viewport[0], "Line 7"); assert.strictEqual(viewport[0], "Line 7");
assert.strictEqual(viewport[8], "Line 15"); assert.strictEqual(viewport[8], "Line 15");
assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n
// Scroll buffer should have all lines // Scroll buffer should have all lines
assert.ok(scrollBuffer.length >= 15); assert.ok(scrollBuffer.length >= 15);
// Check specific lines exist in the buffer // Check specific lines exist in the buffer
const hasLine1 = scrollBuffer.some(line => line === "Line 1"); const hasLine1 = scrollBuffer.some((line) => line === "Line 1");
const hasLine15 = scrollBuffer.some(line => line === "Line 15"); const hasLine15 = scrollBuffer.some((line) => line === "Line 15");
assert.ok(hasLine1, "Buffer should contain 'Line 1'"); assert.ok(hasLine1, "Buffer should contain 'Line 1'");
assert.ok(hasLine15, "Buffer should contain 'Line 15'"); assert.ok(hasLine15, "Buffer should contain 'Line 15'");
}); });
@ -129,9 +129,12 @@ describe("VirtualTerminal", () => {
const terminal = new VirtualTerminal(80, 24); const terminal = new VirtualTerminal(80, 24);
let received = ""; let received = "";
terminal.start((data) => { terminal.start(
received = data; (data) => {
}, () => {}); received = data;
},
() => {},
);
terminal.sendInput("a"); terminal.sendInput("a");
assert.strictEqual(received, "a"); assert.strictEqual(received, "a");
@ -146,9 +149,12 @@ describe("VirtualTerminal", () => {
const terminal = new VirtualTerminal(80, 24); const terminal = new VirtualTerminal(80, 24);
let resized = false; let resized = false;
terminal.start(() => {}, () => { terminal.start(
resized = true; () => {},
}); () => {
resized = true;
},
);
terminal.resize(100, 30); terminal.resize(100, 30);
assert.strictEqual(resized, true); assert.strictEqual(resized, true);

View file

@ -1,6 +1,6 @@
import xterm from '@xterm/headless'; import type { Terminal as XtermTerminalType } from "@xterm/headless";
import type { Terminal as XtermTerminalType } from '@xterm/headless'; import xterm from "@xterm/headless";
import { Terminal } from '../src/terminal.js'; import type { Terminal } from "../src/terminal.js";
// Extract Terminal class from the module // Extract Terminal class from the module
const XtermTerminal = xterm.Terminal; const XtermTerminal = xterm.Terminal;
@ -81,7 +81,7 @@ export class VirtualTerminal implements Terminal {
async flush(): Promise<void> { async flush(): Promise<void> {
// Write an empty string to ensure all previous writes are flushed // Write an empty string to ensure all previous writes are flushed
return new Promise<void>((resolve) => { return new Promise<void>((resolve) => {
this.xterm.write('', () => resolve()); this.xterm.write("", () => resolve());
}); });
} }
@ -107,7 +107,7 @@ export class VirtualTerminal implements Terminal {
if (line) { if (line) {
lines.push(line.translateToString(true)); lines.push(line.translateToString(true));
} else { } else {
lines.push(''); lines.push("");
} }
} }
@ -127,7 +127,7 @@ export class VirtualTerminal implements Terminal {
if (line) { if (line) {
lines.push(line.translateToString(true)); lines.push(line.translateToString(true));
} else { } else {
lines.push(''); lines.push("");
} }
} }
@ -155,7 +155,7 @@ export class VirtualTerminal implements Terminal {
const buffer = this.xterm.buffer.active; const buffer = this.xterm.buffer.active;
return { return {
x: buffer.cursorX, x: buffer.cursorX,
y: buffer.cursorY y: buffer.cursorY,
}; };
} }
} }