mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-20 05:04:44 +00:00
Massive refactor of API
- Switch to function based API - Anthropic SDK style async generator - Fully typed with escape hatches for custom models
This commit is contained in:
parent
004de3c9d0
commit
66cefb236e
29 changed files with 5835 additions and 6225 deletions
|
|
@ -28,6 +28,6 @@
|
|||
"lineWidth": 120
|
||||
},
|
||||
"files": {
|
||||
"includes": ["packages/*/src/**/*", "*.json", "*.md"]
|
||||
"includes": ["packages/*/src/**/*", "packages/*/test/**/*", "*.json", "*.md"]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import { writeFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { Api, KnownProvider, Model } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
|
@ -28,30 +29,13 @@ interface ModelsDevModel {
|
|||
};
|
||||
}
|
||||
|
||||
interface NormalizedModel {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||
async function fetchOpenRouterModels(): Promise<Model<any>[]> {
|
||||
try {
|
||||
console.log("Fetching models from OpenRouter API...");
|
||||
const response = await fetch("https://openrouter.ai/api/v1/models");
|
||||
const data = await response.json();
|
||||
|
||||
const models: NormalizedModel[] = [];
|
||||
const models: Model<any>[] = [];
|
||||
|
||||
for (const model of data.data) {
|
||||
// Only include models that support tools
|
||||
|
|
@ -59,27 +43,17 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
|
||||
// Parse provider from model ID
|
||||
const [providerPrefix] = model.id.split("/");
|
||||
let provider = "";
|
||||
let provider: KnownProvider = "openrouter";
|
||||
let modelKey = model.id;
|
||||
|
||||
// Skip models that we get from models.dev (Anthropic, Google, OpenAI)
|
||||
if (model.id.startsWith("google/") ||
|
||||
model.id.startsWith("openai/") ||
|
||||
model.id.startsWith("anthropic/")) {
|
||||
continue;
|
||||
} else if (model.id.startsWith("x-ai/")) {
|
||||
provider = "xai";
|
||||
modelKey = model.id.replace("x-ai/", "");
|
||||
} else {
|
||||
// All other models go through OpenRouter
|
||||
provider = "openrouter";
|
||||
modelKey = model.id; // Keep full ID for OpenRouter
|
||||
}
|
||||
|
||||
// Skip if not one of our supported providers from OpenRouter
|
||||
if (!["xai", "openrouter"].includes(provider)) {
|
||||
model.id.startsWith("anthropic/") ||
|
||||
model.id.startsWith("x-ai/")) {
|
||||
continue;
|
||||
}
|
||||
modelKey = model.id; // Keep full ID for OpenRouter
|
||||
|
||||
// Parse input modalities
|
||||
const input: ("text" | "image")[] = ["text"];
|
||||
|
|
@ -93,9 +67,11 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
|
||||
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
|
||||
|
||||
const normalizedModel: NormalizedModel = {
|
||||
const normalizedModel: Model<any> = {
|
||||
id: modelKey,
|
||||
name: model.name,
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
provider,
|
||||
reasoning: model.supported_parameters?.includes("reasoning") || false,
|
||||
input,
|
||||
|
|
@ -108,14 +84,6 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
contextWindow: model.context_length || 4096,
|
||||
maxTokens: model.top_provider?.max_completion_tokens || 4096,
|
||||
};
|
||||
|
||||
// Add baseUrl for providers that need it
|
||||
if (provider === "xai") {
|
||||
normalizedModel.baseUrl = "https://api.x.ai/v1";
|
||||
} else if (provider === "openrouter") {
|
||||
normalizedModel.baseUrl = "https://openrouter.ai/api/v1";
|
||||
}
|
||||
|
||||
models.push(normalizedModel);
|
||||
}
|
||||
|
||||
|
|
@ -127,13 +95,13 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|||
}
|
||||
}
|
||||
|
||||
async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||
async function loadModelsDevData(): Promise<Model<any>[]> {
|
||||
try {
|
||||
console.log("Fetching models from models.dev API...");
|
||||
const response = await fetch("https://models.dev/api.json");
|
||||
const data = await response.json();
|
||||
|
||||
const models: NormalizedModel[] = [];
|
||||
const models: Model<any>[] = [];
|
||||
|
||||
// Process Anthropic models
|
||||
if (data.anthropic?.models) {
|
||||
|
|
@ -144,7 +112,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
baseUrl: "https://api.anthropic.com",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -168,7 +138,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "google-generative-ai",
|
||||
provider: "google",
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -192,7 +164,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
|
|
@ -216,6 +190,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "groq",
|
||||
baseUrl: "https://api.groq.com/openai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
|
|
@ -241,6 +216,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "cerebras",
|
||||
baseUrl: "https://api.cerebras.ai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
|
|
@ -257,6 +233,32 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
|||
}
|
||||
}
|
||||
|
||||
// Process xAi models
|
||||
if (data.xai?.models) {
|
||||
for (const [modelId, model] of Object.entries(data.xai.models)) {
|
||||
const m = model as ModelsDevModel;
|
||||
if (m.tool_call !== true) continue;
|
||||
|
||||
models.push({
|
||||
id: modelId,
|
||||
name: m.name || modelId,
|
||||
api: "openai-completions",
|
||||
provider: "xai",
|
||||
baseUrl: "https://api.x.ai/v1",
|
||||
reasoning: m.reasoning === true,
|
||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||
cost: {
|
||||
input: m.cost?.input || 0,
|
||||
output: m.cost?.output || 0,
|
||||
cacheRead: m.cost?.cache_read || 0,
|
||||
cacheWrite: m.cost?.cache_write || 0,
|
||||
},
|
||||
contextWindow: m.limit?.context || 4096,
|
||||
maxTokens: m.limit?.output || 4096,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`Loaded ${models.length} tool-capable models from models.dev`);
|
||||
return models;
|
||||
} catch (error) {
|
||||
|
|
@ -280,6 +282,8 @@ async function generateModels() {
|
|||
allModels.push({
|
||||
id: "gpt-5-chat-latest",
|
||||
name: "GPT-5 Chat Latest",
|
||||
api: "openai-responses",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
provider: "openai",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
|
|
@ -294,8 +298,29 @@ async function generateModels() {
|
|||
});
|
||||
}
|
||||
|
||||
// Add missing Grok models
|
||||
if (!allModels.some(m => m.provider === "xai" && m.id === "grok-code-fast-1")) {
|
||||
allModels.push({
|
||||
id: "grok-code-fast-1",
|
||||
name: "Grok Code Fast 1",
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://api.x.ai/v1",
|
||||
provider: "xai",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.2,
|
||||
output: 1.5,
|
||||
cacheRead: 0.02,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 32768,
|
||||
maxTokens: 8192,
|
||||
});
|
||||
}
|
||||
|
||||
// Group by provider and deduplicate by model ID
|
||||
const providers: Record<string, Record<string, NormalizedModel>> = {};
|
||||
const providers: Record<string, Record<string, Model<any>>> = {};
|
||||
for (const model of allModels) {
|
||||
if (!providers[model.provider]) {
|
||||
providers[model.provider] = {};
|
||||
|
|
@ -319,39 +344,33 @@ export const PROVIDERS = {
|
|||
// Generate provider sections
|
||||
for (const [providerId, models] of Object.entries(providers)) {
|
||||
output += `\t${providerId}: {\n`;
|
||||
output += `\t\tmodels: {\n`;
|
||||
|
||||
for (const model of Object.values(models)) {
|
||||
output += `\t\t\t"${model.id}": {\n`;
|
||||
output += `\t\t\t\tid: "${model.id}",\n`;
|
||||
output += `\t\t\t\tname: "${model.name}",\n`;
|
||||
output += `\t\t\t\tprovider: "${model.provider}",\n`;
|
||||
output += `\t\t"${model.id}": {\n`;
|
||||
output += `\t\t\tid: "${model.id}",\n`;
|
||||
output += `\t\t\tname: "${model.name}",\n`;
|
||||
output += `\t\t\tapi: "${model.api}",\n`;
|
||||
output += `\t\t\tprovider: "${model.provider}",\n`;
|
||||
if (model.baseUrl) {
|
||||
output += `\t\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
||||
output += `\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
||||
}
|
||||
output += `\t\t\t\treasoning: ${model.reasoning},\n`;
|
||||
output += `\t\t\t\tinput: ${JSON.stringify(model.input)},\n`;
|
||||
output += `\t\t\t\tcost: {\n`;
|
||||
output += `\t\t\t\t\tinput: ${model.cost.input},\n`;
|
||||
output += `\t\t\t\t\toutput: ${model.cost.output},\n`;
|
||||
output += `\t\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
||||
output += `\t\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
||||
output += `\t\t\t\t},\n`;
|
||||
output += `\t\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
||||
output += `\t\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
||||
output += `\t\t\t} satisfies Model,\n`;
|
||||
output += `\t\t\treasoning: ${model.reasoning},\n`;
|
||||
output += `\t\t\tinput: [${model.input.map(i => `"${i}"`).join(", ")}],\n`;
|
||||
output += `\t\t\tcost: {\n`;
|
||||
output += `\t\t\t\tinput: ${model.cost.input},\n`;
|
||||
output += `\t\t\t\toutput: ${model.cost.output},\n`;
|
||||
output += `\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
||||
output += `\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
||||
output += `\t\t\t},\n`;
|
||||
output += `\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
||||
output += `\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
||||
output += `\t\t} satisfies Model<"${model.api}">,\n`;
|
||||
}
|
||||
|
||||
output += `\t\t}\n`;
|
||||
output += `\t},\n`;
|
||||
}
|
||||
|
||||
output += `} as const;
|
||||
|
||||
// Helper type to extract models for each provider
|
||||
export type ProviderModels = {
|
||||
[K in keyof typeof PROVIDERS]: typeof PROVIDERS[K]["models"]
|
||||
};
|
||||
`;
|
||||
|
||||
// Write file
|
||||
|
|
|
|||
|
|
@ -1,47 +1,43 @@
|
|||
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
|
||||
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
|
||||
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
|
||||
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptionsUnified,
|
||||
GenerateStream,
|
||||
KnownProvider,
|
||||
Model,
|
||||
OptionsForApi,
|
||||
ReasoningEffort,
|
||||
SimpleGenerateOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export class QueuedGenerateStream implements GenerateStream {
|
||||
private queue: AssistantMessageEvent[] = [];
|
||||
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
|
||||
private done = false;
|
||||
private error?: Error;
|
||||
private finalMessagePromise: Promise<AssistantMessage>;
|
||||
private resolveFinalMessage!: (message: AssistantMessage) => void;
|
||||
private rejectFinalMessage!: (error: Error) => void;
|
||||
|
||||
constructor() {
|
||||
this.finalMessagePromise = new Promise((resolve, reject) => {
|
||||
this.finalMessagePromise = new Promise((resolve) => {
|
||||
this.resolveFinalMessage = resolve;
|
||||
this.rejectFinalMessage = reject;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: AssistantMessageEvent): void {
|
||||
if (this.done) return;
|
||||
|
||||
// If it's the done event, resolve the final message
|
||||
if (event.type === "done") {
|
||||
this.done = true;
|
||||
this.resolveFinalMessage(event.message);
|
||||
}
|
||||
|
||||
// If it's an error event, reject the final message
|
||||
if (event.type === "error") {
|
||||
this.error = new Error(event.error);
|
||||
if (!this.done) {
|
||||
this.rejectFinalMessage(this.error);
|
||||
}
|
||||
this.done = true;
|
||||
this.resolveFinalMessage(event.partial);
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
|
|
@ -86,31 +82,14 @@ export class QueuedGenerateStream implements GenerateStream {
|
|||
}
|
||||
}
|
||||
|
||||
// API implementations registry
|
||||
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
|
||||
|
||||
/**
|
||||
* Register a custom API implementation
|
||||
*/
|
||||
export function registerApi(api: string, impl: GenerateFunction): void {
|
||||
apiImplementations.set(api, impl);
|
||||
}
|
||||
|
||||
// API key storage
|
||||
const apiKeys: Map<string, string> = new Map();
|
||||
|
||||
/**
|
||||
* Set an API key for a provider
|
||||
*/
|
||||
export function setApiKey(provider: KnownProvider, key: string): void;
|
||||
export function setApiKey(provider: string, key: string): void;
|
||||
export function setApiKey(provider: any, key: string): void {
|
||||
apiKeys.set(provider, key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider
|
||||
*/
|
||||
export function getApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getApiKey(provider: string): string | undefined;
|
||||
export function getApiKey(provider: any): string | undefined {
|
||||
|
|
@ -133,45 +112,76 @@ export function getApiKey(provider: any): string | undefined {
|
|||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main generate function
|
||||
*/
|
||||
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream {
|
||||
// Get implementation
|
||||
const impl = apiImplementations.get(model.api);
|
||||
if (!impl) {
|
||||
throw new Error(`Unsupported API: ${model.api}`);
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: OptionsForApi<TApi>,
|
||||
): GenerateStream {
|
||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
const providerOptions = { ...options, apiKey };
|
||||
|
||||
// Get API key from options or environment
|
||||
const api: Api = model.api;
|
||||
switch (api) {
|
||||
case "anthropic-messages":
|
||||
return streamAnthropic(model as Model<"anthropic-messages">, context, providerOptions);
|
||||
|
||||
case "openai-completions":
|
||||
return streamOpenAICompletions(model as Model<"openai-completions">, context, providerOptions as any);
|
||||
|
||||
case "openai-responses":
|
||||
return streamOpenAIResponses(model as Model<"openai-responses">, context, providerOptions as any);
|
||||
|
||||
case "google-generative-ai":
|
||||
return streamGoogle(model as Model<"google-generative-ai">, context, providerOptions);
|
||||
|
||||
default: {
|
||||
// This should never be reached if all Api cases are handled
|
||||
const _exhaustive: never = api;
|
||||
throw new Error(`Unhandled API: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: OptionsForApi<TApi>,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.finalMessage();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleGenerateOptions,
|
||||
): GenerateStream {
|
||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Map generic options to provider-specific
|
||||
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey);
|
||||
|
||||
// Return the GenerateStream from implementation
|
||||
return impl(model, context, providerOptions);
|
||||
const providerOptions = mapOptionsForApi(model, options, apiKey);
|
||||
return stream(model, context, providerOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to generate and get complete response (no streaming)
|
||||
*/
|
||||
export async function generateComplete(
|
||||
model: Model,
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: GenerateOptionsUnified,
|
||||
options?: SimpleGenerateOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const stream = generate(model, context, options);
|
||||
return stream.finalMessage();
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.finalMessage();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map generic options to provider-specific options
|
||||
*/
|
||||
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any {
|
||||
function mapOptionsForApi<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
options?: SimpleGenerateOptions,
|
||||
apiKey?: string,
|
||||
): OptionsForApi<TApi> {
|
||||
const base = {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
|
|
@ -179,18 +189,10 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
|||
apiKey: apiKey || options?.apiKey,
|
||||
};
|
||||
|
||||
switch (api) {
|
||||
case "openai-responses":
|
||||
case "openai-completions":
|
||||
return {
|
||||
...base,
|
||||
reasoning_effort: options?.reasoning,
|
||||
};
|
||||
|
||||
switch (model.api) {
|
||||
case "anthropic-messages": {
|
||||
if (!options?.reasoning) return base;
|
||||
if (!options?.reasoning) return base satisfies AnthropicOptions;
|
||||
|
||||
// Map effort to token budget
|
||||
const anthropicBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
|
|
@ -200,55 +202,60 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
|||
|
||||
return {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: anthropicBudgets[options.reasoning],
|
||||
},
|
||||
};
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: anthropicBudgets[options.reasoning],
|
||||
} satisfies AnthropicOptions;
|
||||
}
|
||||
case "google-generative-ai": {
|
||||
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
|
||||
|
||||
// Model-specific mapping for Google
|
||||
const googleBudget = getGoogleBudget(model, options.reasoning);
|
||||
case "openai-completions":
|
||||
return {
|
||||
...base,
|
||||
thinking_budget: googleBudget,
|
||||
};
|
||||
reasoningEffort: options?.reasoning,
|
||||
} satisfies OpenAICompletionsOptions;
|
||||
|
||||
case "openai-responses":
|
||||
return {
|
||||
...base,
|
||||
reasoningEffort: options?.reasoning,
|
||||
} satisfies OpenAIResponsesOptions;
|
||||
|
||||
case "google-generative-ai": {
|
||||
if (!options?.reasoning) return base as any;
|
||||
|
||||
const googleBudget = getGoogleBudget(model as Model<"google-generative-ai">, options.reasoning);
|
||||
return {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: googleBudget,
|
||||
},
|
||||
} satisfies GoogleOptions;
|
||||
}
|
||||
|
||||
default: {
|
||||
// Exhaustiveness check
|
||||
const _exhaustive: never = model.api;
|
||||
throw new Error(`Unhandled API in mapOptionsForApi: ${_exhaustive}`);
|
||||
}
|
||||
default:
|
||||
return base;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Google thinking budget based on model and effort
|
||||
*/
|
||||
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
||||
// Model-specific logic
|
||||
if (model.id.includes("flash-lite")) {
|
||||
const budgets = {
|
||||
minimal: 512,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("pro")) {
|
||||
function getGoogleBudget(model: Model<"google-generative-ai">, effort: ReasoningEffort): number {
|
||||
// See https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: Math.min(25000, 32768),
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("flash")) {
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
// Covers 2.5-flash-lite as well
|
||||
const budgets = {
|
||||
minimal: 0, // Disable thinking
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
|
|
@ -259,10 +266,3 @@ function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
|||
// Unknown model - use dynamic
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Register built-in API implementations
|
||||
// Import the new function-based implementations
|
||||
import { generateAnthropic } from "./providers/anthropic-generate.js";
|
||||
|
||||
// Register Anthropic implementation
|
||||
apiImplementations.set("anthropic-messages", generateAnthropic);
|
||||
|
|
|
|||
|
|
@ -1,37 +1,8 @@
|
|||
// @mariozechner/pi-ai - Unified LLM API with automatic model discovery
|
||||
// This package provides a common interface for working with multiple LLM providers
|
||||
|
||||
export const version = "0.5.8";
|
||||
|
||||
// Export generate API
|
||||
export {
|
||||
generate,
|
||||
generateComplete,
|
||||
getApiKey,
|
||||
QueuedGenerateStream,
|
||||
registerApi,
|
||||
setApiKey,
|
||||
} from "./generate.js";
|
||||
// Export generated models data
|
||||
export { PROVIDERS } from "./models.generated.js";
|
||||
// Export model utilities
|
||||
export {
|
||||
calculateCost,
|
||||
getModel,
|
||||
type KnownProvider,
|
||||
registerModel,
|
||||
} from "./models.js";
|
||||
|
||||
// Legacy providers (to be deprecated)
|
||||
export { AnthropicLLM } from "./providers/anthropic.js";
|
||||
export { GoogleLLM } from "./providers/google.js";
|
||||
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
||||
export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
||||
|
||||
// Export types
|
||||
export type * from "./types.js";
|
||||
|
||||
// TODO: Remove these legacy exports once consumers are updated
|
||||
export function createLLM(): never {
|
||||
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
|
||||
}
|
||||
export * from "./generate.js";
|
||||
export * from "./models.generated.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export * from "./providers/google.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./types.js";
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +1,39 @@
|
|||
import { PROVIDERS } from "./models.generated.js";
|
||||
import type { KnownProvider, Model, Usage } from "./types.js";
|
||||
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
// Re-export Model type
|
||||
export type { KnownProvider, Model } from "./types.js";
|
||||
|
||||
// Dynamic model registry initialized from PROVIDERS
|
||||
const modelRegistry: Map<string, Map<string, Model>> = new Map();
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
// Initialize registry from PROVIDERS on module load
|
||||
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
||||
const providerModels = new Map<string, Model>();
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model);
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a model from the registry - typed overload for known providers
|
||||
*/
|
||||
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
|
||||
export function getModel(provider: string, modelId: string): Model | undefined;
|
||||
export function getModel(provider: any, modelId: any): Model | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId);
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof PROVIDERS)[TProvider],
|
||||
> = (typeof PROVIDERS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||
|
||||
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof PROVIDERS)[TProvider]>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>>;
|
||||
export function getModel<TApi extends Api>(provider: string, modelId: string): Model<TApi> | undefined;
|
||||
export function getModel<TApi extends Api>(provider: any, modelId: any): Model<TApi> | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId) as Model<TApi> | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom model
|
||||
*/
|
||||
export function registerModel(model: Model): void {
|
||||
export function registerModel<TApi extends Api>(model: Model<TApi>): void {
|
||||
if (!modelRegistry.has(model.provider)) {
|
||||
modelRegistry.set(model.provider, new Map());
|
||||
}
|
||||
modelRegistry.get(model.provider)!.set(model.id, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cost for token usage
|
||||
*/
|
||||
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
|
||||
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
|
|
|
|||
|
|
@ -1,425 +0,0 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
// Anthropic-specific options
|
||||
export interface AnthropicOptions extends GenerateOptions {
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
};
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Anthropic API
|
||||
*/
|
||||
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Anthropic client
|
||||
const client = createAnthropicClient(model, options.apiKey!);
|
||||
|
||||
// Convert messages
|
||||
const messages = convertMessages(context.messages, model, "anthropic-messages");
|
||||
|
||||
// Build params
|
||||
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
|
||||
|
||||
// Create Anthropic stream
|
||||
const anthropicStream = client.client.messages.stream(
|
||||
{
|
||||
...params,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options.signal,
|
||||
},
|
||||
);
|
||||
|
||||
// Emit start event
|
||||
stream.push({
|
||||
type: "start",
|
||||
partial: output,
|
||||
});
|
||||
|
||||
// Process Anthropic events
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
currentBlock = {
|
||||
type: "text",
|
||||
text: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
// We wait for the full tool use to be streamed
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: event.content_block.name,
|
||||
arguments: event.content_block.input as Record<string, any>,
|
||||
partialJson: "",
|
||||
};
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
if (currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
if (currentBlock && currentBlock.type === "toolCall") {
|
||||
currentBlock.partialJson += event.delta.partial_json;
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
|
||||
currentBlock.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
const finalToolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: currentBlock.id,
|
||||
name: currentBlock.name,
|
||||
arguments: JSON.parse(currentBlock.partialJson),
|
||||
};
|
||||
output.content.push(finalToolCall);
|
||||
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
output.usage.input += event.usage.input_tokens || 0;
|
||||
output.usage.output += event.usage.output_tokens || 0;
|
||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit done event with final message
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Helper to create Anthropic client
|
||||
interface AnthropicClientWrapper {
|
||||
client: Anthropic;
|
||||
isOAuthToken: boolean;
|
||||
}
|
||||
|
||||
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
}
|
||||
|
||||
// Build Anthropic API params
|
||||
function buildAnthropicParams(
|
||||
model: Model,
|
||||
context: Context,
|
||||
options: AnthropicOptions,
|
||||
messages: MessageParam[],
|
||||
isOAuthToken: boolean,
|
||||
): MessageCreateParamsStreaming {
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
max_tokens: options.maxTokens || model.maxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Only enable thinking if the model supports it
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert messages to Anthropic format
|
||||
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model, api);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Anthropic format
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
// Image content
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// Convert tools to Anthropic format
|
||||
function convertTools(tools: Context["tools"]): Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// Map Anthropic stop reason to our StopReason type
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
|
@ -3,91 +3,46 @@ import type {
|
|||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
Tool,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface AnthropicLLMOptions extends LLMOptions {
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
};
|
||||
export interface AnthropicOptions extends GenerateOptions {
|
||||
thinkingEnabled?: boolean;
|
||||
thinkingBudgetTokens?: number;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||
private client: Anthropic;
|
||||
private modelInfo: Model;
|
||||
private isOAuthToken: boolean = false;
|
||||
export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.ANTHROPIC_API_KEY) {
|
||||
throw new Error(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
this.client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
this.isOAuthToken = true;
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
this.client = new Anthropic({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true, defaultHeaders });
|
||||
this.isOAuthToken = false;
|
||||
}
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "anthropic-messages";
|
||||
}
|
||||
|
||||
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "anthropic-messages" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -99,77 +54,14 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
};
|
||||
|
||||
try {
|
||||
const messages = this.convertMessages(context.messages);
|
||||
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
max_tokens: options?.maxTokens || 4096,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (this.isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = this.convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Only enable thinking if the model supports it
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
const stream = this.client.messages.stream(
|
||||
{
|
||||
...params,
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
signal: options?.signal,
|
||||
},
|
||||
);
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
const { client, isOAuthToken } = createClient(model, options?.apiKey!);
|
||||
const params = buildParams(model, context, isOAuthToken, options);
|
||||
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
for await (const event of stream) {
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
currentBlock = {
|
||||
|
|
@ -177,7 +69,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
text: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
|
|
@ -185,9 +77,9 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
thinkingSignature: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
// We wait for the full tool use to be streamed to send the event
|
||||
// We wait for the full tool use to be streamed
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
|
|
@ -200,15 +92,19 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
if (event.delta.type === "text_delta") {
|
||||
if (currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text += event.delta.text;
|
||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: event.delta.text });
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
if (currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += event.delta.thinking;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentBlock.thinking,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
|
|
@ -224,9 +120,17 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
} else if (event.type === "content_block_stop") {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
const finalToolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -235,150 +139,274 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
|||
arguments: JSON.parse(currentBlock.partialJson),
|
||||
};
|
||||
output.content.push(finalToolCall);
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: finalToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = this.mapStopReason(event.delta.stop_reason);
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
output.usage.input += event.usage.input_tokens || 0;
|
||||
output.usage.output += event.usage.output_tokens || 0;
|
||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
if (apiKey.includes("sk-ant-oat")) {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||
if (typeof process !== "undefined" && process.env) {
|
||||
process.env.ANTHROPIC_API_KEY = undefined;
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
defaultHeaders,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
} else {
|
||||
const defaultHeaders = {
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||
};
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, model),
|
||||
max_tokens: options?.maxTokens || model.maxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: context.systemPrompt,
|
||||
cache_control: {
|
||||
type: "ephemeral",
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = context.systemPrompt;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
if (options?.thinkingEnabled && model.reasoning) {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[]): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
return params;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
function convertMessages(messages: Message[], model: Model<"anthropic-messages">): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Anthropic format
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
// Image content
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredBlocks = !this.modelInfo?.input.includes("image")
|
||||
? blocks.filter((b) => b.type !== "image")
|
||||
: blocks;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
blocks.push({
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
text: item.text,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: block.text,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
signature: block.thinkingSignature || "",
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: msg.content,
|
||||
is_error: msg.isError,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
return params;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
private convertTools(tools: Context["tools"]): Tool[] {
|
||||
if (!tools) return [];
|
||||
function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
return tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: tool.parameters.properties || {},
|
||||
required: tool.parameters.required || [],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
private mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default:
|
||||
return "stop";
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "safety";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
import {
|
||||
type Content,
|
||||
type FinishReason,
|
||||
FinishReason,
|
||||
FunctionCallingConfigMode,
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type Part,
|
||||
} from "@google/genai";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
|
|
@ -23,7 +25,7 @@ import type {
|
|||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface GoogleLLMOptions extends LLMOptions {
|
||||
export interface GoogleOptions extends GenerateOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
|
|
@ -31,38 +33,20 @@ export interface GoogleLLMOptions extends LLMOptions {
|
|||
};
|
||||
}
|
||||
|
||||
export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||
private client: GoogleGenAI;
|
||||
private modelInfo: Model;
|
||||
export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.GEMINI_API_KEY) {
|
||||
throw new Error(
|
||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.GEMINI_API_KEY;
|
||||
}
|
||||
this.client = new GoogleGenAI({ apiKey });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "google-generative-ai";
|
||||
}
|
||||
|
||||
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -72,70 +56,20 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
const contents = this.convertMessages(context.messages);
|
||||
const client = createClient(options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
// Build generation config
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options?.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options?.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
// Build the config object
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
||||
...(context.tools && { tools: this.convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
// Add tool config if needed
|
||||
if (context.tools && options?.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: this.mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Add thinking config if enabled and model supports it
|
||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
};
|
||||
}
|
||||
|
||||
// Abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
// Build the request parameters
|
||||
const params: GenerateContentParameters = {
|
||||
model: this.modelInfo.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
const stream = await this.client.models.generateContentStream(params);
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
for await (const chunk of stream) {
|
||||
// Extract parts from the chunk
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = part.thought === true;
|
||||
|
||||
// Check if we need to switch blocks
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
|
|
@ -143,50 +77,60 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Start new block
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Append content to current block
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = part.thoughtSignature;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentBlock.thinking,
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: part.text });
|
||||
stream.push({ type: "text_delta", delta: part.text, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle function calls
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Add tool call
|
||||
const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`;
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -195,21 +139,18 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
arguments: part.functionCall.args as Record<string, any>,
|
||||
};
|
||||
output.content.push(toolCall);
|
||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
||||
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map finish reason
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = this.mapStopReason(candidate.finishReason);
|
||||
// Check if we have tool calls in blocks
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
// Capture usage metadata if available
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
|
|
@ -225,166 +166,223 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize last block
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||
} else {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(apiKey?: string): GoogleGenAI {
|
||||
if (!apiKey) {
|
||||
if (!process.env.GEMINI_API_KEY) {
|
||||
throw new Error(
|
||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.GEMINI_API_KEY;
|
||||
}
|
||||
return new GoogleGenAI({ apiKey });
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[]): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
||||
...(context.tools && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (context.tools && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
// Convert array content to Google format
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: item.text };
|
||||
} else {
|
||||
// Image content - Google uses inlineData
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !this.modelInfo?.input.includes("image")
|
||||
? parts.filter((p) => p.text !== undefined)
|
||||
: parts;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||
};
|
||||
}
|
||||
|
||||
// Process content blocks
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
parts.push({ text: block.text });
|
||||
} else if (block.type === "thinking") {
|
||||
const thinkingPart: Part = {
|
||||
thought: true,
|
||||
thoughtSignature: block.thinkingSignature,
|
||||
text: block.thinking,
|
||||
};
|
||||
parts.push(thinkingPart);
|
||||
} else if (block.type === "toolCall") {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
args: block.arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
if (parts.length > 0) {
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
response: {
|
||||
result: msg.content,
|
||||
isError: msg.isError,
|
||||
},
|
||||
parts: [{ text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: item.text };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
|
||||
return contents;
|
||||
}
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
parts.push({ text: block.text });
|
||||
} else if (block.type === "thinking") {
|
||||
const thinkingPart: Part = {
|
||||
thought: true,
|
||||
thoughtSignature: block.thinkingSignature,
|
||||
text: block.thinking,
|
||||
};
|
||||
parts.push(thinkingPart);
|
||||
} else if (block.type === "toolCall") {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
args: block.arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): any[] {
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
private mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
response: {
|
||||
result: msg.content,
|
||||
isError: msg.isError,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
case "BLOCKLIST":
|
||||
case "PROHIBITED_CONTENT":
|
||||
case "SPII":
|
||||
case "SAFETY":
|
||||
case "IMAGE_SAFETY":
|
||||
return "safety";
|
||||
case "RECITATION":
|
||||
return "safety";
|
||||
case "FINISH_REASON_UNSPECIFIED":
|
||||
case "OTHER":
|
||||
case "LANGUAGE":
|
||||
case "MALFORMED_FUNCTION_CALL":
|
||||
case "UNEXPECTED_TOOL_CALL":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
return contents;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[]): any[] {
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.RECITATION:
|
||||
return "safety";
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
import OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
Message,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
|
|
@ -22,43 +24,25 @@ import type {
|
|||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAICompletionsLLMOptions extends LLMOptions {
|
||||
export interface OpenAICompletionsOptions extends GenerateOptions {
|
||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "low" | "medium" | "high";
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||
}
|
||||
|
||||
export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private modelInfo: Model;
|
||||
export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "openai-completions";
|
||||
}
|
||||
|
||||
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -70,52 +54,13 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
};
|
||||
|
||||
try {
|
||||
const messages = this.convertMessages(request.messages, request.systemPrompt);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
messages,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
};
|
||||
|
||||
// Cerebras/xAI dont like the "store" field
|
||||
if (!this.modelInfo.baseUrl?.includes("cerebras.ai") && !this.modelInfo.baseUrl?.includes("api.x.ai")) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_completion_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (request.tools) {
|
||||
params.tools = this.convertTools(request.tools);
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if (
|
||||
options?.reasoningEffort &&
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.id.toLowerCase().includes("grok")
|
||||
) {
|
||||
params.reasoning_effort = options.reasoningEffort;
|
||||
}
|
||||
|
||||
const stream = await this.client.chat.completions.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
const client = createClient(model, options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||
for await (const chunk of stream) {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
output.usage = {
|
||||
input: chunk.usage.prompt_tokens || 0,
|
||||
|
|
@ -132,137 +77,170 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
// Capture finish reason
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = this.mapStopReason(choice.finish_reason);
|
||||
output.stopReason = mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
// Handle text content
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
// Check if we need to switch to text block
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new text block
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
// Append to text block
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentBlock.text,
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reasoning_content field
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp)
|
||||
if (
|
||||
(choice.delta as any).reasoning_content !== null &&
|
||||
(choice.delta as any).reasoning_content !== undefined &&
|
||||
(choice.delta as any).reasoning_content.length > 0
|
||||
) {
|
||||
// Check if we need to switch to thinking block
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" };
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "reasoning_content",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any).reasoning_content;
|
||||
currentBlock.thinking += delta;
|
||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reasoning field
|
||||
// Some endpoints return reasoning in reasining (ollama, xAI, ...)
|
||||
if (
|
||||
(choice.delta as any).reasoning !== null &&
|
||||
(choice.delta as any).reasoning !== undefined &&
|
||||
(choice.delta as any).reasoning.length > 0
|
||||
) {
|
||||
// Check if we need to switch to thinking block
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Start new thinking block
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" };
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "reasoning",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
}
|
||||
// Append to thinking block
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any).reasoning;
|
||||
currentBlock.thinking += delta;
|
||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
||||
stream.push({ type: "thinking_delta", delta, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
// Check if we need a new tool call block
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
// Save current block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Start new tool call block
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
|
|
@ -273,7 +251,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
output.content.push(currentBlock);
|
||||
}
|
||||
|
||||
// Accumulate tool call data
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
||||
|
|
@ -286,16 +263,27 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
}
|
||||
}
|
||||
|
||||
// Save final block if exists
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "thinking") {
|
||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (currentBlock.type === "toolCall") {
|
||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||
delete currentBlock.partialArgs;
|
||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
||||
stream.push({
|
||||
type: "toolCall",
|
||||
toolCall: currentBlock as ToolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -303,141 +291,188 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
|||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
return output;
|
||||
} catch (error) {
|
||||
// Update output with error information
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : String(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(model: Model<"openai-completions">, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
|
||||
const messages = convertMessages(model, context);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
};
|
||||
|
||||
// Cerebras/xAI dont like the "store" field
|
||||
if (!model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai")) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
if (options?.maxTokens) {
|
||||
params.max_completion_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
const useDeveloperRole =
|
||||
this.modelInfo.reasoning &&
|
||||
!this.modelInfo.baseUrl?.includes("cerebras.ai") &&
|
||||
!this.modelInfo.baseUrl?.includes("api.x.ai");
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: systemPrompt });
|
||||
}
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
});
|
||||
} else {
|
||||
// Convert array content to OpenAI format
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
// Image content - OpenAI uses data URLs
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const assistantMsg: ChatCompletionMessageParam = {
|
||||
role: "assistant",
|
||||
content: null,
|
||||
};
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
// Build content from blocks
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
if (textBlocks.length > 0) {
|
||||
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
||||
}
|
||||
// Grok models don't like reasoning_effort
|
||||
if (options?.reasoningEffort && model.reasoning && !model.id.toLowerCase().includes("grok")) {
|
||||
params.reasoning_effort = options.reasoningEffort;
|
||||
}
|
||||
|
||||
// Handle thinking blocks for llama.cpp server + gpt-oss
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
if (thinkingBlocks.length > 0) {
|
||||
// Use the signature from the first thinking block if available
|
||||
const signature = thinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
}
|
||||
function convertMessages(model: Model<"openai-completions">, context: Context): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
// Cerebras/xAi don't like the "developer" role
|
||||
const useDeveloperRole =
|
||||
model.reasoning && !model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai");
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: context.systemPrompt });
|
||||
}
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "tool",
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
tool_call_id: msg.toolCallId,
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: item.text,
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
if (textBlocks.length > 0) {
|
||||
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
||||
}
|
||||
|
||||
// Handle thinking blocks for llama.cpp server + gpt-oss
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
if (thinkingBlocks.length > 0) {
|
||||
// Use the signature from the first thinking block if available
|
||||
const signature = thinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
params.push({
|
||||
role: "tool",
|
||||
content: msg.content,
|
||||
tool_call_id: msg.toolCallId,
|
||||
});
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
private mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"] | null): StopReason {
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "safety";
|
||||
default:
|
||||
return "stop";
|
||||
function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "safety";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,58 +10,49 @@ import type {
|
|||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { QueuedGenerateStream } from "../generate.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
LLM,
|
||||
LLMOptions,
|
||||
GenerateFunction,
|
||||
GenerateOptions,
|
||||
GenerateStream,
|
||||
Message,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { transformMessages } from "./utils.js";
|
||||
|
||||
export interface OpenAIResponsesLLMOptions extends LLMOptions {
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends GenerateOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
}
|
||||
|
||||
export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||
private client: OpenAI;
|
||||
private modelInfo: Model;
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): GenerateStream => {
|
||||
const stream = new QueuedGenerateStream();
|
||||
|
||||
constructor(model: Model, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
this.modelInfo = model;
|
||||
}
|
||||
|
||||
getModel(): Model {
|
||||
return this.modelInfo;
|
||||
}
|
||||
|
||||
getApi(): string {
|
||||
return "openai-responses";
|
||||
}
|
||||
|
||||
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: this.getApi(),
|
||||
provider: this.modelInfo.provider,
|
||||
model: this.modelInfo.id,
|
||||
api: "openai-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
|
@ -71,77 +62,31 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
try {
|
||||
const input = this.convertToInput(request.messages, request.systemPrompt);
|
||||
// Create OpenAI client
|
||||
const client = createClient(model, options?.apiKey);
|
||||
const params = buildParams(model, context, options);
|
||||
const openaiStream = await client.responses.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: this.modelInfo.id,
|
||||
input,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (request.tools) {
|
||||
params.tools = this.convertTools(request.tools);
|
||||
}
|
||||
|
||||
// Add reasoning options for models that support it
|
||||
if (this.modelInfo?.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = {
|
||||
effort: this.modelInfo.name.startsWith("gpt-5") ? "minimal" : null,
|
||||
summary: null,
|
||||
};
|
||||
|
||||
if (this.modelInfo.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
input.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const stream = await this.client.responses.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
||||
|
||||
const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = [];
|
||||
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
||||
let currentBlock: ThinkingContent | TextContent | ToolCall | null = null;
|
||||
|
||||
for await (const event of stream) {
|
||||
for await (const event of openaiStream) {
|
||||
// Handle output item start
|
||||
if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
options?.onEvent?.({ type: "thinking_start" });
|
||||
outputItems.push(item);
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", partial: output });
|
||||
} else if (item.type === "message") {
|
||||
options?.onEvent?.({ type: "text_start" });
|
||||
outputItems.push(item);
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", partial: output });
|
||||
}
|
||||
}
|
||||
// Handle reasoning summary deltas
|
||||
|
|
@ -151,30 +96,42 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
if (
|
||||
currentItem &&
|
||||
currentItem.type === "reasoning" &&
|
||||
currentBlock &&
|
||||
currentBlock.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add a new line between summary parts (hack...)
|
||||
else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
if (
|
||||
currentItem &&
|
||||
currentItem.type === "reasoning" &&
|
||||
currentBlock &&
|
||||
currentBlock.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -186,30 +143,28 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
currentItem.content.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem && currentItem.type === "message") {
|
||||
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart && lastPart.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentItem.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join(""),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem && currentItem.type === "message") {
|
||||
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart && lastPart.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
options?.onEvent?.({
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
content: currentItem.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join(""),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -218,14 +173,24 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
options?.onEvent?.({ type: "thinking_end", content: thinkingContent });
|
||||
} else if (item.type === "message") {
|
||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
||||
const textContent = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
options?.onEvent?.({ type: "text_end", content: textContent });
|
||||
if (item.type === "reasoning" && currentBlock && currentBlock.type === "thinking") {
|
||||
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
currentBlock.textSignature = item.id;
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
|
|
@ -233,8 +198,8 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
name: item.name,
|
||||
arguments: JSON.parse(item.arguments),
|
||||
};
|
||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
||||
outputItems.push(item);
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
// Handle completion
|
||||
|
|
@ -249,10 +214,10 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(this.modelInfo, output.usage);
|
||||
calculateCost(model, output.usage);
|
||||
// Map status to stop reason
|
||||
output.stopReason = this.mapStopReason(response?.status);
|
||||
if (outputItems.some((b) => b.type === "function_call") && output.stopReason === "stop") {
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
|
@ -260,173 +225,215 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
|||
else if (event.type === "error") {
|
||||
output.stopReason = "error";
|
||||
output.error = `Code ${event.code}: ${event.message}` || "Unknown error";
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
return output;
|
||||
} else if (event.type === "response.failed") {
|
||||
output.stopReason = "error";
|
||||
output.error = "Unknown error";
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert output items to blocks
|
||||
for (const item of outputItems) {
|
||||
if (item.type === "reasoning") {
|
||||
output.content.push({
|
||||
type: "thinking",
|
||||
thinking: item.summary?.map((s: any) => s.text).join("\n\n") || "",
|
||||
thinkingSignature: JSON.stringify(item), // Full item for resubmission
|
||||
});
|
||||
} else if (item.type === "message") {
|
||||
output.content.push({
|
||||
type: "text",
|
||||
text: item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""),
|
||||
textSignature: item.id, // ID for resubmission
|
||||
});
|
||||
} else if (item.type === "function_call") {
|
||||
output.content.push({
|
||||
type: "toolCall",
|
||||
id: item.call_id + "|" + item.id,
|
||||
name: item.name,
|
||||
arguments: JSON.parse(item.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
||||
return output;
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = "error";
|
||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
options?.onEvent?.({ type: "error", error: output.error });
|
||||
return output;
|
||||
stream.push({ type: "error", error: output.error, partial: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
function createClient(model: Model<"openai-responses">, apiKey?: string) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertMessages(model, context);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput {
|
||||
const input: ResponseInput = [];
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
// Add system prompt if provided
|
||||
if (systemPrompt) {
|
||||
const role = this.modelInfo?.reasoning ? "developer" : "system";
|
||||
input.push({
|
||||
role,
|
||||
content: systemPrompt,
|
||||
});
|
||||
}
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = {
|
||||
effort: model.name.startsWith("gpt-5") ? "minimal" : null,
|
||||
summary: null,
|
||||
};
|
||||
|
||||
// Convert messages
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
// Handle both string and array content
|
||||
if (typeof msg.content === "string") {
|
||||
input.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
// Convert array content to OpenAI Responses format
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: item.text,
|
||||
} satisfies ResponseInputText;
|
||||
} else {
|
||||
// Image content - OpenAI Responses uses data URLs
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
input.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Process content blocks in order
|
||||
const output: ResponseInput = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
if (block.type === "thinking" && msg.stopReason !== "error") {
|
||||
// Push the full reasoning item(s) from signature
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature);
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
||||
status: "completed",
|
||||
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
||||
} satisfies ResponseOutputMessage);
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
||||
const toolCall = block as ToolCall;
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: toolCall.id.split("|")[1], // Extract original ID
|
||||
call_id: toolCall.id.split("|")[0], // Extract call session ID
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add all output items to input
|
||||
input.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Tool results are sent as function_call_output
|
||||
input.push({
|
||||
type: "function_call_output",
|
||||
call_id: msg.toolCallId.split("|")[0], // Extract call session ID
|
||||
output: msg.content,
|
||||
if (model.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
private convertTools(tools: Tool[]): OpenAITool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
strict: null,
|
||||
}));
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertMessages(model: Model<"openai-responses">, context: Context): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: context.systemPrompt,
|
||||
});
|
||||
}
|
||||
|
||||
private mapStopReason(status: string | undefined): StopReason {
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: msg.content }],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: item.text,
|
||||
} satisfies ResponseInputText;
|
||||
} else {
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||
if (block.type === "thinking" && msg.stopReason !== "error") {
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature);
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
||||
status: "completed",
|
||||
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
||||
} satisfies ResponseOutputMessage);
|
||||
// Do not submit toolcall blocks if the completion had an error (i.e. abort)
|
||||
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
||||
const toolCall = block as ToolCall;
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: toolCall.id.split("|")[1],
|
||||
call_id: toolCall.id.split("|")[0],
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: msg.toolCallId.split("|")[0],
|
||||
output: msg.content,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[]): OpenAITool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
strict: null,
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,6 @@
|
|||
import type { AssistantMessage, Message, Model } from "../types.js";
|
||||
import type { Api, AssistantMessage, Message, Model } from "../types.js";
|
||||
|
||||
/**
|
||||
* Transform messages for cross-provider compatibility.
|
||||
*
|
||||
* - User and toolResult messages are copied verbatim
|
||||
* - Assistant messages:
|
||||
* - If from the same provider/model, copied as-is
|
||||
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
|
||||
*
|
||||
* @param messages The messages to transform
|
||||
* @param model The target model that will process these messages
|
||||
* @returns A copy of the messages array with transformations applied
|
||||
*/
|
||||
export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
|
||||
export function transformMessages<TApi extends Api>(messages: Message[], model: Model<TApi>): Message[] {
|
||||
return messages.map((msg) => {
|
||||
// User and toolResult messages pass through unchanged
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
|
|
@ -24,7 +12,7 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
|||
const assistantMsg = msg as AssistantMessage;
|
||||
|
||||
// If message is from the same provider and API, keep as is
|
||||
if (assistantMsg.provider === model.provider && assistantMsg.api === api) {
|
||||
if (assistantMsg.provider === model.provider && assistantMsg.api === model.api) {
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
|
@ -47,8 +35,6 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
|||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
|
||||
// Should not reach here, but return as-is for safety
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||
export type Api = KnownApi | string;
|
||||
import type { AnthropicOptions } from "./providers/anthropic";
|
||||
import type { GoogleOptions } from "./providers/google";
|
||||
import type { OpenAICompletionsOptions } from "./providers/openai-completions";
|
||||
import type { OpenAIResponsesOptions } from "./providers/openai-responses";
|
||||
|
||||
export type Api = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||
|
||||
export interface ApiOptionsMap {
|
||||
"anthropic-messages": AnthropicOptions;
|
||||
"openai-completions": OpenAICompletionsOptions;
|
||||
"openai-responses": OpenAIResponsesOptions;
|
||||
"google-generative-ai": GoogleOptions;
|
||||
}
|
||||
|
||||
// Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys
|
||||
type _CheckExhaustive = ApiOptionsMap extends Record<Api, GenerateOptions>
|
||||
? Record<Api, GenerateOptions> extends ApiOptionsMap
|
||||
? true
|
||||
: ["ApiOptionsMap is missing some KnownApi values", Exclude<Api, keyof ApiOptionsMap>]
|
||||
: ["ApiOptionsMap doesn't extend Record<KnownApi, GenerateOptions>"];
|
||||
const _exhaustive: _CheckExhaustive = true;
|
||||
|
||||
// Helper type to get options for a specific API
|
||||
export type OptionsForApi<TApi extends Api> = ApiOptionsMap[TApi];
|
||||
|
||||
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
|
@ -21,31 +43,17 @@ export interface GenerateOptions {
|
|||
}
|
||||
|
||||
// Unified options with reasoning (what public generate() accepts)
|
||||
export interface GenerateOptionsUnified extends GenerateOptions {
|
||||
export interface SimpleGenerateOptions extends GenerateOptions {
|
||||
reasoning?: ReasoningEffort;
|
||||
}
|
||||
|
||||
// Generic GenerateFunction with typed options
|
||||
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = (
|
||||
model: Model,
|
||||
export type GenerateFunction<TApi extends Api> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options: TOptions,
|
||||
options: OptionsForApi<TApi>,
|
||||
) => GenerateStream;
|
||||
|
||||
// Legacy LLM interface (to be removed)
|
||||
export interface LLMOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
onEvent?: (event: AssistantMessageEvent) => void;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface LLM<T extends LLMOptions> {
|
||||
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
||||
getModel(): Model;
|
||||
getApi(): string;
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
|
|
@ -100,7 +108,7 @@ export interface AssistantMessage {
|
|||
model: string;
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
error?: string | Error;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface ToolResultMessage {
|
||||
|
|
@ -138,10 +146,10 @@ export type AssistantMessageEvent =
|
|||
| { type: "error"; error: string; partial: AssistantMessage };
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model {
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: Api;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
|
|
|
|||
|
|
@ -1,128 +1,103 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context } from "../src/types.js";
|
||||
import { beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||
|
||||
async function testAbortSignal<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "What is 15 + 27? Think step by step. Then list 50 first names."
|
||||
}]
|
||||
};
|
||||
async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "What is 15 + 27? Think step by step. Then list 50 first names.",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let abortFired = false;
|
||||
const controller = new AbortController();
|
||||
const response = await llm.generate(context, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
onEvent: (event) => {
|
||||
// console.log(JSON.stringify(event, null, 2));
|
||||
if (abortFired) return;
|
||||
setTimeout(() => controller.abort(), 2000);
|
||||
abortFired = true;
|
||||
}
|
||||
});
|
||||
let abortFired = false;
|
||||
const controller = new AbortController();
|
||||
const response = await stream(llm, context, { ...options, signal: controller.signal });
|
||||
for await (const event of response) {
|
||||
if (abortFired) return;
|
||||
setTimeout(() => controller.abort(), 3000);
|
||||
abortFired = true;
|
||||
break;
|
||||
}
|
||||
const msg = await response.finalMessage();
|
||||
|
||||
// If we get here without throwing, the abort didn't work
|
||||
expect(response.stopReason).toBe("error");
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
// If we get here without throwing, the abort didn't work
|
||||
expect(msg.stopReason).toBe("error");
|
||||
expect(msg.content.length).toBeGreaterThan(0);
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
||||
context.messages.push(msg);
|
||||
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
||||
|
||||
// Ensure we can still make requests after abort
|
||||
const followUp = await llm.generate(context, options);
|
||||
expect(followUp.stopReason).toBe("stop");
|
||||
expect(followUp.content.length).toBeGreaterThan(0);
|
||||
const followUp = await complete(llm, context, options);
|
||||
expect(followUp.stopReason).toBe("stop");
|
||||
expect(followUp.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
|
||||
async function testImmediateAbort<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
const controller = new AbortController();
|
||||
async function testImmediateAbort<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
const controller = new AbortController();
|
||||
|
||||
// Abort immediately
|
||||
controller.abort();
|
||||
controller.abort();
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Hello" }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Hello" }],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
...options,
|
||||
signal: controller.signal
|
||||
});
|
||||
expect(response.stopReason).toBe("error");
|
||||
const response = await complete(llm, context, { ...options, signal: controller.signal });
|
||||
expect(response.stopReason).toBe("error");
|
||||
}
|
||||
|
||||
describe("AI Providers Abort Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
||||
let llm: GoogleLLM;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
||||
});
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
||||
const llm: Model<"openai-completions"> = {
|
||||
...getModel("openai", "gpt-4o-mini")!,
|
||||
api: "openai-completions",
|
||||
};
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (!model) {
|
||||
throw new Error("Model not found");
|
||||
}
|
||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
||||
const llm = getModel("anthropic", "claude-opus-4-1-20250805");
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, {});
|
||||
});
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, {});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-opus-4-1-20250805")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
|
||||
it("should abort mid-stream", async () => {
|
||||
await testAbortSignal(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle immediate abort", async () => {
|
||||
await testImmediateAbort(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,313 +1,265 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context, UserMessage, AssistantMessage } from "../src/types.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||
|
||||
async function testEmptyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with completely empty content array
|
||||
const emptyMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: []
|
||||
};
|
||||
async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with completely empty content array
|
||||
const emptyMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [],
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [emptyMessage]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [emptyMessage],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
// Should either handle gracefully or return an error
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Most providers should return an error or empty response
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
// If it didn't error, it should have some content or gracefully handle empty
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
// Should either handle gracefully or return an error
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testEmptyStringMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with empty string content
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: ""
|
||||
}]
|
||||
};
|
||||
async function testEmptyStringMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with empty string content
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty string gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testWhitespaceOnlyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with whitespace-only content
|
||||
const context: Context = {
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: " \n\t "
|
||||
}]
|
||||
};
|
||||
async function testWhitespaceOnlyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with whitespace-only content
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: " \n\t ",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle whitespace-only gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle whitespace-only gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
}
|
||||
}
|
||||
|
||||
async function testEmptyAssistantMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
||||
// Test with empty assistant message in conversation flow
|
||||
// User -> Empty Assistant -> User
|
||||
const emptyAssistant: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: llm.getApi(),
|
||||
provider: llm.getModel().provider,
|
||||
model: llm.getModel().id,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
||||
},
|
||||
stopReason: "stop"
|
||||
};
|
||||
async function testEmptyAssistantMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||
// Test with empty assistant message in conversation flow
|
||||
// User -> Empty Assistant -> User
|
||||
const emptyAssistant: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: llm.api,
|
||||
provider: llm.provider,
|
||||
model: llm.id,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello, how are you?"
|
||||
},
|
||||
emptyAssistant,
|
||||
{
|
||||
role: "user",
|
||||
content: "Please respond this time."
|
||||
}
|
||||
]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello, how are you?",
|
||||
},
|
||||
emptyAssistant,
|
||||
{
|
||||
role: "user",
|
||||
content: "Please respond this time.",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty assistant message in context gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
const response = await complete(llm, context, options);
|
||||
|
||||
expect(response).toBeDefined();
|
||||
expect(response.role).toBe("assistant");
|
||||
|
||||
// Should handle empty assistant message in context gracefully
|
||||
if (response.stopReason === "error") {
|
||||
expect(response.error).toBeDefined();
|
||||
} else {
|
||||
expect(response.content).toBeDefined();
|
||||
expect(response.content.length).toBeGreaterThan(0);
|
||||
}
|
||||
}
|
||||
|
||||
describe("AI Providers Empty Message Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
||||
let llm: GoogleLLM;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
||||
const llm = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (!model) {
|
||||
throw new Error("Model gpt-5-mini not found");
|
||||
}
|
||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
||||
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
||||
let llm: AnthropicLLM;
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-3-5-haiku-20241022")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
||||
const llm = getModel("xai", "grok-3");
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
// Test with xAI/Grok if available
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
if (!model) {
|
||||
throw new Error("Model grok-3 not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.XAI_API_KEY!);
|
||||
});
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
||||
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Test with Groq if available
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
||||
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("groq", "llama-3.3-70b-versatile");
|
||||
if (!model) {
|
||||
throw new Error("Model llama-3.3-70b-versatile not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.GROQ_API_KEY!);
|
||||
});
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Test with Cerebras if available
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
if (!model) {
|
||||
throw new Error("Model gpt-oss-120b not found");
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, process.env.CEREBRAS_API_KEY!);
|
||||
});
|
||||
|
||||
it("should handle empty content array", async () => {
|
||||
await testEmptyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty string content", async () => {
|
||||
await testEmptyStringMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle whitespace-only content", async () => {
|
||||
await testWhitespaceOnlyMessage(llm);
|
||||
});
|
||||
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should handle empty assistant message in conversation", async () => {
|
||||
await testEmptyAssistantMessage(llm);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,311 +1,612 @@
|
|||
import { describe, it, beforeAll, expect } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import { generate, generateComplete } from "../src/generate.js";
|
||||
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
|
||||
import { type ChildProcess, execSync, spawn } from "child_process";
|
||||
import { readFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { dirname, join } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { complete, stream } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Calculator tool definition (same as examples)
|
||||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform"
|
||||
}
|
||||
},
|
||||
required: ["a", "b", "operation"]
|
||||
}
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform",
|
||||
},
|
||||
},
|
||||
required: ["a", "b", "operation"],
|
||||
},
|
||||
};
|
||||
|
||||
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [
|
||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
||||
]
|
||||
};
|
||||
async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [{ role: "user", content: "Reply with exactly: 'Hello test successful'" }],
|
||||
};
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain("Hello test successful");
|
||||
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
const secondResponse = await complete(model, context, options);
|
||||
|
||||
const secondResponse = await generateComplete(model, context, options);
|
||||
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain(
|
||||
"Goodbye test successful",
|
||||
);
|
||||
}
|
||||
|
||||
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool."
|
||||
}],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool.",
|
||||
},
|
||||
],
|
||||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find(b => b.type == "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
const response = await complete(model, context, options);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find((b) => b.type === "toolCall");
|
||||
if (toolCall && toolCall.type === "toolCall") {
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }],
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
const s = stream(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
for await (const event of s) {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
const response = await s.finalMessage();
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some((b) => b.type === "text")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleThinking(model: Model, options: GenerateOptionsUnified) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const stream = generate(model, context, options);
|
||||
const s = stream(model, context, options);
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
}
|
||||
for await (const event of s) {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await stream.finalMessage();
|
||||
const response = await s.finalMessage();
|
||||
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some((b) => b.type === "thinking")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleImage(model: Model, options?: GenerateOptionsUnified) {
|
||||
// Check if the model supports images
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
async function handleImage<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
// Check if the model supports images
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...).",
|
||||
},
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await generateComplete(model, context, options);
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const textContent = response.content.find(b => b.type == "text");
|
||||
if (textContent && textContent.type === "text") {
|
||||
const lowerContent = textContent.text.toLowerCase();
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const textContent = response.content.find((b) => b.type === "text");
|
||||
if (textContent && textContent.type === "text") {
|
||||
const lowerContent = textContent.text.toLowerCase();
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
}
|
||||
|
||||
async function multiTurn(model: Model, options?: GenerateOptionsUnified) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
||||
}
|
||||
],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool.",
|
||||
},
|
||||
],
|
||||
tools: [calculatorTool],
|
||||
};
|
||||
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await generateComplete(model, context, options);
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await complete(model, context, options);
|
||||
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add": result = a + b; break;
|
||||
case "multiply": result = a * b; break;
|
||||
default: result = 0;
|
||||
}
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add":
|
||||
result = a + b;
|
||||
break;
|
||||
case "multiply":
|
||||
result = a * b;
|
||||
break;
|
||||
default:
|
||||
result = 0;
|
||||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
}
|
||||
|
||||
describe("Generate E2E Tests", () => {
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
let model: Model;
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
||||
const llm = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
it("should handle ", async () => {
|
||||
await handleThinking(llm, { thinking: { enabled: true, budgetTokens: 1024 } });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { thinking: { enabled: true, budgetTokens: 2048 } });
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
let model: Model;
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
beforeAll(() => {
|
||||
model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
||||
const llm: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(model, { reasoning: "low" });
|
||||
});
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(model, { reasoning: "medium" });
|
||||
});
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
const llm = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
});
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle ", { retry: 2 }, async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
const model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle thinking", async () => {
|
||||
await handleThinking(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(model, { thinkingEnabled: true });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(model);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(model);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(model);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
||||
const llm = getModel("xai", "grok-code-fast-1");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
||||
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
||||
const llm = getModel("openrouter", "z-ai/glm-4.5v");
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
||||
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Check if ollama is installed
|
||||
let ollamaInstalled = false;
|
||||
try {
|
||||
execSync("which ollama", { stdio: "ignore" });
|
||||
ollamaInstalled = true;
|
||||
} catch {
|
||||
ollamaInstalled = false;
|
||||
}
|
||||
|
||||
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: Model<"openai-completions">;
|
||||
let ollamaProcess: ChildProcess | null = null;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Check if model is available, if not pull it
|
||||
try {
|
||||
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
||||
} catch {
|
||||
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
||||
try {
|
||||
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
||||
} catch (e) {
|
||||
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Start ollama server
|
||||
ollamaProcess = spawn("ollama", ["serve"], {
|
||||
detached: false,
|
||||
stdio: "ignore",
|
||||
});
|
||||
|
||||
// Wait for server to be ready
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkServer = async () => {
|
||||
try {
|
||||
const response = await fetch("http://localhost:11434/api/tags");
|
||||
if (response.ok) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
} catch {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
};
|
||||
setTimeout(checkServer, 1000); // Initial delay
|
||||
});
|
||||
|
||||
llm = {
|
||||
id: "gpt-oss:20b",
|
||||
api: "openai-completions",
|
||||
provider: "ollama",
|
||||
baseUrl: "http://localhost:11434/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16000,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
name: "Ollama GPT-OSS 20B",
|
||||
};
|
||||
}, 30000); // 30 second timeout for setup
|
||||
|
||||
afterAll(() => {
|
||||
// Kill ollama server
|
||||
if (ollamaProcess) {
|
||||
ollamaProcess.kill("SIGTERM");
|
||||
ollamaProcess = null;
|
||||
}
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm, { apiKey: "test" });
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,503 +1,489 @@
|
|||
import { describe, it, expect, beforeAll } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, Context, AssistantMessage, Tool, Message } from "../src/types.js";
|
||||
import { createLLM, getModel } from "../src/models.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { complete } from "../src/generate.js";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js";
|
||||
|
||||
// Tool for testing
|
||||
const weatherTool: Tool = {
|
||||
name: "get_weather",
|
||||
description: "Get the weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" }
|
||||
},
|
||||
required: ["location"]
|
||||
}
|
||||
name: "get_weather",
|
||||
description: "Get the weather for a location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: { type: "string", description: "City name" },
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
};
|
||||
|
||||
// Pre-built contexts representing typical outputs from each provider
|
||||
const providerContexts = {
|
||||
// Anthropic-style message with thinking block
|
||||
anthropic: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
||||
thinkingSignature: "signature_abc123"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "toolu_01abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Tokyo" }
|
||||
}
|
||||
],
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-haiku-latest",
|
||||
usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "toolu_01abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 391,
|
||||
city: "Tokyo",
|
||||
temperature: 18,
|
||||
capital: "Vienna"
|
||||
}
|
||||
},
|
||||
// Anthropic-style message with thinking block
|
||||
anthropic: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
||||
thinkingSignature: "signature_abc123",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "toolu_01abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Tokyo" },
|
||||
},
|
||||
],
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-haiku-latest",
|
||||
usage: {
|
||||
input: 100,
|
||||
output: 50,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "toolu_01abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 391,
|
||||
city: "Tokyo",
|
||||
temperature: 18,
|
||||
capital: "Vienna",
|
||||
},
|
||||
},
|
||||
|
||||
// Google-style message with thinking
|
||||
google: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
||||
thinkingSignature: undefined
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_gemini_123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Berlin" }
|
||||
}
|
||||
],
|
||||
provider: "google",
|
||||
model: "gemini-2.5-flash",
|
||||
usage: { input: 120, output: 60, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_gemini_123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Berlin: 22°C, sunny",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 456,
|
||||
city: "Berlin",
|
||||
temperature: 22,
|
||||
capital: "Paris"
|
||||
}
|
||||
},
|
||||
// Google-style message with thinking
|
||||
google: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking:
|
||||
"I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
||||
thinkingSignature: undefined,
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_gemini_123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "Berlin" },
|
||||
},
|
||||
],
|
||||
provider: "google",
|
||||
model: "gemini-2.5-flash",
|
||||
usage: {
|
||||
input: 120,
|
||||
output: 60,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_gemini_123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Berlin: 22°C, sunny",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 456,
|
||||
city: "Berlin",
|
||||
temperature: 22,
|
||||
capital: "Paris",
|
||||
},
|
||||
},
|
||||
|
||||
// OpenAI Completions style (with reasoning_content)
|
||||
openaiCompletions: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
||||
thinkingSignature: "reasoning_content"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now."
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "London" }
|
||||
}
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
usage: { input: 110, output: 55, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in London: 15°C, rainy",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 525,
|
||||
city: "London",
|
||||
temperature: 15,
|
||||
capital: "Madrid"
|
||||
}
|
||||
},
|
||||
// OpenAI Completions style (with reasoning_content)
|
||||
openaiCompletions: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
||||
thinkingSignature: "reasoning_content",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now.",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_abc123",
|
||||
name: "get_weather",
|
||||
arguments: { location: "London" },
|
||||
},
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
usage: {
|
||||
input: 110,
|
||||
output: 55,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_abc123",
|
||||
toolName: "get_weather",
|
||||
content: "Weather in London: 15°C, rainy",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 525,
|
||||
city: "London",
|
||||
temperature: 15,
|
||||
capital: "Madrid",
|
||||
},
|
||||
},
|
||||
|
||||
// OpenAI Responses style (with complex tool call IDs)
|
||||
openaiResponses: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
||||
thinkingSignature: '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}'
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
||||
textSignature: "msg_response_456"
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney" }
|
||||
}
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-5-mini",
|
||||
usage: { input: 115, output: 58, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "toolUse"
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Sydney: 25°C, clear",
|
||||
isError: false
|
||||
},
|
||||
facts: {
|
||||
calculation: 486,
|
||||
city: "Sydney",
|
||||
temperature: 25,
|
||||
capital: "Rome"
|
||||
}
|
||||
},
|
||||
// OpenAI Responses style (with complex tool call IDs)
|
||||
openaiResponses: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
||||
thinkingSignature:
|
||||
'{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}',
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
||||
textSignature: "msg_response_456",
|
||||
},
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
||||
name: "get_weather",
|
||||
arguments: { location: "Sydney" },
|
||||
},
|
||||
],
|
||||
provider: "openai",
|
||||
model: "gpt-5-mini",
|
||||
usage: {
|
||||
input: 115,
|
||||
output: 58,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
} as AssistantMessage,
|
||||
toolResult: {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||
toolName: "get_weather",
|
||||
content: "Weather in Sydney: 25°C, clear",
|
||||
isError: false,
|
||||
},
|
||||
facts: {
|
||||
calculation: 486,
|
||||
city: "Sydney",
|
||||
temperature: 25,
|
||||
capital: "Rome",
|
||||
},
|
||||
},
|
||||
|
||||
// Aborted message (stopReason: 'error')
|
||||
aborted: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me start calculating 20 * 30...",
|
||||
thinkingSignature: "partial_sig"
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I was about to calculate 20 × 30 which is"
|
||||
}
|
||||
],
|
||||
provider: "test",
|
||||
model: "test-model",
|
||||
usage: { input: 50, output: 25, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
stopReason: "error",
|
||||
error: "Request was aborted"
|
||||
} as AssistantMessage,
|
||||
toolResult: null,
|
||||
facts: {
|
||||
calculation: 600,
|
||||
city: "none",
|
||||
temperature: 0,
|
||||
capital: "none"
|
||||
}
|
||||
}
|
||||
// Aborted message (stopReason: 'error')
|
||||
aborted: {
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "Let me start calculating 20 * 30...",
|
||||
thinkingSignature: "partial_sig",
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
text: "I was about to calculate 20 × 30 which is",
|
||||
},
|
||||
],
|
||||
provider: "test",
|
||||
model: "test-model",
|
||||
usage: {
|
||||
input: 50,
|
||||
output: 25,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
error: "Request was aborted",
|
||||
} as AssistantMessage,
|
||||
toolResult: null,
|
||||
facts: {
|
||||
calculation: 600,
|
||||
city: "none",
|
||||
temperature: 0,
|
||||
capital: "none",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Test that a provider can handle contexts from different sources
|
||||
*/
|
||||
async function testProviderHandoff(
|
||||
targetProvider: LLM<any>,
|
||||
sourceLabel: string,
|
||||
sourceContext: typeof providerContexts[keyof typeof providerContexts]
|
||||
async function testProviderHandoff<TApi extends Api>(
|
||||
targetModel: Model<TApi>,
|
||||
sourceLabel: string,
|
||||
sourceContext: (typeof providerContexts)[keyof typeof providerContexts],
|
||||
): Promise<boolean> {
|
||||
// Build conversation context
|
||||
const messages: Message[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Please do some calculations, tell me about capitals, and check the weather."
|
||||
},
|
||||
sourceContext.message
|
||||
];
|
||||
// Build conversation context
|
||||
const messages: Message[] = [
|
||||
{
|
||||
role: "user",
|
||||
content: "Please do some calculations, tell me about capitals, and check the weather.",
|
||||
},
|
||||
sourceContext.message,
|
||||
];
|
||||
|
||||
// Add tool result if present
|
||||
if (sourceContext.toolResult) {
|
||||
messages.push(sourceContext.toolResult);
|
||||
}
|
||||
// Add tool result if present
|
||||
if (sourceContext.toolResult) {
|
||||
messages.push(sourceContext.toolResult);
|
||||
}
|
||||
|
||||
// Ask follow-up question
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Based on our conversation, please answer:
|
||||
// Ask follow-up question
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Based on our conversation, please answer:
|
||||
1) What was the multiplication result?
|
||||
2) Which city's weather did we check?
|
||||
3) What was the temperature?
|
||||
4) What capital city was mentioned?
|
||||
Please include the specific numbers and names.`
|
||||
});
|
||||
Please include the specific numbers and names.`,
|
||||
});
|
||||
|
||||
const context: Context = {
|
||||
messages,
|
||||
tools: [weatherTool]
|
||||
};
|
||||
const context: Context = {
|
||||
messages,
|
||||
tools: [weatherTool],
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await targetProvider.generate(context, {});
|
||||
try {
|
||||
const response = await complete(targetModel, context, {});
|
||||
|
||||
// Check for error
|
||||
if (response.stopReason === "error") {
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Failed with error: ${response.error}`);
|
||||
return false;
|
||||
}
|
||||
// Check for error
|
||||
if (response.stopReason === "error") {
|
||||
console.log(`[${sourceLabel} → ${targetModel.provider}] Failed with error: ${response.error}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
const responseText = response.content
|
||||
.filter(b => b.type === "text")
|
||||
.map(b => b.text)
|
||||
.join(" ")
|
||||
.toLowerCase();
|
||||
// Extract text from response
|
||||
const responseText = response.content
|
||||
.filter((b) => b.type === "text")
|
||||
.map((b) => b.text)
|
||||
.join(" ")
|
||||
.toLowerCase();
|
||||
|
||||
// For aborted messages, we don't expect to find the facts
|
||||
if (sourceContext.message.stopReason === "error") {
|
||||
const hasToolCalls = response.content.some(b => b.type === "toolCall");
|
||||
const hasThinking = response.content.some(b => b.type === "thinking");
|
||||
const hasText = response.content.some(b => b.type === "text");
|
||||
// For aborted messages, we don't expect to find the facts
|
||||
if (sourceContext.message.stopReason === "error") {
|
||||
const hasToolCalls = response.content.some((b) => b.type === "toolCall");
|
||||
const hasThinking = response.content.some((b) => b.type === "thinking");
|
||||
const hasText = response.content.some((b) => b.type === "text");
|
||||
|
||||
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
||||
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`);
|
||||
return true;
|
||||
}
|
||||
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
||||
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
||||
console.log(
|
||||
`[${sourceLabel} → ${targetModel.provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if response contains our facts
|
||||
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
||||
const hasCity = sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
||||
const hasTemperature = sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
||||
const hasCapital = sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
||||
// Check if response contains our facts
|
||||
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
||||
const hasCity =
|
||||
sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
||||
const hasTemperature =
|
||||
sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
||||
const hasCapital =
|
||||
sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
||||
|
||||
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
||||
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
||||
|
||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handoff test:`);
|
||||
if (!success) {
|
||||
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? '✓' : '✗'}`);
|
||||
console.log(` City (${sourceContext.facts.city}): ${hasCity ? '✓' : '✗'}`);
|
||||
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? '✓' : '✗'}`);
|
||||
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? '✓' : '✗'}`);
|
||||
} else {
|
||||
console.log(` ✓ All facts found`);
|
||||
}
|
||||
console.log(`[${sourceLabel} → ${targetModel.provider}] Handoff test:`);
|
||||
if (!success) {
|
||||
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? "✓" : "✗"}`);
|
||||
console.log(` City (${sourceContext.facts.city}): ${hasCity ? "✓" : "✗"}`);
|
||||
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? "✓" : "✗"}`);
|
||||
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? "✓" : "✗"}`);
|
||||
} else {
|
||||
console.log(` ✓ All facts found`);
|
||||
}
|
||||
|
||||
return success;
|
||||
} catch (error) {
|
||||
console.error(`[${sourceLabel} → ${targetProvider.getModel().provider}] Exception:`, error);
|
||||
return false;
|
||||
}
|
||||
return success;
|
||||
} catch (error) {
|
||||
console.error(`[${sourceLabel} → ${targetModel.provider}] Exception:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
describe("Cross-Provider Handoff Tests", () => {
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
||||
let provider: AnthropicLLM;
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||
if (model) {
|
||||
provider = new AnthropicLLM(model, process.env.ANTHROPIC_API_KEY!);
|
||||
}
|
||||
});
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("Anthropic provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting Google with pre-built contexts:\n");
|
||||
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
||||
let provider: GoogleLLM;
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
if (model) {
|
||||
provider = new GoogleLLM(model, process.env.GEMINI_API_KEY!);
|
||||
}
|
||||
});
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("Google provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
console.log("\nTesting Google with pre-built contexts:\n");
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
||||
const model: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
||||
let provider: OpenAICompletionsLLM;
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
if (model) {
|
||||
provider = new OpenAICompletionsLLM(model, process.env.OPENAI_API_KEY!);
|
||||
}
|
||||
});
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("OpenAI Completions provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
|
||||
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
it("should handle contexts from all providers", async () => {
|
||||
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||
];
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === model.id) {
|
||||
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(model, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
||||
let provider: OpenAIResponsesLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
const model = getModel("openai", "gpt-5-mini");
|
||||
if (model) {
|
||||
provider = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
||||
}
|
||||
});
|
||||
|
||||
it("should handle contexts from all providers", async () => {
|
||||
if (!provider) {
|
||||
console.log("OpenAI Responses provider not available, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
||||
|
||||
const contextTests = [
|
||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
||||
];
|
||||
|
||||
let successCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
for (const { label, context, sourceModel } of contextTests) {
|
||||
// Skip testing same model against itself
|
||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
const success = await testProviderHandoff(provider, label, context);
|
||||
if (success) successCount++;
|
||||
}
|
||||
|
||||
const totalTests = contextTests.length - skippedCount;
|
||||
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
});
|
||||
// All non-skipped handoffs should succeed
|
||||
expect(successCount).toBe(totalTests);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
import { GoogleGenAI } from "@google/genai";
|
||||
import OpenAI from "openai";
|
||||
|
||||
const ai = new GoogleGenAI({});
|
||||
|
||||
async function main() {
|
||||
/*let pager = await ai.models.list();
|
||||
do {
|
||||
for (const model of pager.page) {
|
||||
console.log(JSON.stringify(model, null, 2));
|
||||
console.log("---");
|
||||
}
|
||||
if (!pager.hasNextPage()) break;
|
||||
await pager.nextPage();
|
||||
} while (true);*/
|
||||
|
||||
const openai = new OpenAI();
|
||||
const response = await openai.models.list();
|
||||
do {
|
||||
const page = response.data;
|
||||
for (const model of page) {
|
||||
const info = await openai.models.retrieve(model.id);
|
||||
console.log(JSON.stringify(model, null, 2));
|
||||
console.log("---");
|
||||
}
|
||||
if (!response.hasNextPage()) break;
|
||||
await response.getNextPage();
|
||||
} while (true);
|
||||
}
|
||||
|
||||
await main();
|
||||
|
|
@ -1,618 +0,0 @@
|
|||
import { describe, it, beforeAll, afterAll, expect } from "vitest";
|
||||
import { GoogleLLM } from "../src/providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
||||
import type { LLM, LLMOptions, Context, Tool, AssistantMessage, Model, ImageContent } from "../src/types.js";
|
||||
import { spawn, ChildProcess, execSync } from "child_process";
|
||||
import { createLLM, getModel } from "../src/models.js";
|
||||
import { readFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Calculator tool definition (same as examples)
|
||||
const calculatorTool: Tool = {
|
||||
name: "calculator",
|
||||
description: "Perform basic arithmetic operations",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
a: { type: "number", description: "First number" },
|
||||
b: { type: "number", description: "Second number" },
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["add", "subtract", "multiply", "divide"],
|
||||
description: "The operation to perform"
|
||||
}
|
||||
},
|
||||
required: ["a", "b", "operation"]
|
||||
}
|
||||
};
|
||||
|
||||
async function basicTextGeneration<T extends LLMOptions>(llm: LLM<T>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||
messages: [
|
||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
||||
]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
|
||||
expect(response.role).toBe("assistant");
|
||||
expect(response.content).toBeTruthy();
|
||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.error).toBeFalsy();
|
||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
||||
|
||||
context.messages.push(response);
|
||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||
|
||||
const secondResponse = await llm.generate(context);
|
||||
|
||||
expect(secondResponse.role).toBe("assistant");
|
||||
expect(secondResponse.content).toBeTruthy();
|
||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||
expect(secondResponse.error).toBeFalsy();
|
||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
||||
}
|
||||
|
||||
async function handleToolCall<T extends LLMOptions>(llm: LLM<T>) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||
messages: [{
|
||||
role: "user",
|
||||
content: "Calculate 15 + 27 using the calculator tool."
|
||||
}],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
||||
const toolCall = response.content.find(b => b.type == "toolCall")!;
|
||||
expect(toolCall.name).toBe("calculator");
|
||||
expect(toolCall.id).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleStreaming<T extends LLMOptions>(llm: LLM<T>) {
|
||||
let textStarted = false;
|
||||
let textChunks = "";
|
||||
let textCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
onEvent: (event) => {
|
||||
if (event.type === "text_start") {
|
||||
textStarted = true;
|
||||
} else if (event.type === "text_delta") {
|
||||
textChunks += event.delta;
|
||||
} else if (event.type === "text_end") {
|
||||
textCompleted = true;
|
||||
}
|
||||
}
|
||||
} as T);
|
||||
|
||||
expect(textStarted).toBe(true);
|
||||
expect(textChunks.length).toBeGreaterThan(0);
|
||||
expect(textCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleThinking<T extends LLMOptions>(llm: LLM<T>, options: T) {
|
||||
let thinkingStarted = false;
|
||||
let thinkingChunks = "";
|
||||
let thinkingCompleted = false;
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
||||
};
|
||||
|
||||
const response = await llm.generate(context, {
|
||||
onEvent: (event) => {
|
||||
if (event.type === "thinking_start") {
|
||||
thinkingStarted = true;
|
||||
} else if (event.type === "thinking_delta") {
|
||||
expect(event.content.endsWith(event.delta)).toBe(true);
|
||||
thinkingChunks += event.delta;
|
||||
} else if (event.type === "thinking_end") {
|
||||
thinkingCompleted = true;
|
||||
}
|
||||
},
|
||||
...options
|
||||
});
|
||||
|
||||
|
||||
expect(response.stopReason, `Error: ${(response as any).error}`).toBe("stop");
|
||||
expect(thinkingStarted).toBe(true);
|
||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||
expect(thinkingCompleted).toBe(true);
|
||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
||||
}
|
||||
|
||||
async function handleImage<T extends LLMOptions>(llm: LLM<T>) {
|
||||
// Check if the model supports images
|
||||
const model = llm.getModel();
|
||||
if (!model.input.includes("image")) {
|
||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the test image
|
||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||
const imageBuffer = readFileSync(imagePath);
|
||||
const base64Image = imageBuffer.toString("base64");
|
||||
|
||||
const imageContent: ImageContent = {
|
||||
type: "image",
|
||||
data: base64Image,
|
||||
mimeType: "image/png",
|
||||
};
|
||||
|
||||
const context: Context = {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
||||
imageContent,
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await llm.generate(context);
|
||||
|
||||
// Check the response mentions red and circle
|
||||
expect(response.content.length > 0).toBeTruthy();
|
||||
const lowerContent = response.content.find(b => b.type == "text")?.text || "";
|
||||
expect(lowerContent).toContain("red");
|
||||
expect(lowerContent).toContain("circle");
|
||||
}
|
||||
|
||||
async function multiTurn<T extends LLMOptions>(llm: LLM<T>, thinkingOptions: T) {
|
||||
const context: Context = {
|
||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
||||
}
|
||||
],
|
||||
tools: [calculatorTool]
|
||||
};
|
||||
|
||||
// Collect all text content from all assistant responses
|
||||
let allTextContent = "";
|
||||
let hasSeenThinking = false;
|
||||
let hasSeenToolCalls = false;
|
||||
const maxTurns = 5; // Prevent infinite loops
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
const response = await llm.generate(context, thinkingOptions);
|
||||
|
||||
// Add the assistant response to context
|
||||
context.messages.push(response);
|
||||
|
||||
// Process content blocks
|
||||
for (const block of response.content) {
|
||||
if (block.type === "text") {
|
||||
allTextContent += block.text;
|
||||
} else if (block.type === "thinking") {
|
||||
hasSeenThinking = true;
|
||||
} else if (block.type === "toolCall") {
|
||||
hasSeenToolCalls = true;
|
||||
|
||||
// Process the tool call
|
||||
expect(block.name).toBe("calculator");
|
||||
expect(block.id).toBeTruthy();
|
||||
expect(block.arguments).toBeTruthy();
|
||||
|
||||
const { a, b, operation } = block.arguments;
|
||||
let result: number;
|
||||
switch (operation) {
|
||||
case "add": result = a + b; break;
|
||||
case "multiply": result = a * b; break;
|
||||
default: result = 0;
|
||||
}
|
||||
|
||||
// Add tool result to context
|
||||
context.messages.push({
|
||||
role: "toolResult",
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
content: `${result}`,
|
||||
isError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a stop response with text content, we're likely done
|
||||
expect(response.stopReason).not.toBe("error");
|
||||
if (response.stopReason === "stop") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we got either thinking content or tool calls (or both)
|
||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||
|
||||
// The accumulated text should reference both calculations
|
||||
expect(allTextContent).toBeTruthy();
|
||||
expect(allTextContent.includes("714")).toBe(true);
|
||||
expect(allTextContent.includes("887")).toBe(true);
|
||||
}
|
||||
|
||||
describe("AI Providers E2E Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
||||
let llm: GoogleLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {thinking: { enabled: true, budgetTokens: 1024 }});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||
let llm: OpenAIResponsesLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAIResponsesLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", {retry: 2}, async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "high"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "high"});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new AnthropicLLM(getModel("anthropic", "claude-sonnet-4-20250514")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {thinking: { enabled: true } });
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("xai", "grok-code-fast-1")!, process.env.XAI_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("groq", "openai/gpt-oss-20b")!, process.env.GROQ_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("cerebras", "gpt-oss-120b")!, process.env.CEREBRAS_API_KEY!);
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = new OpenAICompletionsLLM(getModel("openrouter", "z-ai/glm-4.5v")!, process.env.OPENROUTER_API_KEY!);;
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
|
||||
// Check if ollama is installed
|
||||
let ollamaInstalled = false;
|
||||
try {
|
||||
execSync("which ollama", { stdio: "ignore" });
|
||||
ollamaInstalled = true;
|
||||
} catch {
|
||||
ollamaInstalled = false;
|
||||
}
|
||||
|
||||
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||
let llm: OpenAICompletionsLLM;
|
||||
let ollamaProcess: ChildProcess | null = null;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Check if model is available, if not pull it
|
||||
try {
|
||||
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
||||
} catch {
|
||||
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
||||
try {
|
||||
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
||||
} catch (e) {
|
||||
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Start ollama server
|
||||
ollamaProcess = spawn("ollama", ["serve"], {
|
||||
detached: false,
|
||||
stdio: "ignore"
|
||||
});
|
||||
|
||||
// Wait for server to be ready
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkServer = async () => {
|
||||
try {
|
||||
const response = await fetch("http://localhost:11434/api/tags");
|
||||
if (response.ok) {
|
||||
resolve();
|
||||
} else {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
} catch {
|
||||
setTimeout(checkServer, 500);
|
||||
}
|
||||
};
|
||||
setTimeout(checkServer, 1000); // Initial delay
|
||||
});
|
||||
|
||||
const model: Model = {
|
||||
id: "gpt-oss:20b",
|
||||
provider: "ollama",
|
||||
baseUrl: "http://localhost:11434/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16000,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
name: "Ollama GPT-OSS 20B"
|
||||
}
|
||||
llm = new OpenAICompletionsLLM(model, "dummy");
|
||||
}, 30000); // 30 second timeout for setup
|
||||
|
||||
afterAll(() => {
|
||||
// Kill ollama server
|
||||
if (ollamaProcess) {
|
||||
ollamaProcess.kill("SIGTERM");
|
||||
ollamaProcess = null;
|
||||
}
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle thinking mode", async () => {
|
||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
||||
});
|
||||
});
|
||||
|
||||
/*
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => {
|
||||
let llm: AnthropicLLM;
|
||||
|
||||
beforeAll(() => {
|
||||
llm = createLLM("anthropic", "claude-3-5-haiku-latest");
|
||||
});
|
||||
|
||||
it("should complete basic text generation", async () => {
|
||||
await basicTextGeneration(llm);
|
||||
});
|
||||
|
||||
it("should handle tool calling", async () => {
|
||||
await handleToolCall(llm);
|
||||
});
|
||||
|
||||
it("should handle streaming", async () => {
|
||||
await handleStreaming(llm);
|
||||
});
|
||||
|
||||
it("should handle multi-turn with thinking and tools", async () => {
|
||||
await multiTurn(llm, {thinking: {enabled: true}});
|
||||
});
|
||||
|
||||
it("should handle image input", async () => {
|
||||
await handleImage(llm);
|
||||
});
|
||||
});
|
||||
*/
|
||||
});
|
||||
|
|
@ -1,15 +1,6 @@
|
|||
#!/usr/bin/env npx tsx
|
||||
import {
|
||||
Container,
|
||||
LoadingAnimation,
|
||||
TextComponent,
|
||||
TextEditor,
|
||||
TUI,
|
||||
WhitespaceComponent,
|
||||
} from "../src/index.js";
|
||||
import chalk from "chalk";
|
||||
|
||||
|
||||
import { Container, LoadingAnimation, TextComponent, TextEditor, TUI, WhitespaceComponent } from "../src/index.js";
|
||||
|
||||
/**
|
||||
* Test the new smart double-buffered TUI implementation
|
||||
|
|
@ -24,7 +15,7 @@ async function main() {
|
|||
|
||||
// Monkey-patch requestRender to measure performance
|
||||
const originalRequestRender = ui.requestRender.bind(ui);
|
||||
ui.requestRender = function() {
|
||||
ui.requestRender = () => {
|
||||
const startTime = process.hrtime.bigint();
|
||||
originalRequestRender();
|
||||
process.nextTick(() => {
|
||||
|
|
@ -38,10 +29,12 @@ async function main() {
|
|||
|
||||
// Add header
|
||||
const header = new TextComponent(
|
||||
chalk.bold.green("Smart Double Buffer TUI Test") + "\n" +
|
||||
chalk.dim("Testing new implementation with component-level caching and smart diffing") + "\n" +
|
||||
chalk.dim("Press CTRL+C to exit"),
|
||||
{ bottom: 1 }
|
||||
chalk.bold.green("Smart Double Buffer TUI Test") +
|
||||
"\n" +
|
||||
chalk.dim("Testing new implementation with component-level caching and smart diffing") +
|
||||
"\n" +
|
||||
chalk.dim("Press CTRL+C to exit"),
|
||||
{ bottom: 1 },
|
||||
);
|
||||
ui.addChild(header);
|
||||
|
||||
|
|
@ -57,7 +50,9 @@ async function main() {
|
|||
|
||||
// Add text editor
|
||||
const editor = new TextEditor();
|
||||
editor.setText("Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still.");
|
||||
editor.setText(
|
||||
"Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still.",
|
||||
);
|
||||
container.addChild(editor);
|
||||
|
||||
// Add the container to UI
|
||||
|
|
@ -71,15 +66,20 @@ async function main() {
|
|||
const statsInterval = setInterval(() => {
|
||||
if (renderCount > 0) {
|
||||
const avgRenderTime = Number(totalRenderTime / BigInt(renderCount)) / 1_000_000; // Convert to ms
|
||||
const lastRenderTime = renderTimings.length > 0
|
||||
? Number(renderTimings[renderTimings.length - 1]) / 1_000_000
|
||||
: 0;
|
||||
const lastRenderTime =
|
||||
renderTimings.length > 0 ? Number(renderTimings[renderTimings.length - 1]) / 1_000_000 : 0;
|
||||
const avgLinesRedrawn = ui.getAverageLinesRedrawn();
|
||||
|
||||
statsComponent.setText(
|
||||
chalk.yellow(`Performance Stats:`) + "\n" +
|
||||
chalk.dim(`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`) + "\n" +
|
||||
chalk.dim(`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`)
|
||||
chalk.yellow(`Performance Stats:`) +
|
||||
"\n" +
|
||||
chalk.dim(
|
||||
`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`,
|
||||
) +
|
||||
"\n" +
|
||||
chalk.dim(
|
||||
`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
}, 1000);
|
||||
|
|
@ -96,7 +96,11 @@ async function main() {
|
|||
ui.stop();
|
||||
console.log("\n" + chalk.green("Exited double-buffer test"));
|
||||
console.log(chalk.dim(`Total renders: ${renderCount}`));
|
||||
console.log(chalk.dim(`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`));
|
||||
console.log(
|
||||
chalk.dim(
|
||||
`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`,
|
||||
),
|
||||
);
|
||||
console.log(chalk.dim(`Total lines redrawn: ${ui.getLinesRedrawn()}`));
|
||||
console.log(chalk.dim(`Average lines redrawn per render: ${ui.getAverageLinesRedrawn().toFixed(1)}`));
|
||||
process.exit(0);
|
||||
|
|
@ -112,4 +116,4 @@ async function main() {
|
|||
main().catch((error) => {
|
||||
console.error("Error:", error);
|
||||
process.exit(1);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
#!/usr/bin/env npx tsx
|
||||
import { TUI, Container, TextEditor, TextComponent, MarkdownComponent, CombinedAutocompleteProvider } from "../src/index.js";
|
||||
import {
|
||||
CombinedAutocompleteProvider,
|
||||
Container,
|
||||
MarkdownComponent,
|
||||
TextComponent,
|
||||
TextEditor,
|
||||
TUI,
|
||||
} from "../src/index.js";
|
||||
|
||||
/**
|
||||
* Chat Application with Autocomplete
|
||||
*
|
||||
*
|
||||
* Demonstrates:
|
||||
* - Slash command system with autocomplete
|
||||
* - Dynamic message history
|
||||
|
|
@ -16,7 +23,7 @@ const ui = new TUI();
|
|||
// Add header with instructions
|
||||
const header = new TextComponent(
|
||||
"💬 Chat Demo | Type '/' for commands | Start typing a filename + Tab to autocomplete | Ctrl+C to exit",
|
||||
{ bottom: 1 }
|
||||
{ bottom: 1 },
|
||||
);
|
||||
|
||||
const chatHistory = new Container();
|
||||
|
|
@ -82,7 +89,8 @@ ui.onGlobalKeyPress = (data: string) => {
|
|||
};
|
||||
|
||||
// Add initial welcome message to chat history
|
||||
chatHistory.addChild(new MarkdownComponent(`
|
||||
chatHistory.addChild(
|
||||
new MarkdownComponent(`
|
||||
## Welcome to the Chat Demo!
|
||||
|
||||
**Available slash commands:**
|
||||
|
|
@ -96,10 +104,11 @@ chatHistory.addChild(new MarkdownComponent(`
|
|||
- Works with home directory (\`~/\`)
|
||||
|
||||
Try it out! Type a message or command below.
|
||||
`));
|
||||
`),
|
||||
);
|
||||
|
||||
ui.addChild(header);
|
||||
ui.addChild(chatHistory);
|
||||
ui.addChild(editor);
|
||||
ui.setFocus(editor);
|
||||
ui.start();
|
||||
ui.start();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { test, describe } from "node:test";
|
||||
import assert from "node:assert";
|
||||
import { describe, test } from "node:test";
|
||||
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
import { TUI, Container, TextComponent, TextEditor } from "../src/index.js";
|
||||
|
||||
describe("Differential Rendering - Dynamic Content", () => {
|
||||
test("handles static text, dynamic container, and text editor correctly", async () => {
|
||||
|
|
@ -23,7 +23,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
ui.setFocus(editor);
|
||||
|
||||
// Wait for next tick to complete and flush virtual terminal
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 4: Check initial output in scrollbuffer
|
||||
|
|
@ -33,14 +33,16 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
console.log("Initial render:");
|
||||
console.log("Viewport lines:", viewport.length);
|
||||
console.log("ScrollBuffer lines:", scrollBuffer.length);
|
||||
|
||||
|
||||
// Count non-empty lines in scrollbuffer
|
||||
let nonEmptyInBuffer = scrollBuffer.filter(line => line.trim() !== "").length;
|
||||
const nonEmptyInBuffer = scrollBuffer.filter((line) => line.trim() !== "").length;
|
||||
console.log("Non-empty lines in scrollbuffer:", nonEmptyInBuffer);
|
||||
|
||||
// Verify initial render has static text in scrollbuffer
|
||||
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")),
|
||||
`Expected static text in scrollbuffer`);
|
||||
assert.ok(
|
||||
scrollBuffer.some((line) => line.includes("Static Header Text")),
|
||||
`Expected static text in scrollbuffer`,
|
||||
);
|
||||
|
||||
// Step 5: Add 100 text components to container
|
||||
console.log("\nAdding 100 components to container...");
|
||||
|
|
@ -52,7 +54,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
ui.requestRender();
|
||||
|
||||
// Wait for next tick to complete and flush
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 6: Check output after adding 100 components
|
||||
|
|
@ -62,10 +64,10 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
console.log("\nAfter adding 100 items:");
|
||||
console.log("Viewport lines:", viewport.length);
|
||||
console.log("ScrollBuffer lines:", scrollBuffer.length);
|
||||
|
||||
|
||||
// Count all dynamic items in scrollbuffer
|
||||
let dynamicItemsInBuffer = 0;
|
||||
let allItemNumbers = new Set<number>();
|
||||
const allItemNumbers = new Set<number>();
|
||||
for (const line of scrollBuffer) {
|
||||
const match = line.match(/Dynamic Item (\d+)/);
|
||||
if (match) {
|
||||
|
|
@ -73,31 +75,39 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
allItemNumbers.add(parseInt(match[1]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
console.log("Dynamic items found in scrollbuffer:", dynamicItemsInBuffer);
|
||||
console.log("Unique item numbers:", allItemNumbers.size);
|
||||
console.log("Item range:", Math.min(...allItemNumbers), "-", Math.max(...allItemNumbers));
|
||||
|
||||
|
||||
// CRITICAL TEST: The scrollbuffer should contain ALL 100 items
|
||||
// This is what the differential render should preserve!
|
||||
assert.strictEqual(allItemNumbers.size, 100,
|
||||
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`);
|
||||
|
||||
assert.strictEqual(
|
||||
allItemNumbers.size,
|
||||
100,
|
||||
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`,
|
||||
);
|
||||
|
||||
// Verify items are 1-100
|
||||
for (let i = 1; i <= 100; i++) {
|
||||
assert.ok(allItemNumbers.has(i), `Missing Dynamic Item ${i} in scrollbuffer`);
|
||||
}
|
||||
|
||||
// Also verify the static header is still in scrollbuffer
|
||||
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")),
|
||||
"Static header should still be in scrollbuffer");
|
||||
|
||||
// And the editor should be there too
|
||||
assert.ok(scrollBuffer.some(line => line.includes("╭") && line.includes("╮")),
|
||||
"Editor top border should be in scrollbuffer");
|
||||
assert.ok(scrollBuffer.some(line => line.includes("╰") && line.includes("╯")),
|
||||
"Editor bottom border should be in scrollbuffer");
|
||||
|
||||
// Also verify the static header is still in scrollbuffer
|
||||
assert.ok(
|
||||
scrollBuffer.some((line) => line.includes("Static Header Text")),
|
||||
"Static header should still be in scrollbuffer",
|
||||
);
|
||||
|
||||
// And the editor should be there too
|
||||
assert.ok(
|
||||
scrollBuffer.some((line) => line.includes("╭") && line.includes("╮")),
|
||||
"Editor top border should be in scrollbuffer",
|
||||
);
|
||||
assert.ok(
|
||||
scrollBuffer.some((line) => line.includes("╰") && line.includes("╯")),
|
||||
"Editor bottom border should be in scrollbuffer",
|
||||
);
|
||||
|
||||
ui.stop();
|
||||
});
|
||||
|
|
@ -124,7 +134,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
contentContainer.addChild(new TextComponent("Content Line 2"));
|
||||
|
||||
// Initial render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
let viewport = terminal.getViewport();
|
||||
|
|
@ -142,7 +152,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
statusContainer.addChild(new TextComponent("Status: Processing..."));
|
||||
ui.requestRender();
|
||||
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
viewport = terminal.getViewport();
|
||||
|
|
@ -162,7 +172,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
}
|
||||
ui.requestRender();
|
||||
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
viewport = terminal.getViewport();
|
||||
|
|
@ -180,7 +190,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
contentLine10.setText("Content Line 10 - MODIFIED");
|
||||
ui.requestRender();
|
||||
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
viewport = terminal.getViewport();
|
||||
|
|
@ -190,4 +200,4 @@ describe("Differential Rendering - Dynamic Content", () => {
|
|||
|
||||
ui.stop();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { TUI, SelectList } from "../src/index.js";
|
||||
import { readdirSync, statSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { SelectList, TUI } from "../src/index.js";
|
||||
|
||||
const ui = new TUI();
|
||||
ui.start();
|
||||
|
|
@ -52,4 +52,4 @@ function showDirectory(path: string) {
|
|||
ui.setFocus(fileList);
|
||||
}
|
||||
|
||||
showDirectory(currentPath);
|
||||
showDirectory(currentPath);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { describe, test } from "node:test";
|
||||
import assert from "node:assert";
|
||||
import { TextEditor, TextComponent, Container, TUI } from "../src/index.js";
|
||||
import { describe, test } from "node:test";
|
||||
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
|
||||
describe("Layout shift artifacts", () => {
|
||||
|
|
@ -27,7 +27,7 @@ describe("Layout shift artifacts", () => {
|
|||
|
||||
// Initial render
|
||||
ui.start();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Capture initial state
|
||||
|
|
@ -40,7 +40,7 @@ describe("Layout shift artifacts", () => {
|
|||
ui.requestRender();
|
||||
|
||||
// Wait for render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Capture state with status message
|
||||
|
|
@ -51,7 +51,7 @@ describe("Layout shift artifacts", () => {
|
|||
ui.requestRender();
|
||||
|
||||
// Wait for render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Capture final state
|
||||
|
|
@ -64,8 +64,12 @@ describe("Layout shift artifacts", () => {
|
|||
const nextLine = finalViewport[i + 1];
|
||||
|
||||
// Check if we have duplicate bottom borders (the artifact)
|
||||
if (currentLine.includes("╰") && currentLine.includes("╯") &&
|
||||
nextLine.includes("╰") && nextLine.includes("╯")) {
|
||||
if (
|
||||
currentLine.includes("╰") &&
|
||||
currentLine.includes("╯") &&
|
||||
nextLine.includes("╰") &&
|
||||
nextLine.includes("╯")
|
||||
) {
|
||||
foundDuplicateBorder = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -74,18 +78,12 @@ describe("Layout shift artifacts", () => {
|
|||
assert.strictEqual(foundDuplicateBorder, false, "Found duplicate bottom borders - rendering artifact detected!");
|
||||
|
||||
// Also check that there's only one bottom border total
|
||||
const bottomBorderCount = finalViewport.filter((line) =>
|
||||
line.includes("╰")
|
||||
).length;
|
||||
const bottomBorderCount = finalViewport.filter((line) => line.includes("╰")).length;
|
||||
assert.strictEqual(bottomBorderCount, 1, `Expected 1 bottom border, found ${bottomBorderCount}`);
|
||||
|
||||
// Verify the editor is back in its original position
|
||||
const finalEditorStartLine = finalViewport.findIndex((line) =>
|
||||
line.includes("╭")
|
||||
);
|
||||
const initialEditorStartLine = initialViewport.findIndex((line) =>
|
||||
line.includes("╭")
|
||||
);
|
||||
const finalEditorStartLine = finalViewport.findIndex((line) => line.includes("╭"));
|
||||
const initialEditorStartLine = initialViewport.findIndex((line) => line.includes("╭"));
|
||||
assert.strictEqual(finalEditorStartLine, initialEditorStartLine);
|
||||
|
||||
ui.stop();
|
||||
|
|
@ -103,7 +101,7 @@ describe("Layout shift artifacts", () => {
|
|||
|
||||
// Initial render
|
||||
ui.start();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Rapidly add and remove a status message
|
||||
|
|
@ -112,25 +110,21 @@ describe("Layout shift artifacts", () => {
|
|||
// Add status
|
||||
ui.children.splice(1, 0, status);
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Remove status immediately
|
||||
ui.children.splice(1, 1);
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await term.flush();
|
||||
|
||||
// Final output check
|
||||
const finalViewport = term.getViewport();
|
||||
|
||||
// Should only have one set of borders for the editor
|
||||
const topBorderCount = finalViewport.filter((line) =>
|
||||
line.includes("╭") && line.includes("╮")
|
||||
).length;
|
||||
const bottomBorderCount = finalViewport.filter((line) =>
|
||||
line.includes("╰") && line.includes("╯")
|
||||
).length;
|
||||
const topBorderCount = finalViewport.filter((line) => line.includes("╭") && line.includes("╮")).length;
|
||||
const bottomBorderCount = finalViewport.filter((line) => line.includes("╰") && line.includes("╯")).length;
|
||||
|
||||
assert.strictEqual(topBorderCount, 1);
|
||||
assert.strictEqual(bottomBorderCount, 1);
|
||||
|
|
@ -148,4 +142,4 @@ describe("Layout shift artifacts", () => {
|
|||
|
||||
ui.stop();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env npx tsx
|
||||
import { TUI, Container, TextComponent, TextEditor, MarkdownComponent } from "../src/index.js";
|
||||
import { Container, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||
|
||||
/**
|
||||
* Multi-Component Layout Demo
|
||||
|
|
@ -75,4 +75,4 @@ ui.onGlobalKeyPress = (data: string) => {
|
|||
return true;
|
||||
};
|
||||
|
||||
ui.start();
|
||||
ui.start();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { test, describe } from "node:test";
|
||||
import assert from "node:assert";
|
||||
import { describe, test } from "node:test";
|
||||
import { Container, LoadingAnimation, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
import { TUI, Container, TextComponent, MarkdownComponent, TextEditor, LoadingAnimation } from "../src/index.js";
|
||||
|
||||
describe("Multi-Message Garbled Output Reproduction", () => {
|
||||
test("handles rapid message additions with large content without garbling", async () => {
|
||||
|
|
@ -20,7 +20,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
|
|||
ui.setFocus(editor);
|
||||
|
||||
// Initial render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 1: Simulate user message
|
||||
|
|
@ -32,7 +32,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
|
|||
statusContainer.addChild(loadingAnim);
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 3: Simulate rapid tool calls with large outputs
|
||||
|
|
@ -54,7 +54,7 @@ node_modules/get-tsconfig/README.md
|
|||
chatContainer.addChild(new TextComponent(globResult));
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Simulate multiple read tool calls with long content
|
||||
|
|
@ -74,7 +74,7 @@ A collection of tools for managing LLM deployments and building AI agents.
|
|||
chatContainer.addChild(new MarkdownComponent(readmeContent));
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Second read with even more content
|
||||
|
|
@ -94,7 +94,7 @@ Terminal UI framework with surgical differential rendering for building flicker-
|
|||
chatContainer.addChild(new MarkdownComponent(tuiReadmeContent));
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 4: Stop loading animation and add assistant response
|
||||
|
|
@ -114,7 +114,7 @@ The TUI library features surgical differential rendering that minimizes screen u
|
|||
chatContainer.addChild(new MarkdownComponent(assistantResponse));
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Step 5: CRITICAL - Send a new message while previous content is displayed
|
||||
|
|
@ -126,7 +126,7 @@ The TUI library features surgical differential rendering that minimizes screen u
|
|||
statusContainer.addChild(loadingAnim2);
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Add assistant response
|
||||
|
|
@ -144,7 +144,7 @@ Key aspects:
|
|||
chatContainer.addChild(new MarkdownComponent(secondResponse));
|
||||
|
||||
ui.requestRender();
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Debug: Show the garbled output after the problematic step
|
||||
|
|
@ -153,19 +153,25 @@ Key aspects:
|
|||
debugOutput.forEach((line, i) => {
|
||||
if (line.trim()) console.log(`${i}: "${line}"`);
|
||||
});
|
||||
|
||||
|
||||
// Step 6: Check final output
|
||||
const finalOutput = terminal.getScrollBuffer();
|
||||
|
||||
// Check that first user message is NOT garbled
|
||||
const userLine1 = finalOutput.find(line => line.includes("read all README.md files"));
|
||||
assert.strictEqual(userLine1, "read all README.md files except in node_modules",
|
||||
`First user message is garbled: "${userLine1}"`);
|
||||
const userLine1 = finalOutput.find((line) => line.includes("read all README.md files"));
|
||||
assert.strictEqual(
|
||||
userLine1,
|
||||
"read all README.md files except in node_modules",
|
||||
`First user message is garbled: "${userLine1}"`,
|
||||
);
|
||||
|
||||
// Check that second user message is clean
|
||||
const userLine2 = finalOutput.find(line => line.includes("What is the main purpose"));
|
||||
assert.strictEqual(userLine2, "What is the main purpose of the TUI library?",
|
||||
`Second user message is garbled: "${userLine2}"`);
|
||||
const userLine2 = finalOutput.find((line) => line.includes("What is the main purpose"));
|
||||
assert.strictEqual(
|
||||
userLine2,
|
||||
"What is the main purpose of the TUI library?",
|
||||
`Second user message is garbled: "${userLine2}"`,
|
||||
);
|
||||
|
||||
// Check for common garbling patterns
|
||||
const garbledPatterns = [
|
||||
|
|
@ -173,14 +179,14 @@ Key aspects:
|
|||
"README.mdectly",
|
||||
"modulesl rendering",
|
||||
"[assistant]ns.",
|
||||
"node_modules/@esbuild/darwin-arm64/README.mdategy"
|
||||
"node_modules/@esbuild/darwin-arm64/README.mdategy",
|
||||
];
|
||||
|
||||
for (const pattern of garbledPatterns) {
|
||||
const hasGarbled = finalOutput.some(line => line.includes(pattern));
|
||||
const hasGarbled = finalOutput.some((line) => line.includes(pattern));
|
||||
assert.ok(!hasGarbled, `Found garbled pattern "${pattern}" in output`);
|
||||
}
|
||||
|
||||
ui.stop();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
import { test, describe } from "node:test";
|
||||
import assert from "node:assert";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
import { describe, test } from "node:test";
|
||||
import {
|
||||
TUI,
|
||||
Container,
|
||||
TextComponent,
|
||||
TextEditor,
|
||||
WhitespaceComponent,
|
||||
MarkdownComponent,
|
||||
SelectList,
|
||||
TextComponent,
|
||||
TextEditor,
|
||||
TUI,
|
||||
WhitespaceComponent,
|
||||
} from "../src/index.js";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
|
||||
describe("TUI Rendering", () => {
|
||||
|
||||
test("renders single text component", async () => {
|
||||
const terminal = new VirtualTerminal(80, 24);
|
||||
const ui = new TUI(terminal);
|
||||
|
|
@ -22,7 +21,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(text);
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
// Wait for writes to complete and get the rendered output
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
|
|
@ -48,7 +47,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(new TextComponent("Line 3"));
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Line 1");
|
||||
|
|
@ -68,7 +67,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(new TextComponent("Bottom text"));
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Top text");
|
||||
|
|
@ -96,7 +95,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(new TextComponent("After container"));
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Before container");
|
||||
|
|
@ -117,11 +116,11 @@ describe("TUI Rendering", () => {
|
|||
ui.setFocus(editor);
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
// Initial state - empty editor with cursor
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
|
||||
|
||||
// Check that we have the border characters
|
||||
assert.ok(output[0].includes("╭"));
|
||||
assert.ok(output[0].includes("╮"));
|
||||
|
|
@ -142,7 +141,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(dynamicText);
|
||||
|
||||
// Wait for initial render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
// Save initial state
|
||||
|
|
@ -153,8 +152,8 @@ describe("TUI Rendering", () => {
|
|||
ui.requestRender();
|
||||
|
||||
// Wait for render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
// Flush terminal buffer
|
||||
await terminal.flush();
|
||||
|
||||
|
|
@ -180,7 +179,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(text3);
|
||||
|
||||
// Wait for initial render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
let output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Line 1");
|
||||
|
|
@ -191,7 +190,7 @@ describe("TUI Rendering", () => {
|
|||
ui.removeChild(text2);
|
||||
ui.requestRender();
|
||||
|
||||
await new Promise(resolve => setImmediate(resolve));
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Line 1");
|
||||
|
|
@ -212,7 +211,7 @@ describe("TUI Rendering", () => {
|
|||
}
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
|
||||
|
|
@ -241,7 +240,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(new TextComponent("After"));
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
assert.strictEqual(output[0], "Before");
|
||||
|
|
@ -262,7 +261,7 @@ describe("TUI Rendering", () => {
|
|||
ui.addChild(markdown);
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
// Should have formatted markdown
|
||||
|
|
@ -289,7 +288,7 @@ describe("TUI Rendering", () => {
|
|||
ui.setFocus(selectList);
|
||||
|
||||
// Wait for next tick for render to complete
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
|
||||
const output = await terminal.flushAndGetViewport();
|
||||
// First option should be selected (has → indicator)
|
||||
|
|
@ -303,28 +302,28 @@ describe("TUI Rendering", () => {
|
|||
|
||||
test("preserves existing terminal content when rendering", async () => {
|
||||
const terminal = new VirtualTerminal(80, 24);
|
||||
|
||||
|
||||
// Write some content to the terminal before starting TUI
|
||||
// This simulates having existing content in the scrollback buffer
|
||||
terminal.write("Previous command output line 1\r\n");
|
||||
terminal.write("Previous command output line 2\r\n");
|
||||
terminal.write("Some important information\r\n");
|
||||
terminal.write("Last line before TUI starts\r\n");
|
||||
|
||||
|
||||
// Flush to ensure writes are complete
|
||||
await terminal.flush();
|
||||
|
||||
|
||||
// Get the initial state with existing content
|
||||
const initialOutput = [...terminal.getViewport()];
|
||||
assert.strictEqual(initialOutput[0], "Previous command output line 1");
|
||||
assert.strictEqual(initialOutput[1], "Previous command output line 2");
|
||||
assert.strictEqual(initialOutput[2], "Some important information");
|
||||
assert.strictEqual(initialOutput[3], "Last line before TUI starts");
|
||||
|
||||
|
||||
// Now start the TUI with a text editor
|
||||
const ui = new TUI(terminal);
|
||||
ui.start();
|
||||
|
||||
|
||||
const editor = new TextEditor();
|
||||
let submittedText = "";
|
||||
editor.onSubmit = (text) => {
|
||||
|
|
@ -332,87 +331,87 @@ describe("TUI Rendering", () => {
|
|||
};
|
||||
ui.addChild(editor);
|
||||
ui.setFocus(editor);
|
||||
|
||||
|
||||
// Wait for initial render
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
|
||||
// Check that the editor is rendered after the existing content
|
||||
const afterTuiStart = terminal.getViewport();
|
||||
|
||||
|
||||
// The existing content should still be visible above the editor
|
||||
assert.strictEqual(afterTuiStart[0], "Previous command output line 1");
|
||||
assert.strictEqual(afterTuiStart[1], "Previous command output line 2");
|
||||
assert.strictEqual(afterTuiStart[2], "Some important information");
|
||||
assert.strictEqual(afterTuiStart[3], "Last line before TUI starts");
|
||||
|
||||
|
||||
// The editor should appear after the existing content
|
||||
// The editor is 3 lines tall (top border, content line, bottom border)
|
||||
// Top border with box drawing characters filling the width (80 chars)
|
||||
assert.strictEqual(afterTuiStart[4][0], "╭");
|
||||
assert.strictEqual(afterTuiStart[4][78], "╮");
|
||||
|
||||
|
||||
// Content line should have the prompt
|
||||
assert.strictEqual(afterTuiStart[5].substring(0, 4), "│ > ");
|
||||
// And should end with vertical bar
|
||||
assert.strictEqual(afterTuiStart[5][78], "│");
|
||||
|
||||
|
||||
// Bottom border
|
||||
assert.strictEqual(afterTuiStart[6][0], "╰");
|
||||
assert.strictEqual(afterTuiStart[6][78], "╯");
|
||||
|
||||
|
||||
// Type some text into the editor
|
||||
terminal.sendInput("Hello World");
|
||||
|
||||
|
||||
// Wait for the input to be processed
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
|
||||
// Check that text appears in the editor
|
||||
const afterTyping = terminal.getViewport();
|
||||
assert.strictEqual(afterTyping[0], "Previous command output line 1");
|
||||
assert.strictEqual(afterTyping[1], "Previous command output line 2");
|
||||
assert.strictEqual(afterTyping[2], "Some important information");
|
||||
assert.strictEqual(afterTyping[3], "Last line before TUI starts");
|
||||
|
||||
|
||||
// The editor content should show the typed text with the prompt ">"
|
||||
assert.strictEqual(afterTyping[5].substring(0, 15), "│ > Hello World");
|
||||
|
||||
|
||||
// Send SHIFT+ENTER to the editor (adds a new line)
|
||||
// According to text-editor.ts line 251, SHIFT+ENTER is detected as "\n" which calls addNewLine()
|
||||
terminal.sendInput("\n");
|
||||
|
||||
|
||||
// Wait for the input to be processed
|
||||
await new Promise(resolve => process.nextTick(resolve));
|
||||
await new Promise((resolve) => process.nextTick(resolve));
|
||||
await terminal.flush();
|
||||
|
||||
|
||||
// Check that existing content is still preserved after adding new line
|
||||
const afterNewLine = terminal.getViewport();
|
||||
assert.strictEqual(afterNewLine[0], "Previous command output line 1");
|
||||
assert.strictEqual(afterNewLine[1], "Previous command output line 2");
|
||||
assert.strictEqual(afterNewLine[2], "Some important information");
|
||||
assert.strictEqual(afterNewLine[3], "Last line before TUI starts");
|
||||
|
||||
|
||||
// Editor should now be 4 lines tall (top border, first line, second line, bottom border)
|
||||
// Top border at line 4
|
||||
assert.strictEqual(afterNewLine[4][0], "╭");
|
||||
assert.strictEqual(afterNewLine[4][78], "╮");
|
||||
|
||||
|
||||
// First line with text at line 5
|
||||
assert.strictEqual(afterNewLine[5].substring(0, 15), "│ > Hello World");
|
||||
assert.strictEqual(afterNewLine[5][78], "│");
|
||||
|
||||
|
||||
// Second line (empty, with continuation prompt " ") at line 6
|
||||
assert.strictEqual(afterNewLine[6].substring(0, 4), "│ ");
|
||||
assert.strictEqual(afterNewLine[6][78], "│");
|
||||
|
||||
|
||||
// Bottom border at line 7
|
||||
assert.strictEqual(afterNewLine[7][0], "╰");
|
||||
assert.strictEqual(afterNewLine[7][78], "╯");
|
||||
|
||||
|
||||
// Verify that onSubmit was NOT called (since we pressed SHIFT+ENTER, not plain ENTER)
|
||||
assert.strictEqual(submittedText, "");
|
||||
|
||||
|
||||
ui.stop();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { test, describe } from "node:test";
|
||||
import assert from "node:assert";
|
||||
import { describe, test } from "node:test";
|
||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||
|
||||
describe("VirtualTerminal", () => {
|
||||
|
|
@ -86,13 +86,13 @@ describe("VirtualTerminal", () => {
|
|||
assert.strictEqual(viewport.length, 10);
|
||||
assert.strictEqual(viewport[0], "Line 7");
|
||||
assert.strictEqual(viewport[8], "Line 15");
|
||||
assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n
|
||||
assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n
|
||||
|
||||
// Scroll buffer should have all lines
|
||||
assert.ok(scrollBuffer.length >= 15);
|
||||
// Check specific lines exist in the buffer
|
||||
const hasLine1 = scrollBuffer.some(line => line === "Line 1");
|
||||
const hasLine15 = scrollBuffer.some(line => line === "Line 15");
|
||||
const hasLine1 = scrollBuffer.some((line) => line === "Line 1");
|
||||
const hasLine15 = scrollBuffer.some((line) => line === "Line 15");
|
||||
assert.ok(hasLine1, "Buffer should contain 'Line 1'");
|
||||
assert.ok(hasLine15, "Buffer should contain 'Line 15'");
|
||||
});
|
||||
|
|
@ -129,9 +129,12 @@ describe("VirtualTerminal", () => {
|
|||
const terminal = new VirtualTerminal(80, 24);
|
||||
|
||||
let received = "";
|
||||
terminal.start((data) => {
|
||||
received = data;
|
||||
}, () => {});
|
||||
terminal.start(
|
||||
(data) => {
|
||||
received = data;
|
||||
},
|
||||
() => {},
|
||||
);
|
||||
|
||||
terminal.sendInput("a");
|
||||
assert.strictEqual(received, "a");
|
||||
|
|
@ -146,13 +149,16 @@ describe("VirtualTerminal", () => {
|
|||
const terminal = new VirtualTerminal(80, 24);
|
||||
|
||||
let resized = false;
|
||||
terminal.start(() => {}, () => {
|
||||
resized = true;
|
||||
});
|
||||
terminal.start(
|
||||
() => {},
|
||||
() => {
|
||||
resized = true;
|
||||
},
|
||||
);
|
||||
|
||||
terminal.resize(100, 30);
|
||||
assert.strictEqual(resized, true);
|
||||
|
||||
terminal.stop();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import xterm from '@xterm/headless';
|
||||
import type { Terminal as XtermTerminalType } from '@xterm/headless';
|
||||
import { Terminal } from '../src/terminal.js';
|
||||
import type { Terminal as XtermTerminalType } from "@xterm/headless";
|
||||
import xterm from "@xterm/headless";
|
||||
import type { Terminal } from "../src/terminal.js";
|
||||
|
||||
// Extract Terminal class from the module
|
||||
const XtermTerminal = xterm.Terminal;
|
||||
|
|
@ -81,7 +81,7 @@ export class VirtualTerminal implements Terminal {
|
|||
async flush(): Promise<void> {
|
||||
// Write an empty string to ensure all previous writes are flushed
|
||||
return new Promise<void>((resolve) => {
|
||||
this.xterm.write('', () => resolve());
|
||||
this.xterm.write("", () => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ export class VirtualTerminal implements Terminal {
|
|||
if (line) {
|
||||
lines.push(line.translateToString(true));
|
||||
} else {
|
||||
lines.push('');
|
||||
lines.push("");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -127,7 +127,7 @@ export class VirtualTerminal implements Terminal {
|
|||
if (line) {
|
||||
lines.push(line.translateToString(true));
|
||||
} else {
|
||||
lines.push('');
|
||||
lines.push("");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ export class VirtualTerminal implements Terminal {
|
|||
const buffer = this.xterm.buffer.active;
|
||||
return {
|
||||
x: buffer.cursorX,
|
||||
y: buffer.cursorY
|
||||
y: buffer.cursorY,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue