mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-21 23:04:41 +00:00
Massive refactor of API
- Switch to function based API - Anthropic SDK style async generator - Fully typed with escape hatches for custom models
This commit is contained in:
parent
004de3c9d0
commit
66cefb236e
29 changed files with 5835 additions and 6225 deletions
|
|
@ -28,6 +28,6 @@
|
||||||
"lineWidth": 120
|
"lineWidth": 120
|
||||||
},
|
},
|
||||||
"files": {
|
"files": {
|
||||||
"includes": ["packages/*/src/**/*", "*.json", "*.md"]
|
"includes": ["packages/*/src/**/*", "packages/*/test/**/*", "*.json", "*.md"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import { writeFileSync } from "fs";
|
import { writeFileSync } from "fs";
|
||||||
import { join, dirname } from "path";
|
import { join, dirname } from "path";
|
||||||
import { fileURLToPath } from "url";
|
import { fileURLToPath } from "url";
|
||||||
|
import { Api, KnownProvider, Model } from "../src/types.js";
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url);
|
const __filename = fileURLToPath(import.meta.url);
|
||||||
const __dirname = dirname(__filename);
|
const __dirname = dirname(__filename);
|
||||||
|
|
@ -28,30 +29,13 @@ interface ModelsDevModel {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
interface NormalizedModel {
|
async function fetchOpenRouterModels(): Promise<Model<any>[]> {
|
||||||
id: string;
|
|
||||||
name: string;
|
|
||||||
provider: string;
|
|
||||||
baseUrl?: string;
|
|
||||||
reasoning: boolean;
|
|
||||||
input: ("text" | "image")[];
|
|
||||||
cost: {
|
|
||||||
input: number;
|
|
||||||
output: number;
|
|
||||||
cacheRead: number;
|
|
||||||
cacheWrite: number;
|
|
||||||
};
|
|
||||||
contextWindow: number;
|
|
||||||
maxTokens: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
|
||||||
try {
|
try {
|
||||||
console.log("Fetching models from OpenRouter API...");
|
console.log("Fetching models from OpenRouter API...");
|
||||||
const response = await fetch("https://openrouter.ai/api/v1/models");
|
const response = await fetch("https://openrouter.ai/api/v1/models");
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
const models: NormalizedModel[] = [];
|
const models: Model<any>[] = [];
|
||||||
|
|
||||||
for (const model of data.data) {
|
for (const model of data.data) {
|
||||||
// Only include models that support tools
|
// Only include models that support tools
|
||||||
|
|
@ -59,27 +43,17 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||||
|
|
||||||
// Parse provider from model ID
|
// Parse provider from model ID
|
||||||
const [providerPrefix] = model.id.split("/");
|
const [providerPrefix] = model.id.split("/");
|
||||||
let provider = "";
|
let provider: KnownProvider = "openrouter";
|
||||||
let modelKey = model.id;
|
let modelKey = model.id;
|
||||||
|
|
||||||
// Skip models that we get from models.dev (Anthropic, Google, OpenAI)
|
// Skip models that we get from models.dev (Anthropic, Google, OpenAI)
|
||||||
if (model.id.startsWith("google/") ||
|
if (model.id.startsWith("google/") ||
|
||||||
model.id.startsWith("openai/") ||
|
model.id.startsWith("openai/") ||
|
||||||
model.id.startsWith("anthropic/")) {
|
model.id.startsWith("anthropic/") ||
|
||||||
continue;
|
model.id.startsWith("x-ai/")) {
|
||||||
} else if (model.id.startsWith("x-ai/")) {
|
|
||||||
provider = "xai";
|
|
||||||
modelKey = model.id.replace("x-ai/", "");
|
|
||||||
} else {
|
|
||||||
// All other models go through OpenRouter
|
|
||||||
provider = "openrouter";
|
|
||||||
modelKey = model.id; // Keep full ID for OpenRouter
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip if not one of our supported providers from OpenRouter
|
|
||||||
if (!["xai", "openrouter"].includes(provider)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
modelKey = model.id; // Keep full ID for OpenRouter
|
||||||
|
|
||||||
// Parse input modalities
|
// Parse input modalities
|
||||||
const input: ("text" | "image")[] = ["text"];
|
const input: ("text" | "image")[] = ["text"];
|
||||||
|
|
@ -93,9 +67,11 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||||
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
|
const cacheReadCost = parseFloat(model.pricing?.input_cache_read || "0") * 1_000_000;
|
||||||
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
|
const cacheWriteCost = parseFloat(model.pricing?.input_cache_write || "0") * 1_000_000;
|
||||||
|
|
||||||
const normalizedModel: NormalizedModel = {
|
const normalizedModel: Model<any> = {
|
||||||
id: modelKey,
|
id: modelKey,
|
||||||
name: model.name,
|
name: model.name,
|
||||||
|
api: "openai-completions",
|
||||||
|
baseUrl: "https://openrouter.ai/api/v1",
|
||||||
provider,
|
provider,
|
||||||
reasoning: model.supported_parameters?.includes("reasoning") || false,
|
reasoning: model.supported_parameters?.includes("reasoning") || false,
|
||||||
input,
|
input,
|
||||||
|
|
@ -108,14 +84,6 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||||
contextWindow: model.context_length || 4096,
|
contextWindow: model.context_length || 4096,
|
||||||
maxTokens: model.top_provider?.max_completion_tokens || 4096,
|
maxTokens: model.top_provider?.max_completion_tokens || 4096,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add baseUrl for providers that need it
|
|
||||||
if (provider === "xai") {
|
|
||||||
normalizedModel.baseUrl = "https://api.x.ai/v1";
|
|
||||||
} else if (provider === "openrouter") {
|
|
||||||
normalizedModel.baseUrl = "https://openrouter.ai/api/v1";
|
|
||||||
}
|
|
||||||
|
|
||||||
models.push(normalizedModel);
|
models.push(normalizedModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,13 +95,13 @@ async function fetchOpenRouterModels(): Promise<NormalizedModel[]> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
async function loadModelsDevData(): Promise<Model<any>[]> {
|
||||||
try {
|
try {
|
||||||
console.log("Fetching models from models.dev API...");
|
console.log("Fetching models from models.dev API...");
|
||||||
const response = await fetch("https://models.dev/api.json");
|
const response = await fetch("https://models.dev/api.json");
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
const models: NormalizedModel[] = [];
|
const models: Model<any>[] = [];
|
||||||
|
|
||||||
// Process Anthropic models
|
// Process Anthropic models
|
||||||
if (data.anthropic?.models) {
|
if (data.anthropic?.models) {
|
||||||
|
|
@ -144,7 +112,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
models.push({
|
models.push({
|
||||||
id: modelId,
|
id: modelId,
|
||||||
name: m.name || modelId,
|
name: m.name || modelId,
|
||||||
|
api: "anthropic-messages",
|
||||||
provider: "anthropic",
|
provider: "anthropic",
|
||||||
|
baseUrl: "https://api.anthropic.com",
|
||||||
reasoning: m.reasoning === true,
|
reasoning: m.reasoning === true,
|
||||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||||
cost: {
|
cost: {
|
||||||
|
|
@ -168,7 +138,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
models.push({
|
models.push({
|
||||||
id: modelId,
|
id: modelId,
|
||||||
name: m.name || modelId,
|
name: m.name || modelId,
|
||||||
|
api: "google-generative-ai",
|
||||||
provider: "google",
|
provider: "google",
|
||||||
|
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||||
reasoning: m.reasoning === true,
|
reasoning: m.reasoning === true,
|
||||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||||
cost: {
|
cost: {
|
||||||
|
|
@ -192,7 +164,9 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
models.push({
|
models.push({
|
||||||
id: modelId,
|
id: modelId,
|
||||||
name: m.name || modelId,
|
name: m.name || modelId,
|
||||||
|
api: "openai-responses",
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
|
baseUrl: "https://api.openai.com/v1",
|
||||||
reasoning: m.reasoning === true,
|
reasoning: m.reasoning === true,
|
||||||
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||||
cost: {
|
cost: {
|
||||||
|
|
@ -216,6 +190,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
models.push({
|
models.push({
|
||||||
id: modelId,
|
id: modelId,
|
||||||
name: m.name || modelId,
|
name: m.name || modelId,
|
||||||
|
api: "openai-completions",
|
||||||
provider: "groq",
|
provider: "groq",
|
||||||
baseUrl: "https://api.groq.com/openai/v1",
|
baseUrl: "https://api.groq.com/openai/v1",
|
||||||
reasoning: m.reasoning === true,
|
reasoning: m.reasoning === true,
|
||||||
|
|
@ -241,6 +216,7 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
models.push({
|
models.push({
|
||||||
id: modelId,
|
id: modelId,
|
||||||
name: m.name || modelId,
|
name: m.name || modelId,
|
||||||
|
api: "openai-completions",
|
||||||
provider: "cerebras",
|
provider: "cerebras",
|
||||||
baseUrl: "https://api.cerebras.ai/v1",
|
baseUrl: "https://api.cerebras.ai/v1",
|
||||||
reasoning: m.reasoning === true,
|
reasoning: m.reasoning === true,
|
||||||
|
|
@ -257,6 +233,32 @@ async function loadModelsDevData(): Promise<NormalizedModel[]> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process xAi models
|
||||||
|
if (data.xai?.models) {
|
||||||
|
for (const [modelId, model] of Object.entries(data.xai.models)) {
|
||||||
|
const m = model as ModelsDevModel;
|
||||||
|
if (m.tool_call !== true) continue;
|
||||||
|
|
||||||
|
models.push({
|
||||||
|
id: modelId,
|
||||||
|
name: m.name || modelId,
|
||||||
|
api: "openai-completions",
|
||||||
|
provider: "xai",
|
||||||
|
baseUrl: "https://api.x.ai/v1",
|
||||||
|
reasoning: m.reasoning === true,
|
||||||
|
input: m.modalities?.input?.includes("image") ? ["text", "image"] : ["text"],
|
||||||
|
cost: {
|
||||||
|
input: m.cost?.input || 0,
|
||||||
|
output: m.cost?.output || 0,
|
||||||
|
cacheRead: m.cost?.cache_read || 0,
|
||||||
|
cacheWrite: m.cost?.cache_write || 0,
|
||||||
|
},
|
||||||
|
contextWindow: m.limit?.context || 4096,
|
||||||
|
maxTokens: m.limit?.output || 4096,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
console.log(`Loaded ${models.length} tool-capable models from models.dev`);
|
console.log(`Loaded ${models.length} tool-capable models from models.dev`);
|
||||||
return models;
|
return models;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -280,6 +282,8 @@ async function generateModels() {
|
||||||
allModels.push({
|
allModels.push({
|
||||||
id: "gpt-5-chat-latest",
|
id: "gpt-5-chat-latest",
|
||||||
name: "GPT-5 Chat Latest",
|
name: "GPT-5 Chat Latest",
|
||||||
|
api: "openai-responses",
|
||||||
|
baseUrl: "https://api.openai.com/v1",
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
reasoning: false,
|
reasoning: false,
|
||||||
input: ["text", "image"],
|
input: ["text", "image"],
|
||||||
|
|
@ -294,8 +298,29 @@ async function generateModels() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add missing Grok models
|
||||||
|
if (!allModels.some(m => m.provider === "xai" && m.id === "grok-code-fast-1")) {
|
||||||
|
allModels.push({
|
||||||
|
id: "grok-code-fast-1",
|
||||||
|
name: "Grok Code Fast 1",
|
||||||
|
api: "openai-completions",
|
||||||
|
baseUrl: "https://api.x.ai/v1",
|
||||||
|
provider: "xai",
|
||||||
|
reasoning: false,
|
||||||
|
input: ["text"],
|
||||||
|
cost: {
|
||||||
|
input: 0.2,
|
||||||
|
output: 1.5,
|
||||||
|
cacheRead: 0.02,
|
||||||
|
cacheWrite: 0,
|
||||||
|
},
|
||||||
|
contextWindow: 32768,
|
||||||
|
maxTokens: 8192,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Group by provider and deduplicate by model ID
|
// Group by provider and deduplicate by model ID
|
||||||
const providers: Record<string, Record<string, NormalizedModel>> = {};
|
const providers: Record<string, Record<string, Model<any>>> = {};
|
||||||
for (const model of allModels) {
|
for (const model of allModels) {
|
||||||
if (!providers[model.provider]) {
|
if (!providers[model.provider]) {
|
||||||
providers[model.provider] = {};
|
providers[model.provider] = {};
|
||||||
|
|
@ -319,39 +344,33 @@ export const PROVIDERS = {
|
||||||
// Generate provider sections
|
// Generate provider sections
|
||||||
for (const [providerId, models] of Object.entries(providers)) {
|
for (const [providerId, models] of Object.entries(providers)) {
|
||||||
output += `\t${providerId}: {\n`;
|
output += `\t${providerId}: {\n`;
|
||||||
output += `\t\tmodels: {\n`;
|
|
||||||
|
|
||||||
for (const model of Object.values(models)) {
|
for (const model of Object.values(models)) {
|
||||||
output += `\t\t\t"${model.id}": {\n`;
|
output += `\t\t"${model.id}": {\n`;
|
||||||
output += `\t\t\t\tid: "${model.id}",\n`;
|
output += `\t\t\tid: "${model.id}",\n`;
|
||||||
output += `\t\t\t\tname: "${model.name}",\n`;
|
output += `\t\t\tname: "${model.name}",\n`;
|
||||||
output += `\t\t\t\tprovider: "${model.provider}",\n`;
|
output += `\t\t\tapi: "${model.api}",\n`;
|
||||||
|
output += `\t\t\tprovider: "${model.provider}",\n`;
|
||||||
if (model.baseUrl) {
|
if (model.baseUrl) {
|
||||||
output += `\t\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
output += `\t\t\tbaseUrl: "${model.baseUrl}",\n`;
|
||||||
}
|
}
|
||||||
output += `\t\t\t\treasoning: ${model.reasoning},\n`;
|
output += `\t\t\treasoning: ${model.reasoning},\n`;
|
||||||
output += `\t\t\t\tinput: ${JSON.stringify(model.input)},\n`;
|
output += `\t\t\tinput: [${model.input.map(i => `"${i}"`).join(", ")}],\n`;
|
||||||
output += `\t\t\t\tcost: {\n`;
|
output += `\t\t\tcost: {\n`;
|
||||||
output += `\t\t\t\t\tinput: ${model.cost.input},\n`;
|
output += `\t\t\t\tinput: ${model.cost.input},\n`;
|
||||||
output += `\t\t\t\t\toutput: ${model.cost.output},\n`;
|
output += `\t\t\t\toutput: ${model.cost.output},\n`;
|
||||||
output += `\t\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
output += `\t\t\t\tcacheRead: ${model.cost.cacheRead},\n`;
|
||||||
output += `\t\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
output += `\t\t\t\tcacheWrite: ${model.cost.cacheWrite},\n`;
|
||||||
output += `\t\t\t\t},\n`;
|
output += `\t\t\t},\n`;
|
||||||
output += `\t\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
output += `\t\t\tcontextWindow: ${model.contextWindow},\n`;
|
||||||
output += `\t\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
output += `\t\t\tmaxTokens: ${model.maxTokens},\n`;
|
||||||
output += `\t\t\t} satisfies Model,\n`;
|
output += `\t\t} satisfies Model<"${model.api}">,\n`;
|
||||||
}
|
}
|
||||||
|
|
||||||
output += `\t\t}\n`;
|
|
||||||
output += `\t},\n`;
|
output += `\t},\n`;
|
||||||
}
|
}
|
||||||
|
|
||||||
output += `} as const;
|
output += `} as const;
|
||||||
|
|
||||||
// Helper type to extract models for each provider
|
|
||||||
export type ProviderModels = {
|
|
||||||
[K in keyof typeof PROVIDERS]: typeof PROVIDERS[K]["models"]
|
|
||||||
};
|
|
||||||
`;
|
`;
|
||||||
|
|
||||||
// Write file
|
// Write file
|
||||||
|
|
|
||||||
|
|
@ -1,47 +1,43 @@
|
||||||
|
import { type AnthropicOptions, streamAnthropic } from "./providers/anthropic.js";
|
||||||
|
import { type GoogleOptions, streamGoogle } from "./providers/google.js";
|
||||||
|
import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js";
|
||||||
|
import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js";
|
||||||
import type {
|
import type {
|
||||||
Api,
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
AssistantMessageEvent,
|
AssistantMessageEvent,
|
||||||
Context,
|
Context,
|
||||||
GenerateFunction,
|
|
||||||
GenerateOptionsUnified,
|
|
||||||
GenerateStream,
|
GenerateStream,
|
||||||
KnownProvider,
|
KnownProvider,
|
||||||
Model,
|
Model,
|
||||||
|
OptionsForApi,
|
||||||
ReasoningEffort,
|
ReasoningEffort,
|
||||||
|
SimpleGenerateOptions,
|
||||||
} from "./types.js";
|
} from "./types.js";
|
||||||
|
|
||||||
export class QueuedGenerateStream implements GenerateStream {
|
export class QueuedGenerateStream implements GenerateStream {
|
||||||
private queue: AssistantMessageEvent[] = [];
|
private queue: AssistantMessageEvent[] = [];
|
||||||
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
|
private waiting: ((value: IteratorResult<AssistantMessageEvent>) => void)[] = [];
|
||||||
private done = false;
|
private done = false;
|
||||||
private error?: Error;
|
|
||||||
private finalMessagePromise: Promise<AssistantMessage>;
|
private finalMessagePromise: Promise<AssistantMessage>;
|
||||||
private resolveFinalMessage!: (message: AssistantMessage) => void;
|
private resolveFinalMessage!: (message: AssistantMessage) => void;
|
||||||
private rejectFinalMessage!: (error: Error) => void;
|
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.finalMessagePromise = new Promise((resolve, reject) => {
|
this.finalMessagePromise = new Promise((resolve) => {
|
||||||
this.resolveFinalMessage = resolve;
|
this.resolveFinalMessage = resolve;
|
||||||
this.rejectFinalMessage = reject;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
push(event: AssistantMessageEvent): void {
|
push(event: AssistantMessageEvent): void {
|
||||||
if (this.done) return;
|
if (this.done) return;
|
||||||
|
|
||||||
// If it's the done event, resolve the final message
|
|
||||||
if (event.type === "done") {
|
if (event.type === "done") {
|
||||||
this.done = true;
|
this.done = true;
|
||||||
this.resolveFinalMessage(event.message);
|
this.resolveFinalMessage(event.message);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If it's an error event, reject the final message
|
|
||||||
if (event.type === "error") {
|
if (event.type === "error") {
|
||||||
this.error = new Error(event.error);
|
this.done = true;
|
||||||
if (!this.done) {
|
this.resolveFinalMessage(event.partial);
|
||||||
this.rejectFinalMessage(this.error);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deliver to waiting consumer or queue it
|
// Deliver to waiting consumer or queue it
|
||||||
|
|
@ -86,31 +82,14 @@ export class QueuedGenerateStream implements GenerateStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// API implementations registry
|
|
||||||
const apiImplementations: Map<Api | string, GenerateFunction> = new Map();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register a custom API implementation
|
|
||||||
*/
|
|
||||||
export function registerApi(api: string, impl: GenerateFunction): void {
|
|
||||||
apiImplementations.set(api, impl);
|
|
||||||
}
|
|
||||||
|
|
||||||
// API key storage
|
|
||||||
const apiKeys: Map<string, string> = new Map();
|
const apiKeys: Map<string, string> = new Map();
|
||||||
|
|
||||||
/**
|
|
||||||
* Set an API key for a provider
|
|
||||||
*/
|
|
||||||
export function setApiKey(provider: KnownProvider, key: string): void;
|
export function setApiKey(provider: KnownProvider, key: string): void;
|
||||||
export function setApiKey(provider: string, key: string): void;
|
export function setApiKey(provider: string, key: string): void;
|
||||||
export function setApiKey(provider: any, key: string): void {
|
export function setApiKey(provider: any, key: string): void {
|
||||||
apiKeys.set(provider, key);
|
apiKeys.set(provider, key);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get API key for a provider
|
|
||||||
*/
|
|
||||||
export function getApiKey(provider: KnownProvider): string | undefined;
|
export function getApiKey(provider: KnownProvider): string | undefined;
|
||||||
export function getApiKey(provider: string): string | undefined;
|
export function getApiKey(provider: string): string | undefined;
|
||||||
export function getApiKey(provider: any): string | undefined {
|
export function getApiKey(provider: any): string | undefined {
|
||||||
|
|
@ -133,45 +112,76 @@ export function getApiKey(provider: any): string | undefined {
|
||||||
return envVar ? process.env[envVar] : undefined;
|
return envVar ? process.env[envVar] : undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export function stream<TApi extends Api>(
|
||||||
* Main generate function
|
model: Model<TApi>,
|
||||||
*/
|
context: Context,
|
||||||
export function generate(model: Model, context: Context, options?: GenerateOptionsUnified): GenerateStream {
|
options?: OptionsForApi<TApi>,
|
||||||
// Get implementation
|
): GenerateStream {
|
||||||
const impl = apiImplementations.get(model.api);
|
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||||
if (!impl) {
|
if (!apiKey) {
|
||||||
throw new Error(`Unsupported API: ${model.api}`);
|
throw new Error(`No API key for provider: ${model.provider}`);
|
||||||
}
|
}
|
||||||
|
const providerOptions = { ...options, apiKey };
|
||||||
|
|
||||||
// Get API key from options or environment
|
const api: Api = model.api;
|
||||||
|
switch (api) {
|
||||||
|
case "anthropic-messages":
|
||||||
|
return streamAnthropic(model as Model<"anthropic-messages">, context, providerOptions);
|
||||||
|
|
||||||
|
case "openai-completions":
|
||||||
|
return streamOpenAICompletions(model as Model<"openai-completions">, context, providerOptions as any);
|
||||||
|
|
||||||
|
case "openai-responses":
|
||||||
|
return streamOpenAIResponses(model as Model<"openai-responses">, context, providerOptions as any);
|
||||||
|
|
||||||
|
case "google-generative-ai":
|
||||||
|
return streamGoogle(model as Model<"google-generative-ai">, context, providerOptions);
|
||||||
|
|
||||||
|
default: {
|
||||||
|
// This should never be reached if all Api cases are handled
|
||||||
|
const _exhaustive: never = api;
|
||||||
|
throw new Error(`Unhandled API: ${_exhaustive}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function complete<TApi extends Api>(
|
||||||
|
model: Model<TApi>,
|
||||||
|
context: Context,
|
||||||
|
options?: OptionsForApi<TApi>,
|
||||||
|
): Promise<AssistantMessage> {
|
||||||
|
const s = stream(model, context, options);
|
||||||
|
return s.finalMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
export function streamSimple<TApi extends Api>(
|
||||||
|
model: Model<TApi>,
|
||||||
|
context: Context,
|
||||||
|
options?: SimpleGenerateOptions,
|
||||||
|
): GenerateStream {
|
||||||
const apiKey = options?.apiKey || getApiKey(model.provider);
|
const apiKey = options?.apiKey || getApiKey(model.provider);
|
||||||
if (!apiKey) {
|
if (!apiKey) {
|
||||||
throw new Error(`No API key for provider: ${model.provider}`);
|
throw new Error(`No API key for provider: ${model.provider}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map generic options to provider-specific
|
const providerOptions = mapOptionsForApi(model, options, apiKey);
|
||||||
const providerOptions = mapOptionsForApi(model.api, model, options, apiKey);
|
return stream(model, context, providerOptions);
|
||||||
|
|
||||||
// Return the GenerateStream from implementation
|
|
||||||
return impl(model, context, providerOptions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export async function completeSimple<TApi extends Api>(
|
||||||
* Helper to generate and get complete response (no streaming)
|
model: Model<TApi>,
|
||||||
*/
|
|
||||||
export async function generateComplete(
|
|
||||||
model: Model,
|
|
||||||
context: Context,
|
context: Context,
|
||||||
options?: GenerateOptionsUnified,
|
options?: SimpleGenerateOptions,
|
||||||
): Promise<AssistantMessage> {
|
): Promise<AssistantMessage> {
|
||||||
const stream = generate(model, context, options);
|
const s = streamSimple(model, context, options);
|
||||||
return stream.finalMessage();
|
return s.finalMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
function mapOptionsForApi<TApi extends Api>(
|
||||||
* Map generic options to provider-specific options
|
model: Model<TApi>,
|
||||||
*/
|
options?: SimpleGenerateOptions,
|
||||||
function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOptionsUnified, apiKey?: string): any {
|
apiKey?: string,
|
||||||
|
): OptionsForApi<TApi> {
|
||||||
const base = {
|
const base = {
|
||||||
temperature: options?.temperature,
|
temperature: options?.temperature,
|
||||||
maxTokens: options?.maxTokens,
|
maxTokens: options?.maxTokens,
|
||||||
|
|
@ -179,18 +189,10 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
||||||
apiKey: apiKey || options?.apiKey,
|
apiKey: apiKey || options?.apiKey,
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (api) {
|
switch (model.api) {
|
||||||
case "openai-responses":
|
|
||||||
case "openai-completions":
|
|
||||||
return {
|
|
||||||
...base,
|
|
||||||
reasoning_effort: options?.reasoning,
|
|
||||||
};
|
|
||||||
|
|
||||||
case "anthropic-messages": {
|
case "anthropic-messages": {
|
||||||
if (!options?.reasoning) return base;
|
if (!options?.reasoning) return base satisfies AnthropicOptions;
|
||||||
|
|
||||||
// Map effort to token budget
|
|
||||||
const anthropicBudgets = {
|
const anthropicBudgets = {
|
||||||
minimal: 1024,
|
minimal: 1024,
|
||||||
low: 2048,
|
low: 2048,
|
||||||
|
|
@ -200,55 +202,60 @@ function mapOptionsForApi(api: Api | string, model: Model, options?: GenerateOpt
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...base,
|
...base,
|
||||||
thinking: {
|
thinkingEnabled: true,
|
||||||
enabled: true,
|
thinkingBudgetTokens: anthropicBudgets[options.reasoning],
|
||||||
budgetTokens: anthropicBudgets[options.reasoning],
|
} satisfies AnthropicOptions;
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
case "google-generative-ai": {
|
|
||||||
if (!options?.reasoning) return { ...base, thinking_budget: -1 };
|
|
||||||
|
|
||||||
// Model-specific mapping for Google
|
case "openai-completions":
|
||||||
const googleBudget = getGoogleBudget(model, options.reasoning);
|
|
||||||
return {
|
return {
|
||||||
...base,
|
...base,
|
||||||
thinking_budget: googleBudget,
|
reasoningEffort: options?.reasoning,
|
||||||
};
|
} satisfies OpenAICompletionsOptions;
|
||||||
|
|
||||||
|
case "openai-responses":
|
||||||
|
return {
|
||||||
|
...base,
|
||||||
|
reasoningEffort: options?.reasoning,
|
||||||
|
} satisfies OpenAIResponsesOptions;
|
||||||
|
|
||||||
|
case "google-generative-ai": {
|
||||||
|
if (!options?.reasoning) return base as any;
|
||||||
|
|
||||||
|
const googleBudget = getGoogleBudget(model as Model<"google-generative-ai">, options.reasoning);
|
||||||
|
return {
|
||||||
|
...base,
|
||||||
|
thinking: {
|
||||||
|
enabled: true,
|
||||||
|
budgetTokens: googleBudget,
|
||||||
|
},
|
||||||
|
} satisfies GoogleOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
default: {
|
||||||
|
// Exhaustiveness check
|
||||||
|
const _exhaustive: never = model.api;
|
||||||
|
throw new Error(`Unhandled API in mapOptionsForApi: ${_exhaustive}`);
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return base;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
function getGoogleBudget(model: Model<"google-generative-ai">, effort: ReasoningEffort): number {
|
||||||
* Get Google thinking budget based on model and effort
|
// See https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
||||||
*/
|
if (model.id.includes("2.5-pro")) {
|
||||||
function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
|
||||||
// Model-specific logic
|
|
||||||
if (model.id.includes("flash-lite")) {
|
|
||||||
const budgets = {
|
|
||||||
minimal: 512,
|
|
||||||
low: 2048,
|
|
||||||
medium: 8192,
|
|
||||||
high: 24576,
|
|
||||||
};
|
|
||||||
return budgets[effort];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model.id.includes("pro")) {
|
|
||||||
const budgets = {
|
const budgets = {
|
||||||
minimal: 128,
|
minimal: 128,
|
||||||
low: 2048,
|
low: 2048,
|
||||||
medium: 8192,
|
medium: 8192,
|
||||||
high: Math.min(25000, 32768),
|
high: 32768,
|
||||||
};
|
};
|
||||||
return budgets[effort];
|
return budgets[effort];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model.id.includes("flash")) {
|
if (model.id.includes("2.5-flash")) {
|
||||||
|
// Covers 2.5-flash-lite as well
|
||||||
const budgets = {
|
const budgets = {
|
||||||
minimal: 0, // Disable thinking
|
minimal: 128,
|
||||||
low: 2048,
|
low: 2048,
|
||||||
medium: 8192,
|
medium: 8192,
|
||||||
high: 24576,
|
high: 24576,
|
||||||
|
|
@ -259,10 +266,3 @@ function getGoogleBudget(model: Model, effort: ReasoningEffort): number {
|
||||||
// Unknown model - use dynamic
|
// Unknown model - use dynamic
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register built-in API implementations
|
|
||||||
// Import the new function-based implementations
|
|
||||||
import { generateAnthropic } from "./providers/anthropic-generate.js";
|
|
||||||
|
|
||||||
// Register Anthropic implementation
|
|
||||||
apiImplementations.set("anthropic-messages", generateAnthropic);
|
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,8 @@
|
||||||
// @mariozechner/pi-ai - Unified LLM API with automatic model discovery
|
export * from "./generate.js";
|
||||||
// This package provides a common interface for working with multiple LLM providers
|
export * from "./models.generated.js";
|
||||||
|
export * from "./models.js";
|
||||||
export const version = "0.5.8";
|
export * from "./providers/anthropic.js";
|
||||||
|
export * from "./providers/google.js";
|
||||||
// Export generate API
|
export * from "./providers/openai-completions.js";
|
||||||
export {
|
export * from "./providers/openai-responses.js";
|
||||||
generate,
|
export * from "./types.js";
|
||||||
generateComplete,
|
|
||||||
getApiKey,
|
|
||||||
QueuedGenerateStream,
|
|
||||||
registerApi,
|
|
||||||
setApiKey,
|
|
||||||
} from "./generate.js";
|
|
||||||
// Export generated models data
|
|
||||||
export { PROVIDERS } from "./models.generated.js";
|
|
||||||
// Export model utilities
|
|
||||||
export {
|
|
||||||
calculateCost,
|
|
||||||
getModel,
|
|
||||||
type KnownProvider,
|
|
||||||
registerModel,
|
|
||||||
} from "./models.js";
|
|
||||||
|
|
||||||
// Legacy providers (to be deprecated)
|
|
||||||
export { AnthropicLLM } from "./providers/anthropic.js";
|
|
||||||
export { GoogleLLM } from "./providers/google.js";
|
|
||||||
export { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
|
||||||
export { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
|
||||||
|
|
||||||
// Export types
|
|
||||||
export type * from "./types.js";
|
|
||||||
|
|
||||||
// TODO: Remove these legacy exports once consumers are updated
|
|
||||||
export function createLLM(): never {
|
|
||||||
throw new Error("createLLM is deprecated. Use generate() with getModel() instead.");
|
|
||||||
}
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +1,39 @@
|
||||||
import { PROVIDERS } from "./models.generated.js";
|
import { PROVIDERS } from "./models.generated.js";
|
||||||
import type { KnownProvider, Model, Usage } from "./types.js";
|
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||||
|
|
||||||
// Re-export Model type
|
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||||
export type { KnownProvider, Model } from "./types.js";
|
|
||||||
|
|
||||||
// Dynamic model registry initialized from PROVIDERS
|
|
||||||
const modelRegistry: Map<string, Map<string, Model>> = new Map();
|
|
||||||
|
|
||||||
// Initialize registry from PROVIDERS on module load
|
// Initialize registry from PROVIDERS on module load
|
||||||
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
||||||
const providerModels = new Map<string, Model>();
|
const providerModels = new Map<string, Model<Api>>();
|
||||||
for (const [id, model] of Object.entries(models)) {
|
for (const [id, model] of Object.entries(models)) {
|
||||||
providerModels.set(id, model as Model);
|
providerModels.set(id, model as Model<Api>);
|
||||||
}
|
}
|
||||||
modelRegistry.set(provider, providerModels);
|
modelRegistry.set(provider, providerModels);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
type ModelApi<
|
||||||
* Get a model from the registry - typed overload for known providers
|
TProvider extends KnownProvider,
|
||||||
*/
|
TModelId extends keyof (typeof PROVIDERS)[TProvider],
|
||||||
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
|
> = (typeof PROVIDERS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||||
export function getModel(provider: string, modelId: string): Model | undefined;
|
|
||||||
export function getModel(provider: any, modelId: any): Model | undefined {
|
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof PROVIDERS)[TProvider]>(
|
||||||
return modelRegistry.get(provider)?.get(modelId);
|
provider: TProvider,
|
||||||
|
modelId: TModelId,
|
||||||
|
): Model<ModelApi<TProvider, TModelId>>;
|
||||||
|
export function getModel<TApi extends Api>(provider: string, modelId: string): Model<TApi> | undefined;
|
||||||
|
export function getModel<TApi extends Api>(provider: any, modelId: any): Model<TApi> | undefined {
|
||||||
|
return modelRegistry.get(provider)?.get(modelId) as Model<TApi> | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export function registerModel<TApi extends Api>(model: Model<TApi>): void {
|
||||||
* Register a custom model
|
|
||||||
*/
|
|
||||||
export function registerModel(model: Model): void {
|
|
||||||
if (!modelRegistry.has(model.provider)) {
|
if (!modelRegistry.has(model.provider)) {
|
||||||
modelRegistry.set(model.provider, new Map());
|
modelRegistry.set(model.provider, new Map());
|
||||||
}
|
}
|
||||||
modelRegistry.get(model.provider)!.set(model.id, model);
|
modelRegistry.get(model.provider)!.set(model.id, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||||
* Calculate cost for token usage
|
|
||||||
*/
|
|
||||||
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
|
|
||||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||||
|
|
|
||||||
|
|
@ -1,425 +0,0 @@
|
||||||
import Anthropic from "@anthropic-ai/sdk";
|
|
||||||
import type {
|
|
||||||
ContentBlockParam,
|
|
||||||
MessageCreateParamsStreaming,
|
|
||||||
MessageParam,
|
|
||||||
Tool,
|
|
||||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
|
||||||
import { QueuedGenerateStream } from "../generate.js";
|
|
||||||
import { calculateCost } from "../models.js";
|
|
||||||
import type {
|
|
||||||
Api,
|
|
||||||
AssistantMessage,
|
|
||||||
Context,
|
|
||||||
GenerateFunction,
|
|
||||||
GenerateOptions,
|
|
||||||
GenerateStream,
|
|
||||||
Message,
|
|
||||||
Model,
|
|
||||||
StopReason,
|
|
||||||
TextContent,
|
|
||||||
ThinkingContent,
|
|
||||||
ToolCall,
|
|
||||||
} from "../types.js";
|
|
||||||
import { transformMessages } from "./utils.js";
|
|
||||||
|
|
||||||
// Anthropic-specific options
|
|
||||||
export interface AnthropicOptions extends GenerateOptions {
|
|
||||||
thinking?: {
|
|
||||||
enabled: boolean;
|
|
||||||
budgetTokens?: number;
|
|
||||||
};
|
|
||||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate function for Anthropic API
|
|
||||||
*/
|
|
||||||
export const generateAnthropic: GenerateFunction<AnthropicOptions> = (
|
|
||||||
model: Model,
|
|
||||||
context: Context,
|
|
||||||
options: AnthropicOptions,
|
|
||||||
): GenerateStream => {
|
|
||||||
const stream = new QueuedGenerateStream();
|
|
||||||
|
|
||||||
// Start async processing
|
|
||||||
(async () => {
|
|
||||||
const output: AssistantMessage = {
|
|
||||||
role: "assistant",
|
|
||||||
content: [],
|
|
||||||
api: "anthropic-messages" as Api,
|
|
||||||
provider: model.provider,
|
|
||||||
model: model.id,
|
|
||||||
usage: {
|
|
||||||
input: 0,
|
|
||||||
output: 0,
|
|
||||||
cacheRead: 0,
|
|
||||||
cacheWrite: 0,
|
|
||||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
||||||
},
|
|
||||||
stopReason: "stop",
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Create Anthropic client
|
|
||||||
const client = createAnthropicClient(model, options.apiKey!);
|
|
||||||
|
|
||||||
// Convert messages
|
|
||||||
const messages = convertMessages(context.messages, model, "anthropic-messages");
|
|
||||||
|
|
||||||
// Build params
|
|
||||||
const params = buildAnthropicParams(model, context, options, messages, client.isOAuthToken);
|
|
||||||
|
|
||||||
// Create Anthropic stream
|
|
||||||
const anthropicStream = client.client.messages.stream(
|
|
||||||
{
|
|
||||||
...params,
|
|
||||||
stream: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
signal: options.signal,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Emit start event
|
|
||||||
stream.push({
|
|
||||||
type: "start",
|
|
||||||
partial: output,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Process Anthropic events
|
|
||||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
|
||||||
|
|
||||||
for await (const event of anthropicStream) {
|
|
||||||
if (event.type === "content_block_start") {
|
|
||||||
if (event.content_block.type === "text") {
|
|
||||||
currentBlock = {
|
|
||||||
type: "text",
|
|
||||||
text: "",
|
|
||||||
};
|
|
||||||
output.content.push(currentBlock);
|
|
||||||
stream.push({ type: "text_start", partial: output });
|
|
||||||
} else if (event.content_block.type === "thinking") {
|
|
||||||
currentBlock = {
|
|
||||||
type: "thinking",
|
|
||||||
thinking: "",
|
|
||||||
thinkingSignature: "",
|
|
||||||
};
|
|
||||||
output.content.push(currentBlock);
|
|
||||||
stream.push({ type: "thinking_start", partial: output });
|
|
||||||
} else if (event.content_block.type === "tool_use") {
|
|
||||||
// We wait for the full tool use to be streamed
|
|
||||||
currentBlock = {
|
|
||||||
type: "toolCall",
|
|
||||||
id: event.content_block.id,
|
|
||||||
name: event.content_block.name,
|
|
||||||
arguments: event.content_block.input as Record<string, any>,
|
|
||||||
partialJson: "",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else if (event.type === "content_block_delta") {
|
|
||||||
if (event.delta.type === "text_delta") {
|
|
||||||
if (currentBlock && currentBlock.type === "text") {
|
|
||||||
currentBlock.text += event.delta.text;
|
|
||||||
stream.push({
|
|
||||||
type: "text_delta",
|
|
||||||
delta: event.delta.text,
|
|
||||||
partial: output,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (event.delta.type === "thinking_delta") {
|
|
||||||
if (currentBlock && currentBlock.type === "thinking") {
|
|
||||||
currentBlock.thinking += event.delta.thinking;
|
|
||||||
stream.push({
|
|
||||||
type: "thinking_delta",
|
|
||||||
delta: event.delta.thinking,
|
|
||||||
partial: output,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (event.delta.type === "input_json_delta") {
|
|
||||||
if (currentBlock && currentBlock.type === "toolCall") {
|
|
||||||
currentBlock.partialJson += event.delta.partial_json;
|
|
||||||
}
|
|
||||||
} else if (event.delta.type === "signature_delta") {
|
|
||||||
if (currentBlock && currentBlock.type === "thinking") {
|
|
||||||
currentBlock.thinkingSignature = currentBlock.thinkingSignature || "";
|
|
||||||
currentBlock.thinkingSignature += event.delta.signature;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (event.type === "content_block_stop") {
|
|
||||||
if (currentBlock) {
|
|
||||||
if (currentBlock.type === "text") {
|
|
||||||
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
|
||||||
} else if (currentBlock.type === "thinking") {
|
|
||||||
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
|
||||||
} else if (currentBlock.type === "toolCall") {
|
|
||||||
const finalToolCall: ToolCall = {
|
|
||||||
type: "toolCall",
|
|
||||||
id: currentBlock.id,
|
|
||||||
name: currentBlock.name,
|
|
||||||
arguments: JSON.parse(currentBlock.partialJson),
|
|
||||||
};
|
|
||||||
output.content.push(finalToolCall);
|
|
||||||
stream.push({ type: "toolCall", toolCall: finalToolCall, partial: output });
|
|
||||||
}
|
|
||||||
currentBlock = null;
|
|
||||||
}
|
|
||||||
} else if (event.type === "message_delta") {
|
|
||||||
if (event.delta.stop_reason) {
|
|
||||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
|
||||||
}
|
|
||||||
output.usage.input += event.usage.input_tokens || 0;
|
|
||||||
output.usage.output += event.usage.output_tokens || 0;
|
|
||||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
|
||||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
|
||||||
calculateCost(model, output.usage);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit done event with final message
|
|
||||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
|
||||||
stream.end();
|
|
||||||
} catch (error) {
|
|
||||||
output.stopReason = "error";
|
|
||||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
|
||||||
stream.push({ type: "error", error: output.error, partial: output });
|
|
||||||
stream.end();
|
|
||||||
}
|
|
||||||
})();
|
|
||||||
|
|
||||||
return stream;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper to create Anthropic client
|
|
||||||
interface AnthropicClientWrapper {
|
|
||||||
client: Anthropic;
|
|
||||||
isOAuthToken: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
function createAnthropicClient(model: Model, apiKey: string): AnthropicClientWrapper {
|
|
||||||
if (apiKey.includes("sk-ant-oat")) {
|
|
||||||
const defaultHeaders = {
|
|
||||||
accept: "application/json",
|
|
||||||
"anthropic-dangerous-direct-browser-access": "true",
|
|
||||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
|
||||||
if (typeof process !== "undefined" && process.env) {
|
|
||||||
process.env.ANTHROPIC_API_KEY = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
const client = new Anthropic({
|
|
||||||
apiKey: null,
|
|
||||||
authToken: apiKey,
|
|
||||||
baseURL: model.baseUrl,
|
|
||||||
defaultHeaders,
|
|
||||||
dangerouslyAllowBrowser: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
return { client, isOAuthToken: true };
|
|
||||||
} else {
|
|
||||||
const defaultHeaders = {
|
|
||||||
accept: "application/json",
|
|
||||||
"anthropic-dangerous-direct-browser-access": "true",
|
|
||||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
|
||||||
};
|
|
||||||
|
|
||||||
const client = new Anthropic({
|
|
||||||
apiKey,
|
|
||||||
baseURL: model.baseUrl,
|
|
||||||
dangerouslyAllowBrowser: true,
|
|
||||||
defaultHeaders,
|
|
||||||
});
|
|
||||||
|
|
||||||
return { client, isOAuthToken: false };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build Anthropic API params
|
|
||||||
function buildAnthropicParams(
|
|
||||||
model: Model,
|
|
||||||
context: Context,
|
|
||||||
options: AnthropicOptions,
|
|
||||||
messages: MessageParam[],
|
|
||||||
isOAuthToken: boolean,
|
|
||||||
): MessageCreateParamsStreaming {
|
|
||||||
const params: MessageCreateParamsStreaming = {
|
|
||||||
model: model.id,
|
|
||||||
messages,
|
|
||||||
max_tokens: options.maxTokens || model.maxTokens,
|
|
||||||
stream: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
// For OAuth tokens, we MUST include Claude Code identity
|
|
||||||
if (isOAuthToken) {
|
|
||||||
params.system = [
|
|
||||||
{
|
|
||||||
type: "text",
|
|
||||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
|
||||||
cache_control: {
|
|
||||||
type: "ephemeral",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
if (context.systemPrompt) {
|
|
||||||
params.system.push({
|
|
||||||
type: "text",
|
|
||||||
text: context.systemPrompt,
|
|
||||||
cache_control: {
|
|
||||||
type: "ephemeral",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (context.systemPrompt) {
|
|
||||||
params.system = context.systemPrompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options.temperature !== undefined) {
|
|
||||||
params.temperature = options.temperature;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (context.tools) {
|
|
||||||
params.tools = convertTools(context.tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only enable thinking if the model supports it
|
|
||||||
if (options.thinking?.enabled && model.reasoning) {
|
|
||||||
params.thinking = {
|
|
||||||
type: "enabled",
|
|
||||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options.toolChoice) {
|
|
||||||
if (typeof options.toolChoice === "string") {
|
|
||||||
params.tool_choice = { type: options.toolChoice };
|
|
||||||
} else {
|
|
||||||
params.tool_choice = options.toolChoice;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert messages to Anthropic format
|
|
||||||
function convertMessages(messages: Message[], model: Model, api: Api): MessageParam[] {
|
|
||||||
const params: MessageParam[] = [];
|
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
|
||||||
const transformedMessages = transformMessages(messages, model, api);
|
|
||||||
|
|
||||||
for (const msg of transformedMessages) {
|
|
||||||
if (msg.role === "user") {
|
|
||||||
// Handle both string and array content
|
|
||||||
if (typeof msg.content === "string") {
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: msg.content,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Convert array content to Anthropic format
|
|
||||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
|
||||||
if (item.type === "text") {
|
|
||||||
return {
|
|
||||||
type: "text",
|
|
||||||
text: item.text,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Image content
|
|
||||||
return {
|
|
||||||
type: "image",
|
|
||||||
source: {
|
|
||||||
type: "base64",
|
|
||||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
|
||||||
data: item.data,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: filteredBlocks,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (msg.role === "assistant") {
|
|
||||||
const blocks: ContentBlockParam[] = [];
|
|
||||||
|
|
||||||
for (const block of msg.content) {
|
|
||||||
if (block.type === "text") {
|
|
||||||
blocks.push({
|
|
||||||
type: "text",
|
|
||||||
text: block.text,
|
|
||||||
});
|
|
||||||
} else if (block.type === "thinking") {
|
|
||||||
blocks.push({
|
|
||||||
type: "thinking",
|
|
||||||
thinking: block.thinking,
|
|
||||||
signature: block.thinkingSignature || "",
|
|
||||||
});
|
|
||||||
} else if (block.type === "toolCall") {
|
|
||||||
blocks.push({
|
|
||||||
type: "tool_use",
|
|
||||||
id: block.id,
|
|
||||||
name: block.name,
|
|
||||||
input: block.arguments,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
params.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: blocks,
|
|
||||||
});
|
|
||||||
} else if (msg.role === "toolResult") {
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: "tool_result",
|
|
||||||
tool_use_id: msg.toolCallId,
|
|
||||||
content: msg.content,
|
|
||||||
is_error: msg.isError,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert tools to Anthropic format
|
|
||||||
function convertTools(tools: Context["tools"]): Tool[] {
|
|
||||||
if (!tools) return [];
|
|
||||||
|
|
||||||
return tools.map((tool) => ({
|
|
||||||
name: tool.name,
|
|
||||||
description: tool.description,
|
|
||||||
input_schema: {
|
|
||||||
type: "object" as const,
|
|
||||||
properties: tool.parameters.properties || {},
|
|
||||||
required: tool.parameters.required || [],
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map Anthropic stop reason to our StopReason type
|
|
||||||
function mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
|
||||||
switch (reason) {
|
|
||||||
case "end_turn":
|
|
||||||
return "stop";
|
|
||||||
case "max_tokens":
|
|
||||||
return "length";
|
|
||||||
case "tool_use":
|
|
||||||
return "toolUse";
|
|
||||||
case "refusal":
|
|
||||||
return "safety";
|
|
||||||
case "pause_turn": // Stop is good enough -> resubmit
|
|
||||||
return "stop";
|
|
||||||
case "stop_sequence":
|
|
||||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
|
||||||
default:
|
|
||||||
return "stop";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -3,91 +3,46 @@ import type {
|
||||||
ContentBlockParam,
|
ContentBlockParam,
|
||||||
MessageCreateParamsStreaming,
|
MessageCreateParamsStreaming,
|
||||||
MessageParam,
|
MessageParam,
|
||||||
Tool,
|
|
||||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||||
|
import { QueuedGenerateStream } from "../generate.js";
|
||||||
import { calculateCost } from "../models.js";
|
import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
Context,
|
Context,
|
||||||
LLM,
|
GenerateFunction,
|
||||||
LLMOptions,
|
GenerateOptions,
|
||||||
|
GenerateStream,
|
||||||
Message,
|
Message,
|
||||||
Model,
|
Model,
|
||||||
StopReason,
|
StopReason,
|
||||||
TextContent,
|
TextContent,
|
||||||
ThinkingContent,
|
ThinkingContent,
|
||||||
|
Tool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
} from "../types.js";
|
} from "../types.js";
|
||||||
import { transformMessages } from "./utils.js";
|
import { transformMessages } from "./utils.js";
|
||||||
|
|
||||||
export interface AnthropicLLMOptions extends LLMOptions {
|
export interface AnthropicOptions extends GenerateOptions {
|
||||||
thinking?: {
|
thinkingEnabled?: boolean;
|
||||||
enabled: boolean;
|
thinkingBudgetTokens?: number;
|
||||||
budgetTokens?: number;
|
|
||||||
};
|
|
||||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||||
}
|
}
|
||||||
|
|
||||||
export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
export const streamAnthropic: GenerateFunction<"anthropic-messages"> = (
|
||||||
private client: Anthropic;
|
model: Model<"anthropic-messages">,
|
||||||
private modelInfo: Model;
|
context: Context,
|
||||||
private isOAuthToken: boolean = false;
|
options?: AnthropicOptions,
|
||||||
|
): GenerateStream => {
|
||||||
|
const stream = new QueuedGenerateStream();
|
||||||
|
|
||||||
constructor(model: Model, apiKey?: string) {
|
(async () => {
|
||||||
if (!apiKey) {
|
|
||||||
if (!process.env.ANTHROPIC_API_KEY) {
|
|
||||||
throw new Error(
|
|
||||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass it as an argument.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
apiKey = process.env.ANTHROPIC_API_KEY;
|
|
||||||
}
|
|
||||||
if (apiKey.includes("sk-ant-oat")) {
|
|
||||||
const defaultHeaders = {
|
|
||||||
accept: "application/json",
|
|
||||||
"anthropic-dangerous-direct-browser-access": "true",
|
|
||||||
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Clear the env var if we're in Node.js to prevent SDK from using it
|
|
||||||
if (typeof process !== "undefined" && process.env) {
|
|
||||||
process.env.ANTHROPIC_API_KEY = undefined;
|
|
||||||
}
|
|
||||||
this.client = new Anthropic({
|
|
||||||
apiKey: null,
|
|
||||||
authToken: apiKey,
|
|
||||||
baseURL: model.baseUrl,
|
|
||||||
defaultHeaders,
|
|
||||||
dangerouslyAllowBrowser: true,
|
|
||||||
});
|
|
||||||
this.isOAuthToken = true;
|
|
||||||
} else {
|
|
||||||
const defaultHeaders = {
|
|
||||||
accept: "application/json",
|
|
||||||
"anthropic-dangerous-direct-browser-access": "true",
|
|
||||||
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
|
||||||
};
|
|
||||||
this.client = new Anthropic({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true, defaultHeaders });
|
|
||||||
this.isOAuthToken = false;
|
|
||||||
}
|
|
||||||
this.modelInfo = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
getModel(): Model {
|
|
||||||
return this.modelInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
getApi(): string {
|
|
||||||
return "anthropic-messages";
|
|
||||||
}
|
|
||||||
|
|
||||||
async generate(context: Context, options?: AnthropicLLMOptions): Promise<AssistantMessage> {
|
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
api: this.getApi(),
|
api: "anthropic-messages" as Api,
|
||||||
provider: this.modelInfo.provider,
|
provider: model.provider,
|
||||||
model: this.modelInfo.id,
|
model: model.id,
|
||||||
usage: {
|
usage: {
|
||||||
input: 0,
|
input: 0,
|
||||||
output: 0,
|
output: 0,
|
||||||
|
|
@ -99,77 +54,14 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const messages = this.convertMessages(context.messages);
|
const { client, isOAuthToken } = createClient(model, options?.apiKey!);
|
||||||
|
const params = buildParams(model, context, isOAuthToken, options);
|
||||||
const params: MessageCreateParamsStreaming = {
|
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||||
model: this.modelInfo.id,
|
stream.push({ type: "start", partial: output });
|
||||||
messages,
|
|
||||||
max_tokens: options?.maxTokens || 4096,
|
|
||||||
stream: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
// For OAuth tokens, we MUST include Claude Code identity
|
|
||||||
if (this.isOAuthToken) {
|
|
||||||
params.system = [
|
|
||||||
{
|
|
||||||
type: "text",
|
|
||||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
|
||||||
cache_control: {
|
|
||||||
type: "ephemeral",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
if (context.systemPrompt) {
|
|
||||||
params.system.push({
|
|
||||||
type: "text",
|
|
||||||
text: context.systemPrompt,
|
|
||||||
cache_control: {
|
|
||||||
type: "ephemeral",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (context.systemPrompt) {
|
|
||||||
params.system = context.systemPrompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.temperature !== undefined) {
|
|
||||||
params.temperature = options?.temperature;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (context.tools) {
|
|
||||||
params.tools = this.convertTools(context.tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only enable thinking if the model supports it
|
|
||||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
|
||||||
params.thinking = {
|
|
||||||
type: "enabled",
|
|
||||||
budget_tokens: options.thinking.budgetTokens || 1024,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.toolChoice) {
|
|
||||||
if (typeof options.toolChoice === "string") {
|
|
||||||
params.tool_choice = { type: options.toolChoice };
|
|
||||||
} else {
|
|
||||||
params.tool_choice = options.toolChoice;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const stream = this.client.messages.stream(
|
|
||||||
{
|
|
||||||
...params,
|
|
||||||
stream: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
signal: options?.signal,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
|
||||||
|
|
||||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||||
for await (const event of stream) {
|
|
||||||
|
for await (const event of anthropicStream) {
|
||||||
if (event.type === "content_block_start") {
|
if (event.type === "content_block_start") {
|
||||||
if (event.content_block.type === "text") {
|
if (event.content_block.type === "text") {
|
||||||
currentBlock = {
|
currentBlock = {
|
||||||
|
|
@ -177,7 +69,7 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
text: "",
|
text: "",
|
||||||
};
|
};
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
options?.onEvent?.({ type: "text_start" });
|
stream.push({ type: "text_start", partial: output });
|
||||||
} else if (event.content_block.type === "thinking") {
|
} else if (event.content_block.type === "thinking") {
|
||||||
currentBlock = {
|
currentBlock = {
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
|
|
@ -185,9 +77,9 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
thinkingSignature: "",
|
thinkingSignature: "",
|
||||||
};
|
};
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
options?.onEvent?.({ type: "thinking_start" });
|
stream.push({ type: "thinking_start", partial: output });
|
||||||
} else if (event.content_block.type === "tool_use") {
|
} else if (event.content_block.type === "tool_use") {
|
||||||
// We wait for the full tool use to be streamed to send the event
|
// We wait for the full tool use to be streamed
|
||||||
currentBlock = {
|
currentBlock = {
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
id: event.content_block.id,
|
id: event.content_block.id,
|
||||||
|
|
@ -200,15 +92,19 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
if (event.delta.type === "text_delta") {
|
if (event.delta.type === "text_delta") {
|
||||||
if (currentBlock && currentBlock.type === "text") {
|
if (currentBlock && currentBlock.type === "text") {
|
||||||
currentBlock.text += event.delta.text;
|
currentBlock.text += event.delta.text;
|
||||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: event.delta.text });
|
stream.push({
|
||||||
|
type: "text_delta",
|
||||||
|
delta: event.delta.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} else if (event.delta.type === "thinking_delta") {
|
} else if (event.delta.type === "thinking_delta") {
|
||||||
if (currentBlock && currentBlock.type === "thinking") {
|
if (currentBlock && currentBlock.type === "thinking") {
|
||||||
currentBlock.thinking += event.delta.thinking;
|
currentBlock.thinking += event.delta.thinking;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "thinking_delta",
|
type: "thinking_delta",
|
||||||
content: currentBlock.thinking,
|
|
||||||
delta: event.delta.thinking,
|
delta: event.delta.thinking,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (event.delta.type === "input_json_delta") {
|
} else if (event.delta.type === "input_json_delta") {
|
||||||
|
|
@ -224,9 +120,17 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
} else if (event.type === "content_block_stop") {
|
} else if (event.type === "content_block_stop") {
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "thinking") {
|
} else if (currentBlock.type === "thinking") {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
const finalToolCall: ToolCall = {
|
const finalToolCall: ToolCall = {
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
|
|
@ -235,150 +139,274 @@ export class AnthropicLLM implements LLM<AnthropicLLMOptions> {
|
||||||
arguments: JSON.parse(currentBlock.partialJson),
|
arguments: JSON.parse(currentBlock.partialJson),
|
||||||
};
|
};
|
||||||
output.content.push(finalToolCall);
|
output.content.push(finalToolCall);
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: finalToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: finalToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
currentBlock = null;
|
currentBlock = null;
|
||||||
}
|
}
|
||||||
} else if (event.type === "message_delta") {
|
} else if (event.type === "message_delta") {
|
||||||
if (event.delta.stop_reason) {
|
if (event.delta.stop_reason) {
|
||||||
output.stopReason = this.mapStopReason(event.delta.stop_reason);
|
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||||
}
|
}
|
||||||
output.usage.input += event.usage.input_tokens || 0;
|
output.usage.input += event.usage.input_tokens || 0;
|
||||||
output.usage.output += event.usage.output_tokens || 0;
|
output.usage.output += event.usage.output_tokens || 0;
|
||||||
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
output.usage.cacheRead += event.usage.cache_read_input_tokens || 0;
|
||||||
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
output.usage.cacheWrite += event.usage.cache_creation_input_tokens || 0;
|
||||||
calculateCost(this.modelInfo, output.usage);
|
calculateCost(model, output.usage);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
if (options?.signal?.aborted) {
|
||||||
return output;
|
throw new Error("Request was aborted");
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||||
|
stream.end();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
return output;
|
stream.end();
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
function createClient(
|
||||||
|
model: Model<"anthropic-messages">,
|
||||||
|
apiKey: string,
|
||||||
|
): { client: Anthropic; isOAuthToken: boolean } {
|
||||||
|
if (apiKey.includes("sk-ant-oat")) {
|
||||||
|
const defaultHeaders = {
|
||||||
|
accept: "application/json",
|
||||||
|
"anthropic-dangerous-direct-browser-access": "true",
|
||||||
|
"anthropic-beta": "oauth-2025-04-20,fine-grained-tool-streaming-2025-05-14",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Clear the env var if we're in Node.js to prevent SDK from using it
|
||||||
|
if (typeof process !== "undefined" && process.env) {
|
||||||
|
process.env.ANTHROPIC_API_KEY = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
const client = new Anthropic({
|
||||||
|
apiKey: null,
|
||||||
|
authToken: apiKey,
|
||||||
|
baseURL: model.baseUrl,
|
||||||
|
defaultHeaders,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
return { client, isOAuthToken: true };
|
||||||
|
} else {
|
||||||
|
const defaultHeaders = {
|
||||||
|
accept: "application/json",
|
||||||
|
"anthropic-dangerous-direct-browser-access": "true",
|
||||||
|
"anthropic-beta": "fine-grained-tool-streaming-2025-05-14",
|
||||||
|
};
|
||||||
|
|
||||||
|
const client = new Anthropic({
|
||||||
|
apiKey,
|
||||||
|
baseURL: model.baseUrl,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders,
|
||||||
|
});
|
||||||
|
|
||||||
|
return { client, isOAuthToken: false };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildParams(
|
||||||
|
model: Model<"anthropic-messages">,
|
||||||
|
context: Context,
|
||||||
|
isOAuthToken: boolean,
|
||||||
|
options?: AnthropicOptions,
|
||||||
|
): MessageCreateParamsStreaming {
|
||||||
|
const params: MessageCreateParamsStreaming = {
|
||||||
|
model: model.id,
|
||||||
|
messages: convertMessages(context.messages, model),
|
||||||
|
max_tokens: options?.maxTokens || model.maxTokens,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
// For OAuth tokens, we MUST include Claude Code identity
|
||||||
|
if (isOAuthToken) {
|
||||||
|
params.system = [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||||
|
cache_control: {
|
||||||
|
type: "ephemeral",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
if (context.systemPrompt) {
|
||||||
|
params.system.push({
|
||||||
|
type: "text",
|
||||||
|
text: context.systemPrompt,
|
||||||
|
cache_control: {
|
||||||
|
type: "ephemeral",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (context.systemPrompt) {
|
||||||
|
params.system = context.systemPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options?.temperature !== undefined) {
|
||||||
|
params.temperature = options.temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (context.tools) {
|
||||||
|
params.tools = convertTools(context.tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options?.thinkingEnabled && model.reasoning) {
|
||||||
|
params.thinking = {
|
||||||
|
type: "enabled",
|
||||||
|
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options?.toolChoice) {
|
||||||
|
if (typeof options.toolChoice === "string") {
|
||||||
|
params.tool_choice = { type: options.toolChoice };
|
||||||
|
} else {
|
||||||
|
params.tool_choice = options.toolChoice;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertMessages(messages: Message[]): MessageParam[] {
|
return params;
|
||||||
const params: MessageParam[] = [];
|
}
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
function convertMessages(messages: Message[], model: Model<"anthropic-messages">): MessageParam[] {
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
const params: MessageParam[] = [];
|
||||||
|
|
||||||
for (const msg of transformedMessages) {
|
// Transform messages for cross-provider compatibility
|
||||||
if (msg.role === "user") {
|
const transformedMessages = transformMessages(messages, model);
|
||||||
// Handle both string and array content
|
|
||||||
if (typeof msg.content === "string") {
|
for (const msg of transformedMessages) {
|
||||||
|
if (msg.role === "user") {
|
||||||
|
if (typeof msg.content === "string") {
|
||||||
|
if (msg.content.trim().length > 0) {
|
||||||
params.push({
|
params.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
});
|
});
|
||||||
} else {
|
|
||||||
// Convert array content to Anthropic format
|
|
||||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
|
||||||
if (item.type === "text") {
|
|
||||||
return {
|
|
||||||
type: "text",
|
|
||||||
text: item.text,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Image content
|
|
||||||
return {
|
|
||||||
type: "image",
|
|
||||||
source: {
|
|
||||||
type: "base64",
|
|
||||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
|
||||||
data: item.data,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const filteredBlocks = !this.modelInfo?.input.includes("image")
|
|
||||||
? blocks.filter((b) => b.type !== "image")
|
|
||||||
: blocks;
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: filteredBlocks,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
} else if (msg.role === "assistant") {
|
} else {
|
||||||
const blocks: ContentBlockParam[] = [];
|
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||||
|
if (item.type === "text") {
|
||||||
for (const block of msg.content) {
|
return {
|
||||||
if (block.type === "text") {
|
|
||||||
blocks.push({
|
|
||||||
type: "text",
|
type: "text",
|
||||||
text: block.text,
|
text: item.text,
|
||||||
});
|
};
|
||||||
} else if (block.type === "thinking") {
|
} else {
|
||||||
blocks.push({
|
return {
|
||||||
type: "thinking",
|
type: "image",
|
||||||
thinking: block.thinking,
|
source: {
|
||||||
signature: block.thinkingSignature || "",
|
type: "base64",
|
||||||
});
|
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||||
} else if (block.type === "toolCall") {
|
data: item.data,
|
||||||
blocks.push({
|
},
|
||||||
type: "tool_use",
|
};
|
||||||
id: block.id,
|
|
||||||
name: block.name,
|
|
||||||
input: block.arguments,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
params.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: blocks,
|
|
||||||
});
|
});
|
||||||
} else if (msg.role === "toolResult") {
|
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||||
|
filteredBlocks = filteredBlocks.filter((b) => {
|
||||||
|
if (b.type === "text") {
|
||||||
|
return b.text.trim().length > 0;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
if (filteredBlocks.length === 0) continue;
|
||||||
params.push({
|
params.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: [
|
content: filteredBlocks,
|
||||||
{
|
|
||||||
type: "tool_result",
|
|
||||||
tool_use_id: msg.toolCallId,
|
|
||||||
content: msg.content,
|
|
||||||
is_error: msg.isError,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
} else if (msg.role === "assistant") {
|
||||||
|
const blocks: ContentBlockParam[] = [];
|
||||||
|
|
||||||
|
for (const block of msg.content) {
|
||||||
|
if (block.type === "text") {
|
||||||
|
if (block.text.trim().length === 0) continue;
|
||||||
|
blocks.push({
|
||||||
|
type: "text",
|
||||||
|
text: block.text,
|
||||||
|
});
|
||||||
|
} else if (block.type === "thinking") {
|
||||||
|
if (block.thinking.trim().length === 0) continue;
|
||||||
|
blocks.push({
|
||||||
|
type: "thinking",
|
||||||
|
thinking: block.thinking,
|
||||||
|
signature: block.thinkingSignature || "",
|
||||||
|
});
|
||||||
|
} else if (block.type === "toolCall") {
|
||||||
|
blocks.push({
|
||||||
|
type: "tool_use",
|
||||||
|
id: block.id,
|
||||||
|
name: block.name,
|
||||||
|
input: block.arguments,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (blocks.length === 0) continue;
|
||||||
|
params.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: blocks,
|
||||||
|
});
|
||||||
|
} else if (msg.role === "toolResult") {
|
||||||
|
params.push({
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "tool_result",
|
||||||
|
tool_use_id: msg.toolCallId,
|
||||||
|
content: msg.content,
|
||||||
|
is_error: msg.isError,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
}
|
}
|
||||||
return params;
|
|
||||||
}
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
private convertTools(tools: Context["tools"]): Tool[] {
|
function convertTools(tools: Tool[]): Anthropic.Messages.Tool[] {
|
||||||
if (!tools) return [];
|
if (!tools) return [];
|
||||||
|
|
||||||
return tools.map((tool) => ({
|
return tools.map((tool) => ({
|
||||||
name: tool.name,
|
name: tool.name,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
input_schema: {
|
input_schema: {
|
||||||
type: "object" as const,
|
type: "object" as const,
|
||||||
properties: tool.parameters.properties || {},
|
properties: tool.parameters.properties || {},
|
||||||
required: tool.parameters.required || [],
|
required: tool.parameters.required || [],
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
private mapStopReason(reason: Anthropic.Messages.StopReason | null): StopReason {
|
function mapStopReason(reason: Anthropic.Messages.StopReason): StopReason {
|
||||||
switch (reason) {
|
switch (reason) {
|
||||||
case "end_turn":
|
case "end_turn":
|
||||||
return "stop";
|
return "stop";
|
||||||
case "max_tokens":
|
case "max_tokens":
|
||||||
return "length";
|
return "length";
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
return "toolUse";
|
return "toolUse";
|
||||||
case "refusal":
|
case "refusal":
|
||||||
return "safety";
|
return "safety";
|
||||||
case "pause_turn": // Stop is good enough -> resubmit
|
case "pause_turn": // Stop is good enough -> resubmit
|
||||||
return "stop";
|
return "stop";
|
||||||
case "stop_sequence":
|
case "stop_sequence":
|
||||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||||
default:
|
default: {
|
||||||
return "stop";
|
const _exhaustive: never = reason;
|
||||||
|
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,21 @@
|
||||||
import {
|
import {
|
||||||
type Content,
|
type Content,
|
||||||
type FinishReason,
|
FinishReason,
|
||||||
FunctionCallingConfigMode,
|
FunctionCallingConfigMode,
|
||||||
type GenerateContentConfig,
|
type GenerateContentConfig,
|
||||||
type GenerateContentParameters,
|
type GenerateContentParameters,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
type Part,
|
type Part,
|
||||||
} from "@google/genai";
|
} from "@google/genai";
|
||||||
|
import { QueuedGenerateStream } from "../generate.js";
|
||||||
import { calculateCost } from "../models.js";
|
import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
Context,
|
Context,
|
||||||
LLM,
|
GenerateFunction,
|
||||||
LLMOptions,
|
GenerateOptions,
|
||||||
Message,
|
GenerateStream,
|
||||||
Model,
|
Model,
|
||||||
StopReason,
|
StopReason,
|
||||||
TextContent,
|
TextContent,
|
||||||
|
|
@ -23,7 +25,7 @@ import type {
|
||||||
} from "../types.js";
|
} from "../types.js";
|
||||||
import { transformMessages } from "./utils.js";
|
import { transformMessages } from "./utils.js";
|
||||||
|
|
||||||
export interface GoogleLLMOptions extends LLMOptions {
|
export interface GoogleOptions extends GenerateOptions {
|
||||||
toolChoice?: "auto" | "none" | "any";
|
toolChoice?: "auto" | "none" | "any";
|
||||||
thinking?: {
|
thinking?: {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
|
|
@ -31,38 +33,20 @@ export interface GoogleLLMOptions extends LLMOptions {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
export const streamGoogle: GenerateFunction<"google-generative-ai"> = (
|
||||||
private client: GoogleGenAI;
|
model: Model<"google-generative-ai">,
|
||||||
private modelInfo: Model;
|
context: Context,
|
||||||
|
options?: GoogleOptions,
|
||||||
|
): GenerateStream => {
|
||||||
|
const stream = new QueuedGenerateStream();
|
||||||
|
|
||||||
constructor(model: Model, apiKey?: string) {
|
(async () => {
|
||||||
if (!apiKey) {
|
|
||||||
if (!process.env.GEMINI_API_KEY) {
|
|
||||||
throw new Error(
|
|
||||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
apiKey = process.env.GEMINI_API_KEY;
|
|
||||||
}
|
|
||||||
this.client = new GoogleGenAI({ apiKey });
|
|
||||||
this.modelInfo = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
getModel(): Model {
|
|
||||||
return this.modelInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
getApi(): string {
|
|
||||||
return "google-generative-ai";
|
|
||||||
}
|
|
||||||
|
|
||||||
async generate(context: Context, options?: GoogleLLMOptions): Promise<AssistantMessage> {
|
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
api: this.getApi(),
|
api: "google-generative-ai" as Api,
|
||||||
provider: this.modelInfo.provider,
|
provider: model.provider,
|
||||||
model: this.modelInfo.id,
|
model: model.id,
|
||||||
usage: {
|
usage: {
|
||||||
input: 0,
|
input: 0,
|
||||||
output: 0,
|
output: 0,
|
||||||
|
|
@ -72,70 +56,20 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
},
|
},
|
||||||
stopReason: "stop",
|
stopReason: "stop",
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const contents = this.convertMessages(context.messages);
|
const client = createClient(options?.apiKey);
|
||||||
|
const params = buildParams(model, context, options);
|
||||||
|
const googleStream = await client.models.generateContentStream(params);
|
||||||
|
|
||||||
// Build generation config
|
stream.push({ type: "start", partial: output });
|
||||||
const generationConfig: GenerateContentConfig = {};
|
|
||||||
if (options?.temperature !== undefined) {
|
|
||||||
generationConfig.temperature = options.temperature;
|
|
||||||
}
|
|
||||||
if (options?.maxTokens !== undefined) {
|
|
||||||
generationConfig.maxOutputTokens = options.maxTokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the config object
|
|
||||||
const config: GenerateContentConfig = {
|
|
||||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
|
||||||
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
|
||||||
...(context.tools && { tools: this.convertTools(context.tools) }),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add tool config if needed
|
|
||||||
if (context.tools && options?.toolChoice) {
|
|
||||||
config.toolConfig = {
|
|
||||||
functionCallingConfig: {
|
|
||||||
mode: this.mapToolChoice(options.toolChoice),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add thinking config if enabled and model supports it
|
|
||||||
if (options?.thinking?.enabled && this.modelInfo.reasoning) {
|
|
||||||
config.thinkingConfig = {
|
|
||||||
includeThoughts: true,
|
|
||||||
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort signal
|
|
||||||
if (options?.signal) {
|
|
||||||
if (options.signal.aborted) {
|
|
||||||
throw new Error("Request aborted");
|
|
||||||
}
|
|
||||||
config.abortSignal = options.signal;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the request parameters
|
|
||||||
const params: GenerateContentParameters = {
|
|
||||||
model: this.modelInfo.id,
|
|
||||||
contents,
|
|
||||||
config,
|
|
||||||
};
|
|
||||||
|
|
||||||
const stream = await this.client.models.generateContentStream(params);
|
|
||||||
|
|
||||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
|
||||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of googleStream) {
|
||||||
// Extract parts from the chunk
|
|
||||||
const candidate = chunk.candidates?.[0];
|
const candidate = chunk.candidates?.[0];
|
||||||
if (candidate?.content?.parts) {
|
if (candidate?.content?.parts) {
|
||||||
for (const part of candidate.content.parts) {
|
for (const part of candidate.content.parts) {
|
||||||
if (part.text !== undefined) {
|
if (part.text !== undefined) {
|
||||||
const isThinking = part.thought === true;
|
const isThinking = part.thought === true;
|
||||||
|
|
||||||
// Check if we need to switch blocks
|
|
||||||
if (
|
if (
|
||||||
!currentBlock ||
|
!currentBlock ||
|
||||||
(isThinking && currentBlock.type !== "thinking") ||
|
(isThinking && currentBlock.type !== "thinking") ||
|
||||||
|
|
@ -143,50 +77,60 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
) {
|
) {
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start new block
|
|
||||||
if (isThinking) {
|
if (isThinking) {
|
||||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||||
options?.onEvent?.({ type: "thinking_start" });
|
stream.push({ type: "thinking_start", partial: output });
|
||||||
} else {
|
} else {
|
||||||
currentBlock = { type: "text", text: "" };
|
currentBlock = { type: "text", text: "" };
|
||||||
options?.onEvent?.({ type: "text_start" });
|
stream.push({ type: "text_start", partial: output });
|
||||||
}
|
}
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append content to current block
|
|
||||||
if (currentBlock.type === "thinking") {
|
if (currentBlock.type === "thinking") {
|
||||||
currentBlock.thinking += part.text;
|
currentBlock.thinking += part.text;
|
||||||
currentBlock.thinkingSignature = part.thoughtSignature;
|
currentBlock.thinkingSignature = part.thoughtSignature;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "thinking_delta",
|
type: "thinking_delta",
|
||||||
content: currentBlock.thinking,
|
|
||||||
delta: part.text,
|
delta: part.text,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
currentBlock.text += part.text;
|
currentBlock.text += part.text;
|
||||||
options?.onEvent?.({ type: "text_delta", content: currentBlock.text, delta: part.text });
|
stream.push({ type: "text_delta", delta: part.text, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle function calls
|
|
||||||
if (part.functionCall) {
|
if (part.functionCall) {
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
currentBlock = null;
|
currentBlock = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool call
|
|
||||||
const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`;
|
const toolCallId = part.functionCall.id || `${part.functionCall.name}_${Date.now()}`;
|
||||||
const toolCall: ToolCall = {
|
const toolCall: ToolCall = {
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
|
|
@ -195,21 +139,18 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
arguments: part.functionCall.args as Record<string, any>,
|
arguments: part.functionCall.args as Record<string, any>,
|
||||||
};
|
};
|
||||||
output.content.push(toolCall);
|
output.content.push(toolCall);
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map finish reason
|
|
||||||
if (candidate?.finishReason) {
|
if (candidate?.finishReason) {
|
||||||
output.stopReason = this.mapStopReason(candidate.finishReason);
|
output.stopReason = mapStopReason(candidate.finishReason);
|
||||||
// Check if we have tool calls in blocks
|
|
||||||
if (output.content.some((b) => b.type === "toolCall")) {
|
if (output.content.some((b) => b.type === "toolCall")) {
|
||||||
output.stopReason = "toolUse";
|
output.stopReason = "toolUse";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture usage metadata if available
|
|
||||||
if (chunk.usageMetadata) {
|
if (chunk.usageMetadata) {
|
||||||
output.usage = {
|
output.usage = {
|
||||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||||
|
|
@ -225,166 +166,223 @@ export class GoogleLLM implements LLM<GoogleLLMOptions> {
|
||||||
total: 0,
|
total: 0,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
calculateCost(this.modelInfo, output.usage);
|
calculateCost(model, output.usage);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finalize last block
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({ type: "text_end", content: currentBlock.text, partial: output });
|
||||||
} else {
|
} else {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({ type: "thinking_end", content: currentBlock.thinking, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||||
return output;
|
stream.end();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
return output;
|
stream.end();
|
||||||
}
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
function createClient(apiKey?: string): GoogleGenAI {
|
||||||
|
if (!apiKey) {
|
||||||
|
if (!process.env.GEMINI_API_KEY) {
|
||||||
|
throw new Error(
|
||||||
|
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
apiKey = process.env.GEMINI_API_KEY;
|
||||||
|
}
|
||||||
|
return new GoogleGenAI({ apiKey });
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildParams(
|
||||||
|
model: Model<"google-generative-ai">,
|
||||||
|
context: Context,
|
||||||
|
options: GoogleOptions = {},
|
||||||
|
): GenerateContentParameters {
|
||||||
|
const contents = convertMessages(model, context);
|
||||||
|
|
||||||
|
const generationConfig: GenerateContentConfig = {};
|
||||||
|
if (options.temperature !== undefined) {
|
||||||
|
generationConfig.temperature = options.temperature;
|
||||||
|
}
|
||||||
|
if (options.maxTokens !== undefined) {
|
||||||
|
generationConfig.maxOutputTokens = options.maxTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertMessages(messages: Message[]): Content[] {
|
const config: GenerateContentConfig = {
|
||||||
const contents: Content[] = [];
|
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||||
|
...(context.systemPrompt && { systemInstruction: context.systemPrompt }),
|
||||||
|
...(context.tools && { tools: convertTools(context.tools) }),
|
||||||
|
};
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
if (context.tools && options.toolChoice) {
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
config.toolConfig = {
|
||||||
|
functionCallingConfig: {
|
||||||
|
mode: mapToolChoice(options.toolChoice),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
for (const msg of transformedMessages) {
|
if (options.thinking?.enabled && model.reasoning) {
|
||||||
if (msg.role === "user") {
|
config.thinkingConfig = {
|
||||||
// Handle both string and array content
|
includeThoughts: true,
|
||||||
if (typeof msg.content === "string") {
|
...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }),
|
||||||
contents.push({
|
};
|
||||||
role: "user",
|
}
|
||||||
parts: [{ text: msg.content }],
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Convert array content to Google format
|
|
||||||
const parts: Part[] = msg.content.map((item) => {
|
|
||||||
if (item.type === "text") {
|
|
||||||
return { text: item.text };
|
|
||||||
} else {
|
|
||||||
// Image content - Google uses inlineData
|
|
||||||
return {
|
|
||||||
inlineData: {
|
|
||||||
mimeType: item.mimeType,
|
|
||||||
data: item.data,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const filteredParts = !this.modelInfo?.input.includes("image")
|
|
||||||
? parts.filter((p) => p.text !== undefined)
|
|
||||||
: parts;
|
|
||||||
contents.push({
|
|
||||||
role: "user",
|
|
||||||
parts: filteredParts,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (msg.role === "assistant") {
|
|
||||||
const parts: Part[] = [];
|
|
||||||
|
|
||||||
// Process content blocks
|
if (options.signal) {
|
||||||
for (const block of msg.content) {
|
if (options.signal.aborted) {
|
||||||
if (block.type === "text") {
|
throw new Error("Request aborted");
|
||||||
parts.push({ text: block.text });
|
}
|
||||||
} else if (block.type === "thinking") {
|
config.abortSignal = options.signal;
|
||||||
const thinkingPart: Part = {
|
}
|
||||||
thought: true,
|
|
||||||
thoughtSignature: block.thinkingSignature,
|
|
||||||
text: block.thinking,
|
|
||||||
};
|
|
||||||
parts.push(thinkingPart);
|
|
||||||
} else if (block.type === "toolCall") {
|
|
||||||
parts.push({
|
|
||||||
functionCall: {
|
|
||||||
id: block.id,
|
|
||||||
name: block.name,
|
|
||||||
args: block.arguments,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parts.length > 0) {
|
const params: GenerateContentParameters = {
|
||||||
contents.push({
|
model: model.id,
|
||||||
role: "model",
|
contents,
|
||||||
parts,
|
config,
|
||||||
});
|
};
|
||||||
}
|
|
||||||
} else if (msg.role === "toolResult") {
|
return params;
|
||||||
|
}
|
||||||
|
function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] {
|
||||||
|
const contents: Content[] = [];
|
||||||
|
const transformedMessages = transformMessages(context.messages, model);
|
||||||
|
|
||||||
|
for (const msg of transformedMessages) {
|
||||||
|
if (msg.role === "user") {
|
||||||
|
if (typeof msg.content === "string") {
|
||||||
contents.push({
|
contents.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
parts: [
|
parts: [{ text: msg.content }],
|
||||||
{
|
});
|
||||||
functionResponse: {
|
} else {
|
||||||
id: msg.toolCallId,
|
const parts: Part[] = msg.content.map((item) => {
|
||||||
name: msg.toolName,
|
if (item.type === "text") {
|
||||||
response: {
|
return { text: item.text };
|
||||||
result: msg.content,
|
} else {
|
||||||
isError: msg.isError,
|
return {
|
||||||
},
|
inlineData: {
|
||||||
|
mimeType: item.mimeType,
|
||||||
|
data: item.data,
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
],
|
}
|
||||||
|
});
|
||||||
|
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
||||||
|
if (filteredParts.length === 0) continue;
|
||||||
|
contents.push({
|
||||||
|
role: "user",
|
||||||
|
parts: filteredParts,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
} else if (msg.role === "assistant") {
|
||||||
|
const parts: Part[] = [];
|
||||||
|
|
||||||
return contents;
|
for (const block of msg.content) {
|
||||||
}
|
if (block.type === "text") {
|
||||||
|
parts.push({ text: block.text });
|
||||||
|
} else if (block.type === "thinking") {
|
||||||
|
const thinkingPart: Part = {
|
||||||
|
thought: true,
|
||||||
|
thoughtSignature: block.thinkingSignature,
|
||||||
|
text: block.thinking,
|
||||||
|
};
|
||||||
|
parts.push(thinkingPart);
|
||||||
|
} else if (block.type === "toolCall") {
|
||||||
|
parts.push({
|
||||||
|
functionCall: {
|
||||||
|
id: block.id,
|
||||||
|
name: block.name,
|
||||||
|
args: block.arguments,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private convertTools(tools: Tool[]): any[] {
|
if (parts.length === 0) continue;
|
||||||
return [
|
contents.push({
|
||||||
{
|
role: "model",
|
||||||
functionDeclarations: tools.map((tool) => ({
|
parts,
|
||||||
name: tool.name,
|
});
|
||||||
description: tool.description,
|
} else if (msg.role === "toolResult") {
|
||||||
parameters: tool.parameters,
|
contents.push({
|
||||||
})),
|
role: "user",
|
||||||
},
|
parts: [
|
||||||
];
|
{
|
||||||
}
|
functionResponse: {
|
||||||
|
id: msg.toolCallId,
|
||||||
private mapToolChoice(choice: string): FunctionCallingConfigMode {
|
name: msg.toolName,
|
||||||
switch (choice) {
|
response: {
|
||||||
case "auto":
|
result: msg.content,
|
||||||
return FunctionCallingConfigMode.AUTO;
|
isError: msg.isError,
|
||||||
case "none":
|
},
|
||||||
return FunctionCallingConfigMode.NONE;
|
},
|
||||||
case "any":
|
},
|
||||||
return FunctionCallingConfigMode.ANY;
|
],
|
||||||
default:
|
});
|
||||||
return FunctionCallingConfigMode.AUTO;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private mapStopReason(reason: FinishReason): StopReason {
|
return contents;
|
||||||
switch (reason) {
|
}
|
||||||
case "STOP":
|
|
||||||
return "stop";
|
function convertTools(tools: Tool[]): any[] {
|
||||||
case "MAX_TOKENS":
|
return [
|
||||||
return "length";
|
{
|
||||||
case "BLOCKLIST":
|
functionDeclarations: tools.map((tool) => ({
|
||||||
case "PROHIBITED_CONTENT":
|
name: tool.name,
|
||||||
case "SPII":
|
description: tool.description,
|
||||||
case "SAFETY":
|
parameters: tool.parameters,
|
||||||
case "IMAGE_SAFETY":
|
})),
|
||||||
return "safety";
|
},
|
||||||
case "RECITATION":
|
];
|
||||||
return "safety";
|
}
|
||||||
case "FINISH_REASON_UNSPECIFIED":
|
|
||||||
case "OTHER":
|
function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||||
case "LANGUAGE":
|
switch (choice) {
|
||||||
case "MALFORMED_FUNCTION_CALL":
|
case "auto":
|
||||||
case "UNEXPECTED_TOOL_CALL":
|
return FunctionCallingConfigMode.AUTO;
|
||||||
return "error";
|
case "none":
|
||||||
default:
|
return FunctionCallingConfigMode.NONE;
|
||||||
return "stop";
|
case "any":
|
||||||
|
return FunctionCallingConfigMode.ANY;
|
||||||
|
default:
|
||||||
|
return FunctionCallingConfigMode.AUTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapStopReason(reason: FinishReason): StopReason {
|
||||||
|
switch (reason) {
|
||||||
|
case FinishReason.STOP:
|
||||||
|
return "stop";
|
||||||
|
case FinishReason.MAX_TOKENS:
|
||||||
|
return "length";
|
||||||
|
case FinishReason.BLOCKLIST:
|
||||||
|
case FinishReason.PROHIBITED_CONTENT:
|
||||||
|
case FinishReason.SPII:
|
||||||
|
case FinishReason.SAFETY:
|
||||||
|
case FinishReason.IMAGE_SAFETY:
|
||||||
|
case FinishReason.RECITATION:
|
||||||
|
return "safety";
|
||||||
|
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||||
|
case FinishReason.OTHER:
|
||||||
|
case FinishReason.LANGUAGE:
|
||||||
|
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||||
|
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||||
|
return "error";
|
||||||
|
default: {
|
||||||
|
const _exhaustive: never = reason;
|
||||||
|
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,20 @@
|
||||||
import OpenAI from "openai";
|
import OpenAI from "openai";
|
||||||
import type {
|
import type {
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionContentPart,
|
ChatCompletionContentPart,
|
||||||
ChatCompletionContentPartImage,
|
ChatCompletionContentPartImage,
|
||||||
ChatCompletionContentPartText,
|
ChatCompletionContentPartText,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
} from "openai/resources/chat/completions.js";
|
} from "openai/resources/chat/completions.js";
|
||||||
|
import { QueuedGenerateStream } from "../generate.js";
|
||||||
import { calculateCost } from "../models.js";
|
import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
Context,
|
Context,
|
||||||
LLM,
|
GenerateFunction,
|
||||||
LLMOptions,
|
GenerateOptions,
|
||||||
Message,
|
GenerateStream,
|
||||||
Model,
|
Model,
|
||||||
StopReason,
|
StopReason,
|
||||||
TextContent,
|
TextContent,
|
||||||
|
|
@ -22,43 +24,25 @@ import type {
|
||||||
} from "../types.js";
|
} from "../types.js";
|
||||||
import { transformMessages } from "./utils.js";
|
import { transformMessages } from "./utils.js";
|
||||||
|
|
||||||
export interface OpenAICompletionsLLMOptions extends LLMOptions {
|
export interface OpenAICompletionsOptions extends GenerateOptions {
|
||||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||||
reasoningEffort?: "low" | "medium" | "high";
|
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||||
}
|
}
|
||||||
|
|
||||||
export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
export const streamOpenAICompletions: GenerateFunction<"openai-completions"> = (
|
||||||
private client: OpenAI;
|
model: Model<"openai-completions">,
|
||||||
private modelInfo: Model;
|
context: Context,
|
||||||
|
options?: OpenAICompletionsOptions,
|
||||||
|
): GenerateStream => {
|
||||||
|
const stream = new QueuedGenerateStream();
|
||||||
|
|
||||||
constructor(model: Model, apiKey?: string) {
|
(async () => {
|
||||||
if (!apiKey) {
|
|
||||||
if (!process.env.OPENAI_API_KEY) {
|
|
||||||
throw new Error(
|
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
apiKey = process.env.OPENAI_API_KEY;
|
|
||||||
}
|
|
||||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
|
||||||
this.modelInfo = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
getModel(): Model {
|
|
||||||
return this.modelInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
getApi(): string {
|
|
||||||
return "openai-completions";
|
|
||||||
}
|
|
||||||
|
|
||||||
async generate(request: Context, options?: OpenAICompletionsLLMOptions): Promise<AssistantMessage> {
|
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
api: this.getApi(),
|
api: model.api,
|
||||||
provider: this.modelInfo.provider,
|
provider: model.provider,
|
||||||
model: this.modelInfo.id,
|
model: model.id,
|
||||||
usage: {
|
usage: {
|
||||||
input: 0,
|
input: 0,
|
||||||
output: 0,
|
output: 0,
|
||||||
|
|
@ -70,52 +54,13 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const messages = this.convertMessages(request.messages, request.systemPrompt);
|
const client = createClient(model, options?.apiKey);
|
||||||
|
const params = buildParams(model, context, options);
|
||||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
|
||||||
model: this.modelInfo.id,
|
stream.push({ type: "start", partial: output });
|
||||||
messages,
|
|
||||||
stream: true,
|
|
||||||
stream_options: { include_usage: true },
|
|
||||||
};
|
|
||||||
|
|
||||||
// Cerebras/xAI dont like the "store" field
|
|
||||||
if (!this.modelInfo.baseUrl?.includes("cerebras.ai") && !this.modelInfo.baseUrl?.includes("api.x.ai")) {
|
|
||||||
params.store = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.maxTokens) {
|
|
||||||
params.max_completion_tokens = options?.maxTokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.temperature !== undefined) {
|
|
||||||
params.temperature = options?.temperature;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (request.tools) {
|
|
||||||
params.tools = this.convertTools(request.tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.toolChoice) {
|
|
||||||
params.tool_choice = options.toolChoice;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
options?.reasoningEffort &&
|
|
||||||
this.modelInfo.reasoning &&
|
|
||||||
!this.modelInfo.id.toLowerCase().includes("grok")
|
|
||||||
) {
|
|
||||||
params.reasoning_effort = options.reasoningEffort;
|
|
||||||
}
|
|
||||||
|
|
||||||
const stream = await this.client.chat.completions.create(params, {
|
|
||||||
signal: options?.signal,
|
|
||||||
});
|
|
||||||
|
|
||||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
|
||||||
|
|
||||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of openaiStream) {
|
||||||
if (chunk.usage) {
|
if (chunk.usage) {
|
||||||
output.usage = {
|
output.usage = {
|
||||||
input: chunk.usage.prompt_tokens || 0,
|
input: chunk.usage.prompt_tokens || 0,
|
||||||
|
|
@ -132,137 +77,170 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
total: 0,
|
total: 0,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
calculateCost(this.modelInfo, output.usage);
|
calculateCost(model, output.usage);
|
||||||
}
|
}
|
||||||
|
|
||||||
const choice = chunk.choices[0];
|
const choice = chunk.choices[0];
|
||||||
if (!choice) continue;
|
if (!choice) continue;
|
||||||
|
|
||||||
// Capture finish reason
|
|
||||||
if (choice.finish_reason) {
|
if (choice.finish_reason) {
|
||||||
output.stopReason = this.mapStopReason(choice.finish_reason);
|
output.stopReason = mapStopReason(choice.finish_reason);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (choice.delta) {
|
if (choice.delta) {
|
||||||
// Handle text content
|
|
||||||
if (
|
if (
|
||||||
choice.delta.content !== null &&
|
choice.delta.content !== null &&
|
||||||
choice.delta.content !== undefined &&
|
choice.delta.content !== undefined &&
|
||||||
choice.delta.content.length > 0
|
choice.delta.content.length > 0
|
||||||
) {
|
) {
|
||||||
// Check if we need to switch to text block
|
|
||||||
if (!currentBlock || currentBlock.type !== "text") {
|
if (!currentBlock || currentBlock.type !== "text") {
|
||||||
// Save current block if exists
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "thinking") {
|
if (currentBlock.type === "thinking") {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||||
delete currentBlock.partialArgs;
|
delete currentBlock.partialArgs;
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: currentBlock as ToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Start new text block
|
|
||||||
currentBlock = { type: "text", text: "" };
|
currentBlock = { type: "text", text: "" };
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
options?.onEvent?.({ type: "text_start" });
|
stream.push({ type: "text_start", partial: output });
|
||||||
}
|
}
|
||||||
// Append to text block
|
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
currentBlock.text += choice.delta.content;
|
currentBlock.text += choice.delta.content;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "text_delta",
|
type: "text_delta",
|
||||||
content: currentBlock.text,
|
|
||||||
delta: choice.delta.content,
|
delta: choice.delta.content,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle reasoning_content field
|
// Some endpoints return reasoning in reasoning_content (llama.cpp)
|
||||||
if (
|
if (
|
||||||
(choice.delta as any).reasoning_content !== null &&
|
(choice.delta as any).reasoning_content !== null &&
|
||||||
(choice.delta as any).reasoning_content !== undefined &&
|
(choice.delta as any).reasoning_content !== undefined &&
|
||||||
(choice.delta as any).reasoning_content.length > 0
|
(choice.delta as any).reasoning_content.length > 0
|
||||||
) {
|
) {
|
||||||
// Check if we need to switch to thinking block
|
|
||||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||||
// Save current block if exists
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||||
delete currentBlock.partialArgs;
|
delete currentBlock.partialArgs;
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: currentBlock as ToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Start new thinking block
|
currentBlock = {
|
||||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning_content" };
|
type: "thinking",
|
||||||
|
thinking: "",
|
||||||
|
thinkingSignature: "reasoning_content",
|
||||||
|
};
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
options?.onEvent?.({ type: "thinking_start" });
|
stream.push({ type: "thinking_start", partial: output });
|
||||||
}
|
}
|
||||||
// Append to thinking block
|
|
||||||
if (currentBlock.type === "thinking") {
|
if (currentBlock.type === "thinking") {
|
||||||
const delta = (choice.delta as any).reasoning_content;
|
const delta = (choice.delta as any).reasoning_content;
|
||||||
currentBlock.thinking += delta;
|
currentBlock.thinking += delta;
|
||||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
stream.push({
|
||||||
|
type: "thinking_delta",
|
||||||
|
delta,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle reasoning field
|
// Some endpoints return reasoning in reasining (ollama, xAI, ...)
|
||||||
if (
|
if (
|
||||||
(choice.delta as any).reasoning !== null &&
|
(choice.delta as any).reasoning !== null &&
|
||||||
(choice.delta as any).reasoning !== undefined &&
|
(choice.delta as any).reasoning !== undefined &&
|
||||||
(choice.delta as any).reasoning.length > 0
|
(choice.delta as any).reasoning.length > 0
|
||||||
) {
|
) {
|
||||||
// Check if we need to switch to thinking block
|
|
||||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||||
// Save current block if exists
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||||
delete currentBlock.partialArgs;
|
delete currentBlock.partialArgs;
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: currentBlock as ToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Start new thinking block
|
currentBlock = {
|
||||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: "reasoning" };
|
type: "thinking",
|
||||||
|
thinking: "",
|
||||||
|
thinkingSignature: "reasoning",
|
||||||
|
};
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
options?.onEvent?.({ type: "thinking_start" });
|
stream.push({ type: "thinking_start", partial: output });
|
||||||
}
|
}
|
||||||
// Append to thinking block
|
|
||||||
if (currentBlock.type === "thinking") {
|
if (currentBlock.type === "thinking") {
|
||||||
const delta = (choice.delta as any).reasoning;
|
const delta = (choice.delta as any).reasoning;
|
||||||
currentBlock.thinking += delta;
|
currentBlock.thinking += delta;
|
||||||
options?.onEvent?.({ type: "thinking_delta", content: currentBlock.thinking, delta });
|
stream.push({ type: "thinking_delta", delta, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle tool calls
|
|
||||||
if (choice?.delta?.tool_calls) {
|
if (choice?.delta?.tool_calls) {
|
||||||
for (const toolCall of choice.delta.tool_calls) {
|
for (const toolCall of choice.delta.tool_calls) {
|
||||||
// Check if we need a new tool call block
|
|
||||||
if (
|
if (
|
||||||
!currentBlock ||
|
!currentBlock ||
|
||||||
currentBlock.type !== "toolCall" ||
|
currentBlock.type !== "toolCall" ||
|
||||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||||
) {
|
) {
|
||||||
// Save current block if exists
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "thinking") {
|
} else if (currentBlock.type === "thinking") {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||||
delete currentBlock.partialArgs;
|
delete currentBlock.partialArgs;
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: currentBlock as ToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start new tool call block
|
|
||||||
currentBlock = {
|
currentBlock = {
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
id: toolCall.id || "",
|
id: toolCall.id || "",
|
||||||
|
|
@ -273,7 +251,6 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
output.content.push(currentBlock);
|
output.content.push(currentBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate tool call data
|
|
||||||
if (currentBlock.type === "toolCall") {
|
if (currentBlock.type === "toolCall") {
|
||||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||||
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
||||||
|
|
@ -286,16 +263,27 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save final block if exists
|
|
||||||
if (currentBlock) {
|
if (currentBlock) {
|
||||||
if (currentBlock.type === "text") {
|
if (currentBlock.type === "text") {
|
||||||
options?.onEvent?.({ type: "text_end", content: currentBlock.text });
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "thinking") {
|
} else if (currentBlock.type === "thinking") {
|
||||||
options?.onEvent?.({ type: "thinking_end", content: currentBlock.thinking });
|
stream.push({
|
||||||
|
type: "thinking_end",
|
||||||
|
content: currentBlock.thinking,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
} else if (currentBlock.type === "toolCall") {
|
} else if (currentBlock.type === "toolCall") {
|
||||||
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
currentBlock.arguments = JSON.parse(currentBlock.partialArgs || "{}");
|
||||||
delete currentBlock.partialArgs;
|
delete currentBlock.partialArgs;
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall: currentBlock as ToolCall });
|
stream.push({
|
||||||
|
type: "toolCall",
|
||||||
|
toolCall: currentBlock as ToolCall,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -303,141 +291,188 @@ export class OpenAICompletionsLLM implements LLM<OpenAICompletionsLLMOptions> {
|
||||||
throw new Error("Request was aborted");
|
throw new Error("Request was aborted");
|
||||||
}
|
}
|
||||||
|
|
||||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||||
|
stream.end();
|
||||||
return output;
|
return output;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// Update output with error information
|
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = error instanceof Error ? error.message : String(error);
|
output.error = error instanceof Error ? error.message : String(error);
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
return output;
|
stream.end();
|
||||||
}
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
function createClient(model: Model<"openai-completions">, apiKey?: string) {
|
||||||
|
if (!apiKey) {
|
||||||
|
if (!process.env.OPENAI_API_KEY) {
|
||||||
|
throw new Error(
|
||||||
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
apiKey = process.env.OPENAI_API_KEY;
|
||||||
|
}
|
||||||
|
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
|
||||||
|
const messages = convertMessages(model, context);
|
||||||
|
|
||||||
|
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||||
|
model: model.id,
|
||||||
|
messages,
|
||||||
|
stream: true,
|
||||||
|
stream_options: { include_usage: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Cerebras/xAI dont like the "store" field
|
||||||
|
if (!model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai")) {
|
||||||
|
params.store = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertMessages(messages: Message[], systemPrompt?: string): ChatCompletionMessageParam[] {
|
if (options?.maxTokens) {
|
||||||
const params: ChatCompletionMessageParam[] = [];
|
params.max_completion_tokens = options?.maxTokens;
|
||||||
|
}
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
if (options?.temperature !== undefined) {
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
params.temperature = options?.temperature;
|
||||||
|
}
|
||||||
|
|
||||||
// Add system prompt if provided
|
if (context.tools) {
|
||||||
if (systemPrompt) {
|
params.tools = convertTools(context.tools);
|
||||||
// Cerebras/xAi don't like the "developer" role
|
}
|
||||||
const useDeveloperRole =
|
|
||||||
this.modelInfo.reasoning &&
|
|
||||||
!this.modelInfo.baseUrl?.includes("cerebras.ai") &&
|
|
||||||
!this.modelInfo.baseUrl?.includes("api.x.ai");
|
|
||||||
const role = useDeveloperRole ? "developer" : "system";
|
|
||||||
params.push({ role: role, content: systemPrompt });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert messages
|
if (options?.toolChoice) {
|
||||||
for (const msg of transformedMessages) {
|
params.tool_choice = options.toolChoice;
|
||||||
if (msg.role === "user") {
|
}
|
||||||
// Handle both string and array content
|
|
||||||
if (typeof msg.content === "string") {
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: msg.content,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Convert array content to OpenAI format
|
|
||||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
|
||||||
if (item.type === "text") {
|
|
||||||
return {
|
|
||||||
type: "text",
|
|
||||||
text: item.text,
|
|
||||||
} satisfies ChatCompletionContentPartText;
|
|
||||||
} else {
|
|
||||||
// Image content - OpenAI uses data URLs
|
|
||||||
return {
|
|
||||||
type: "image_url",
|
|
||||||
image_url: {
|
|
||||||
url: `data:${item.mimeType};base64,${item.data}`,
|
|
||||||
},
|
|
||||||
} satisfies ChatCompletionContentPartImage;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
|
||||||
? content.filter((c) => c.type !== "image_url")
|
|
||||||
: content;
|
|
||||||
params.push({
|
|
||||||
role: "user",
|
|
||||||
content: filteredContent,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (msg.role === "assistant") {
|
|
||||||
const assistantMsg: ChatCompletionMessageParam = {
|
|
||||||
role: "assistant",
|
|
||||||
content: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build content from blocks
|
// Grok models don't like reasoning_effort
|
||||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
if (options?.reasoningEffort && model.reasoning && !model.id.toLowerCase().includes("grok")) {
|
||||||
if (textBlocks.length > 0) {
|
params.reasoning_effort = options.reasoningEffort;
|
||||||
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Handle thinking blocks for llama.cpp server + gpt-oss
|
return params;
|
||||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
}
|
||||||
if (thinkingBlocks.length > 0) {
|
|
||||||
// Use the signature from the first thinking block if available
|
|
||||||
const signature = thinkingBlocks[0].thinkingSignature;
|
|
||||||
if (signature && signature.length > 0) {
|
|
||||||
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle tool calls
|
function convertMessages(model: Model<"openai-completions">, context: Context): ChatCompletionMessageParam[] {
|
||||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
const params: ChatCompletionMessageParam[] = [];
|
||||||
if (toolCalls.length > 0) {
|
|
||||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
|
||||||
id: tc.id,
|
|
||||||
type: "function" as const,
|
|
||||||
function: {
|
|
||||||
name: tc.name,
|
|
||||||
arguments: JSON.stringify(tc.arguments),
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
params.push(assistantMsg);
|
const transformedMessages = transformMessages(context.messages, model);
|
||||||
} else if (msg.role === "toolResult") {
|
|
||||||
|
if (context.systemPrompt) {
|
||||||
|
// Cerebras/xAi don't like the "developer" role
|
||||||
|
const useDeveloperRole =
|
||||||
|
model.reasoning && !model.baseUrl.includes("cerebras.ai") && !model.baseUrl.includes("api.x.ai");
|
||||||
|
const role = useDeveloperRole ? "developer" : "system";
|
||||||
|
params.push({ role: role, content: context.systemPrompt });
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const msg of transformedMessages) {
|
||||||
|
if (msg.role === "user") {
|
||||||
|
if (typeof msg.content === "string") {
|
||||||
params.push({
|
params.push({
|
||||||
role: "tool",
|
role: "user",
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
tool_call_id: msg.toolCallId,
|
});
|
||||||
|
} else {
|
||||||
|
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||||
|
if (item.type === "text") {
|
||||||
|
return {
|
||||||
|
type: "text",
|
||||||
|
text: item.text,
|
||||||
|
} satisfies ChatCompletionContentPartText;
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
type: "image_url",
|
||||||
|
image_url: {
|
||||||
|
url: `data:${item.mimeType};base64,${item.data}`,
|
||||||
|
},
|
||||||
|
} satisfies ChatCompletionContentPartImage;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const filteredContent = !model.input.includes("image")
|
||||||
|
? content.filter((c) => c.type !== "image_url")
|
||||||
|
: content;
|
||||||
|
if (filteredContent.length === 0) continue;
|
||||||
|
params.push({
|
||||||
|
role: "user",
|
||||||
|
content: filteredContent,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
} else if (msg.role === "assistant") {
|
||||||
|
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||||
|
role: "assistant",
|
||||||
|
content: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||||
|
if (textBlocks.length > 0) {
|
||||||
|
assistantMsg.content = textBlocks.map((b) => b.text).join("");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle thinking blocks for llama.cpp server + gpt-oss
|
||||||
|
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||||
|
if (thinkingBlocks.length > 0) {
|
||||||
|
// Use the signature from the first thinking block if available
|
||||||
|
const signature = thinkingBlocks[0].thinkingSignature;
|
||||||
|
if (signature && signature.length > 0) {
|
||||||
|
(assistantMsg as any)[signature] = thinkingBlocks.map((b) => b.thinking).join("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||||
|
id: tc.id,
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: tc.name,
|
||||||
|
arguments: JSON.stringify(tc.arguments),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
params.push(assistantMsg);
|
||||||
|
} else if (msg.role === "toolResult") {
|
||||||
|
params.push({
|
||||||
|
role: "tool",
|
||||||
|
content: msg.content,
|
||||||
|
tool_call_id: msg.toolCallId,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return params;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
return params;
|
||||||
return tools.map((tool) => ({
|
}
|
||||||
type: "function",
|
|
||||||
function: {
|
|
||||||
name: tool.name,
|
|
||||||
description: tool.description,
|
|
||||||
parameters: tool.parameters,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
private mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"] | null): StopReason {
|
function convertTools(tools: Tool[]): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||||
switch (reason) {
|
return tools.map((tool) => ({
|
||||||
case "stop":
|
type: "function",
|
||||||
return "stop";
|
function: {
|
||||||
case "length":
|
name: tool.name,
|
||||||
return "length";
|
description: tool.description,
|
||||||
case "function_call":
|
parameters: tool.parameters,
|
||||||
case "tool_calls":
|
},
|
||||||
return "toolUse";
|
}));
|
||||||
case "content_filter":
|
}
|
||||||
return "safety";
|
|
||||||
default:
|
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
|
||||||
return "stop";
|
if (reason === null) return "stop";
|
||||||
|
switch (reason) {
|
||||||
|
case "stop":
|
||||||
|
return "stop";
|
||||||
|
case "length":
|
||||||
|
return "length";
|
||||||
|
case "function_call":
|
||||||
|
case "tool_calls":
|
||||||
|
return "toolUse";
|
||||||
|
case "content_filter":
|
||||||
|
return "safety";
|
||||||
|
default: {
|
||||||
|
const _exhaustive: never = reason;
|
||||||
|
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,58 +10,49 @@ import type {
|
||||||
ResponseOutputMessage,
|
ResponseOutputMessage,
|
||||||
ResponseReasoningItem,
|
ResponseReasoningItem,
|
||||||
} from "openai/resources/responses/responses.js";
|
} from "openai/resources/responses/responses.js";
|
||||||
|
import { QueuedGenerateStream } from "../generate.js";
|
||||||
import { calculateCost } from "../models.js";
|
import { calculateCost } from "../models.js";
|
||||||
import type {
|
import type {
|
||||||
|
Api,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
Context,
|
Context,
|
||||||
LLM,
|
GenerateFunction,
|
||||||
LLMOptions,
|
GenerateOptions,
|
||||||
|
GenerateStream,
|
||||||
Message,
|
Message,
|
||||||
Model,
|
Model,
|
||||||
StopReason,
|
StopReason,
|
||||||
TextContent,
|
TextContent,
|
||||||
|
ThinkingContent,
|
||||||
Tool,
|
Tool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
} from "../types.js";
|
} from "../types.js";
|
||||||
import { transformMessages } from "./utils.js";
|
import { transformMessages } from "./utils.js";
|
||||||
|
|
||||||
export interface OpenAIResponsesLLMOptions extends LLMOptions {
|
// OpenAI Responses-specific options
|
||||||
|
export interface OpenAIResponsesOptions extends GenerateOptions {
|
||||||
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
reasoningEffort?: "minimal" | "low" | "medium" | "high";
|
||||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
/**
|
||||||
private client: OpenAI;
|
* Generate function for OpenAI Responses API
|
||||||
private modelInfo: Model;
|
*/
|
||||||
|
export const streamOpenAIResponses: GenerateFunction<"openai-responses"> = (
|
||||||
|
model: Model<"openai-responses">,
|
||||||
|
context: Context,
|
||||||
|
options?: OpenAIResponsesOptions,
|
||||||
|
): GenerateStream => {
|
||||||
|
const stream = new QueuedGenerateStream();
|
||||||
|
|
||||||
constructor(model: Model, apiKey?: string) {
|
// Start async processing
|
||||||
if (!apiKey) {
|
(async () => {
|
||||||
if (!process.env.OPENAI_API_KEY) {
|
|
||||||
throw new Error(
|
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
apiKey = process.env.OPENAI_API_KEY;
|
|
||||||
}
|
|
||||||
this.client = new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
|
||||||
this.modelInfo = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
getModel(): Model {
|
|
||||||
return this.modelInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
getApi(): string {
|
|
||||||
return "openai-responses";
|
|
||||||
}
|
|
||||||
|
|
||||||
async generate(request: Context, options?: OpenAIResponsesLLMOptions): Promise<AssistantMessage> {
|
|
||||||
const output: AssistantMessage = {
|
const output: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
api: this.getApi(),
|
api: "openai-responses" as Api,
|
||||||
provider: this.modelInfo.provider,
|
provider: model.provider,
|
||||||
model: this.modelInfo.id,
|
model: model.id,
|
||||||
usage: {
|
usage: {
|
||||||
input: 0,
|
input: 0,
|
||||||
output: 0,
|
output: 0,
|
||||||
|
|
@ -71,77 +62,31 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
},
|
},
|
||||||
stopReason: "stop",
|
stopReason: "stop",
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const input = this.convertToInput(request.messages, request.systemPrompt);
|
// Create OpenAI client
|
||||||
|
const client = createClient(model, options?.apiKey);
|
||||||
|
const params = buildParams(model, context, options);
|
||||||
|
const openaiStream = await client.responses.create(params, { signal: options?.signal });
|
||||||
|
stream.push({ type: "start", partial: output });
|
||||||
|
|
||||||
const params: ResponseCreateParamsStreaming = {
|
|
||||||
model: this.modelInfo.id,
|
|
||||||
input,
|
|
||||||
stream: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (options?.maxTokens) {
|
|
||||||
params.max_output_tokens = options?.maxTokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.temperature !== undefined) {
|
|
||||||
params.temperature = options?.temperature;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (request.tools) {
|
|
||||||
params.tools = this.convertTools(request.tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add reasoning options for models that support it
|
|
||||||
if (this.modelInfo?.reasoning) {
|
|
||||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
|
||||||
params.reasoning = {
|
|
||||||
effort: options?.reasoningEffort || "medium",
|
|
||||||
summary: options?.reasoningSummary || "auto",
|
|
||||||
};
|
|
||||||
params.include = ["reasoning.encrypted_content"];
|
|
||||||
} else {
|
|
||||||
params.reasoning = {
|
|
||||||
effort: this.modelInfo.name.startsWith("gpt-5") ? "minimal" : null,
|
|
||||||
summary: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (this.modelInfo.name.startsWith("gpt-5")) {
|
|
||||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
|
||||||
input.push({
|
|
||||||
role: "developer",
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: "input_text",
|
|
||||||
text: "# Juice: 0 !important",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const stream = await this.client.responses.create(params, {
|
|
||||||
signal: options?.signal,
|
|
||||||
});
|
|
||||||
|
|
||||||
options?.onEvent?.({ type: "start", model: this.modelInfo.id, provider: this.modelInfo.provider });
|
|
||||||
|
|
||||||
const outputItems: (ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall)[] = [];
|
|
||||||
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
||||||
|
let currentBlock: ThinkingContent | TextContent | ToolCall | null = null;
|
||||||
|
|
||||||
for await (const event of stream) {
|
for await (const event of openaiStream) {
|
||||||
// Handle output item start
|
// Handle output item start
|
||||||
if (event.type === "response.output_item.added") {
|
if (event.type === "response.output_item.added") {
|
||||||
const item = event.item;
|
const item = event.item;
|
||||||
if (item.type === "reasoning") {
|
if (item.type === "reasoning") {
|
||||||
options?.onEvent?.({ type: "thinking_start" });
|
|
||||||
outputItems.push(item);
|
|
||||||
currentItem = item;
|
currentItem = item;
|
||||||
|
currentBlock = { type: "thinking", thinking: "" };
|
||||||
|
output.content.push(currentBlock);
|
||||||
|
stream.push({ type: "thinking_start", partial: output });
|
||||||
} else if (item.type === "message") {
|
} else if (item.type === "message") {
|
||||||
options?.onEvent?.({ type: "text_start" });
|
|
||||||
outputItems.push(item);
|
|
||||||
currentItem = item;
|
currentItem = item;
|
||||||
|
currentBlock = { type: "text", text: "" };
|
||||||
|
output.content.push(currentBlock);
|
||||||
|
stream.push({ type: "text_start", partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Handle reasoning summary deltas
|
// Handle reasoning summary deltas
|
||||||
|
|
@ -151,30 +96,42 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
currentItem.summary.push(event.part);
|
currentItem.summary.push(event.part);
|
||||||
}
|
}
|
||||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||||
if (currentItem && currentItem.type === "reasoning") {
|
if (
|
||||||
|
currentItem &&
|
||||||
|
currentItem.type === "reasoning" &&
|
||||||
|
currentBlock &&
|
||||||
|
currentBlock.type === "thinking"
|
||||||
|
) {
|
||||||
currentItem.summary = currentItem.summary || [];
|
currentItem.summary = currentItem.summary || [];
|
||||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||||
if (lastPart) {
|
if (lastPart) {
|
||||||
|
currentBlock.thinking += event.delta;
|
||||||
lastPart.text += event.delta;
|
lastPart.text += event.delta;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "thinking_delta",
|
type: "thinking_delta",
|
||||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
|
||||||
delta: event.delta,
|
delta: event.delta,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Add a new line between summary parts (hack...)
|
// Add a new line between summary parts (hack...)
|
||||||
else if (event.type === "response.reasoning_summary_part.done") {
|
else if (event.type === "response.reasoning_summary_part.done") {
|
||||||
if (currentItem && currentItem.type === "reasoning") {
|
if (
|
||||||
|
currentItem &&
|
||||||
|
currentItem.type === "reasoning" &&
|
||||||
|
currentBlock &&
|
||||||
|
currentBlock.type === "thinking"
|
||||||
|
) {
|
||||||
currentItem.summary = currentItem.summary || [];
|
currentItem.summary = currentItem.summary || [];
|
||||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||||
if (lastPart) {
|
if (lastPart) {
|
||||||
|
currentBlock.thinking += "\n\n";
|
||||||
lastPart.text += "\n\n";
|
lastPart.text += "\n\n";
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "thinking_delta",
|
type: "thinking_delta",
|
||||||
content: currentItem.summary.map((s) => s.text).join("\n\n"),
|
|
||||||
delta: "\n\n",
|
delta: "\n\n",
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -186,30 +143,28 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
currentItem.content.push(event.part);
|
currentItem.content.push(event.part);
|
||||||
}
|
}
|
||||||
} else if (event.type === "response.output_text.delta") {
|
} else if (event.type === "response.output_text.delta") {
|
||||||
if (currentItem && currentItem.type === "message") {
|
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||||
if (lastPart && lastPart.type === "output_text") {
|
if (lastPart && lastPart.type === "output_text") {
|
||||||
|
currentBlock.text += event.delta;
|
||||||
lastPart.text += event.delta;
|
lastPart.text += event.delta;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "text_delta",
|
type: "text_delta",
|
||||||
content: currentItem.content
|
|
||||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
|
||||||
.join(""),
|
|
||||||
delta: event.delta,
|
delta: event.delta,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (event.type === "response.refusal.delta") {
|
} else if (event.type === "response.refusal.delta") {
|
||||||
if (currentItem && currentItem.type === "message") {
|
if (currentItem && currentItem.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||||
if (lastPart && lastPart.type === "refusal") {
|
if (lastPart && lastPart.type === "refusal") {
|
||||||
|
currentBlock.text += event.delta;
|
||||||
lastPart.refusal += event.delta;
|
lastPart.refusal += event.delta;
|
||||||
options?.onEvent?.({
|
stream.push({
|
||||||
type: "text_delta",
|
type: "text_delta",
|
||||||
content: currentItem.content
|
|
||||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
|
||||||
.join(""),
|
|
||||||
delta: event.delta,
|
delta: event.delta,
|
||||||
|
partial: output,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -218,14 +173,24 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
else if (event.type === "response.output_item.done") {
|
else if (event.type === "response.output_item.done") {
|
||||||
const item = event.item;
|
const item = event.item;
|
||||||
|
|
||||||
if (item.type === "reasoning") {
|
if (item.type === "reasoning" && currentBlock && currentBlock.type === "thinking") {
|
||||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||||
const thinkingContent = item.summary?.map((s) => s.text).join("\n\n") || "";
|
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||||
options?.onEvent?.({ type: "thinking_end", content: thinkingContent });
|
stream.push({
|
||||||
} else if (item.type === "message") {
|
type: "thinking_end",
|
||||||
outputItems[outputItems.length - 1] = item; // Update with final item
|
content: currentBlock.thinking,
|
||||||
const textContent = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
partial: output,
|
||||||
options?.onEvent?.({ type: "text_end", content: textContent });
|
});
|
||||||
|
currentBlock = null;
|
||||||
|
} else if (item.type === "message" && currentBlock && currentBlock.type === "text") {
|
||||||
|
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||||
|
currentBlock.textSignature = item.id;
|
||||||
|
stream.push({
|
||||||
|
type: "text_end",
|
||||||
|
content: currentBlock.text,
|
||||||
|
partial: output,
|
||||||
|
});
|
||||||
|
currentBlock = null;
|
||||||
} else if (item.type === "function_call") {
|
} else if (item.type === "function_call") {
|
||||||
const toolCall: ToolCall = {
|
const toolCall: ToolCall = {
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
|
|
@ -233,8 +198,8 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
name: item.name,
|
name: item.name,
|
||||||
arguments: JSON.parse(item.arguments),
|
arguments: JSON.parse(item.arguments),
|
||||||
};
|
};
|
||||||
options?.onEvent?.({ type: "toolCall", toolCall });
|
output.content.push(toolCall);
|
||||||
outputItems.push(item);
|
stream.push({ type: "toolCall", toolCall, partial: output });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Handle completion
|
// Handle completion
|
||||||
|
|
@ -249,10 +214,10 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
calculateCost(this.modelInfo, output.usage);
|
calculateCost(model, output.usage);
|
||||||
// Map status to stop reason
|
// Map status to stop reason
|
||||||
output.stopReason = this.mapStopReason(response?.status);
|
output.stopReason = mapStopReason(response?.status);
|
||||||
if (outputItems.some((b) => b.type === "function_call") && output.stopReason === "stop") {
|
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
|
||||||
output.stopReason = "toolUse";
|
output.stopReason = "toolUse";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -260,173 +225,215 @@ export class OpenAIResponsesLLM implements LLM<OpenAIResponsesLLMOptions> {
|
||||||
else if (event.type === "error") {
|
else if (event.type === "error") {
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = `Code ${event.code}: ${event.message}` || "Unknown error";
|
output.error = `Code ${event.code}: ${event.message}` || "Unknown error";
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
|
stream.end();
|
||||||
return output;
|
return output;
|
||||||
} else if (event.type === "response.failed") {
|
} else if (event.type === "response.failed") {
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = "Unknown error";
|
output.error = "Unknown error";
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
|
stream.end();
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert output items to blocks
|
|
||||||
for (const item of outputItems) {
|
|
||||||
if (item.type === "reasoning") {
|
|
||||||
output.content.push({
|
|
||||||
type: "thinking",
|
|
||||||
thinking: item.summary?.map((s: any) => s.text).join("\n\n") || "",
|
|
||||||
thinkingSignature: JSON.stringify(item), // Full item for resubmission
|
|
||||||
});
|
|
||||||
} else if (item.type === "message") {
|
|
||||||
output.content.push({
|
|
||||||
type: "text",
|
|
||||||
text: item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join(""),
|
|
||||||
textSignature: item.id, // ID for resubmission
|
|
||||||
});
|
|
||||||
} else if (item.type === "function_call") {
|
|
||||||
output.content.push({
|
|
||||||
type: "toolCall",
|
|
||||||
id: item.call_id + "|" + item.id,
|
|
||||||
name: item.name,
|
|
||||||
arguments: JSON.parse(item.arguments),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options?.signal?.aborted) {
|
if (options?.signal?.aborted) {
|
||||||
throw new Error("Request was aborted");
|
throw new Error("Request was aborted");
|
||||||
}
|
}
|
||||||
|
|
||||||
options?.onEvent?.({ type: "done", reason: output.stopReason, message: output });
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||||
return output;
|
stream.end();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
output.stopReason = "error";
|
output.stopReason = "error";
|
||||||
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
output.error = error instanceof Error ? error.message : JSON.stringify(error);
|
||||||
options?.onEvent?.({ type: "error", error: output.error });
|
stream.push({ type: "error", error: output.error, partial: output });
|
||||||
return output;
|
stream.end();
|
||||||
}
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
function createClient(model: Model<"openai-responses">, apiKey?: string) {
|
||||||
|
if (!apiKey) {
|
||||||
|
if (!process.env.OPENAI_API_KEY) {
|
||||||
|
throw new Error(
|
||||||
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
apiKey = process.env.OPENAI_API_KEY;
|
||||||
|
}
|
||||||
|
return new OpenAI({ apiKey, baseURL: model.baseUrl, dangerouslyAllowBrowser: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||||
|
const messages = convertMessages(model, context);
|
||||||
|
|
||||||
|
const params: ResponseCreateParamsStreaming = {
|
||||||
|
model: model.id,
|
||||||
|
input: messages,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options?.maxTokens) {
|
||||||
|
params.max_output_tokens = options?.maxTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertToInput(messages: Message[], systemPrompt?: string): ResponseInput {
|
if (options?.temperature !== undefined) {
|
||||||
const input: ResponseInput = [];
|
params.temperature = options?.temperature;
|
||||||
|
}
|
||||||
|
|
||||||
// Transform messages for cross-provider compatibility
|
if (context.tools) {
|
||||||
const transformedMessages = transformMessages(messages, this.modelInfo, this.getApi());
|
params.tools = convertTools(context.tools);
|
||||||
|
}
|
||||||
|
|
||||||
// Add system prompt if provided
|
if (model.reasoning) {
|
||||||
if (systemPrompt) {
|
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||||
const role = this.modelInfo?.reasoning ? "developer" : "system";
|
params.reasoning = {
|
||||||
input.push({
|
effort: options?.reasoningEffort || "medium",
|
||||||
role,
|
summary: options?.reasoningSummary || "auto",
|
||||||
content: systemPrompt,
|
};
|
||||||
});
|
params.include = ["reasoning.encrypted_content"];
|
||||||
}
|
} else {
|
||||||
|
params.reasoning = {
|
||||||
|
effort: model.name.startsWith("gpt-5") ? "minimal" : null,
|
||||||
|
summary: null,
|
||||||
|
};
|
||||||
|
|
||||||
// Convert messages
|
if (model.name.startsWith("gpt-5")) {
|
||||||
for (const msg of transformedMessages) {
|
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||||
if (msg.role === "user") {
|
messages.push({
|
||||||
// Handle both string and array content
|
role: "developer",
|
||||||
if (typeof msg.content === "string") {
|
content: [
|
||||||
input.push({
|
{
|
||||||
role: "user",
|
type: "input_text",
|
||||||
content: [{ type: "input_text", text: msg.content }],
|
text: "# Juice: 0 !important",
|
||||||
});
|
},
|
||||||
} else {
|
],
|
||||||
// Convert array content to OpenAI Responses format
|
|
||||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
|
||||||
if (item.type === "text") {
|
|
||||||
return {
|
|
||||||
type: "input_text",
|
|
||||||
text: item.text,
|
|
||||||
} satisfies ResponseInputText;
|
|
||||||
} else {
|
|
||||||
// Image content - OpenAI Responses uses data URLs
|
|
||||||
return {
|
|
||||||
type: "input_image",
|
|
||||||
detail: "auto",
|
|
||||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
|
||||||
} satisfies ResponseInputImage;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const filteredContent = !this.modelInfo?.input.includes("image")
|
|
||||||
? content.filter((c) => c.type !== "input_image")
|
|
||||||
: content;
|
|
||||||
input.push({
|
|
||||||
role: "user",
|
|
||||||
content: filteredContent,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else if (msg.role === "assistant") {
|
|
||||||
// Process content blocks in order
|
|
||||||
const output: ResponseInput = [];
|
|
||||||
|
|
||||||
for (const block of msg.content) {
|
|
||||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
|
||||||
if (block.type === "thinking" && msg.stopReason !== "error") {
|
|
||||||
// Push the full reasoning item(s) from signature
|
|
||||||
if (block.thinkingSignature) {
|
|
||||||
const reasoningItem = JSON.parse(block.thinkingSignature);
|
|
||||||
output.push(reasoningItem);
|
|
||||||
}
|
|
||||||
} else if (block.type === "text") {
|
|
||||||
const textBlock = block as TextContent;
|
|
||||||
output.push({
|
|
||||||
type: "message",
|
|
||||||
role: "assistant",
|
|
||||||
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
|
||||||
status: "completed",
|
|
||||||
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
|
||||||
} satisfies ResponseOutputMessage);
|
|
||||||
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
|
||||||
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
|
||||||
const toolCall = block as ToolCall;
|
|
||||||
output.push({
|
|
||||||
type: "function_call",
|
|
||||||
id: toolCall.id.split("|")[1], // Extract original ID
|
|
||||||
call_id: toolCall.id.split("|")[0], // Extract call session ID
|
|
||||||
name: toolCall.name,
|
|
||||||
arguments: JSON.stringify(toolCall.arguments),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add all output items to input
|
|
||||||
input.push(...output);
|
|
||||||
} else if (msg.role === "toolResult") {
|
|
||||||
// Tool results are sent as function_call_output
|
|
||||||
input.push({
|
|
||||||
type: "function_call_output",
|
|
||||||
call_id: msg.toolCallId.split("|")[0], // Extract call session ID
|
|
||||||
output: msg.content,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return input;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private convertTools(tools: Tool[]): OpenAITool[] {
|
return params;
|
||||||
return tools.map((tool) => ({
|
}
|
||||||
type: "function",
|
|
||||||
name: tool.name,
|
function convertMessages(model: Model<"openai-responses">, context: Context): ResponseInput {
|
||||||
description: tool.description,
|
const messages: ResponseInput = [];
|
||||||
parameters: tool.parameters,
|
|
||||||
strict: null,
|
const transformedMessages = transformMessages(context.messages, model);
|
||||||
}));
|
|
||||||
|
if (context.systemPrompt) {
|
||||||
|
const role = model.reasoning ? "developer" : "system";
|
||||||
|
messages.push({
|
||||||
|
role,
|
||||||
|
content: context.systemPrompt,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private mapStopReason(status: string | undefined): StopReason {
|
for (const msg of transformedMessages) {
|
||||||
switch (status) {
|
if (msg.role === "user") {
|
||||||
case "completed":
|
if (typeof msg.content === "string") {
|
||||||
return "stop";
|
messages.push({
|
||||||
case "incomplete":
|
role: "user",
|
||||||
return "length";
|
content: [{ type: "input_text", text: msg.content }],
|
||||||
case "failed":
|
});
|
||||||
case "cancelled":
|
} else {
|
||||||
return "error";
|
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||||
default:
|
if (item.type === "text") {
|
||||||
return "stop";
|
return {
|
||||||
|
type: "input_text",
|
||||||
|
text: item.text,
|
||||||
|
} satisfies ResponseInputText;
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
type: "input_image",
|
||||||
|
detail: "auto",
|
||||||
|
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||||
|
} satisfies ResponseInputImage;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const filteredContent = !model.input.includes("image")
|
||||||
|
? content.filter((c) => c.type !== "input_image")
|
||||||
|
: content;
|
||||||
|
if (filteredContent.length === 0) continue;
|
||||||
|
messages.push({
|
||||||
|
role: "user",
|
||||||
|
content: filteredContent,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (msg.role === "assistant") {
|
||||||
|
const output: ResponseInput = [];
|
||||||
|
|
||||||
|
for (const block of msg.content) {
|
||||||
|
// Do not submit thinking blocks if the completion had an error (i.e. abort)
|
||||||
|
if (block.type === "thinking" && msg.stopReason !== "error") {
|
||||||
|
if (block.thinkingSignature) {
|
||||||
|
const reasoningItem = JSON.parse(block.thinkingSignature);
|
||||||
|
output.push(reasoningItem);
|
||||||
|
}
|
||||||
|
} else if (block.type === "text") {
|
||||||
|
const textBlock = block as TextContent;
|
||||||
|
output.push({
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: [{ type: "output_text", text: textBlock.text, annotations: [] }],
|
||||||
|
status: "completed",
|
||||||
|
id: textBlock.textSignature || "msg_" + Math.random().toString(36).substring(2, 15),
|
||||||
|
} satisfies ResponseOutputMessage);
|
||||||
|
// Do not submit toolcall blocks if the completion had an error (i.e. abort)
|
||||||
|
} else if (block.type === "toolCall" && msg.stopReason !== "error") {
|
||||||
|
const toolCall = block as ToolCall;
|
||||||
|
output.push({
|
||||||
|
type: "function_call",
|
||||||
|
id: toolCall.id.split("|")[1],
|
||||||
|
call_id: toolCall.id.split("|")[0],
|
||||||
|
name: toolCall.name,
|
||||||
|
arguments: JSON.stringify(toolCall.arguments),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (output.length === 0) continue;
|
||||||
|
messages.push(...output);
|
||||||
|
} else if (msg.role === "toolResult") {
|
||||||
|
messages.push({
|
||||||
|
type: "function_call_output",
|
||||||
|
call_id: msg.toolCallId.split("|")[0],
|
||||||
|
output: msg.content,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertTools(tools: Tool[]): OpenAITool[] {
|
||||||
|
return tools.map((tool) => ({
|
||||||
|
type: "function",
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: tool.parameters,
|
||||||
|
strict: null,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
|
||||||
|
if (!status) return "stop";
|
||||||
|
switch (status) {
|
||||||
|
case "completed":
|
||||||
|
return "stop";
|
||||||
|
case "incomplete":
|
||||||
|
return "length";
|
||||||
|
case "failed":
|
||||||
|
case "cancelled":
|
||||||
|
return "error";
|
||||||
|
// These two are wonky ...
|
||||||
|
case "in_progress":
|
||||||
|
case "queued":
|
||||||
|
return "stop";
|
||||||
|
default: {
|
||||||
|
const _exhaustive: never = status;
|
||||||
|
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,6 @@
|
||||||
import type { AssistantMessage, Message, Model } from "../types.js";
|
import type { Api, AssistantMessage, Message, Model } from "../types.js";
|
||||||
|
|
||||||
/**
|
export function transformMessages<TApi extends Api>(messages: Message[], model: Model<TApi>): Message[] {
|
||||||
* Transform messages for cross-provider compatibility.
|
|
||||||
*
|
|
||||||
* - User and toolResult messages are copied verbatim
|
|
||||||
* - Assistant messages:
|
|
||||||
* - If from the same provider/model, copied as-is
|
|
||||||
* - If from different provider/model, thinking blocks are converted to text blocks with <thinking></thinking> tags
|
|
||||||
*
|
|
||||||
* @param messages The messages to transform
|
|
||||||
* @param model The target model that will process these messages
|
|
||||||
* @returns A copy of the messages array with transformations applied
|
|
||||||
*/
|
|
||||||
export function transformMessages(messages: Message[], model: Model, api: string): Message[] {
|
|
||||||
return messages.map((msg) => {
|
return messages.map((msg) => {
|
||||||
// User and toolResult messages pass through unchanged
|
// User and toolResult messages pass through unchanged
|
||||||
if (msg.role === "user" || msg.role === "toolResult") {
|
if (msg.role === "user" || msg.role === "toolResult") {
|
||||||
|
|
@ -24,7 +12,7 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
||||||
const assistantMsg = msg as AssistantMessage;
|
const assistantMsg = msg as AssistantMessage;
|
||||||
|
|
||||||
// If message is from the same provider and API, keep as is
|
// If message is from the same provider and API, keep as is
|
||||||
if (assistantMsg.provider === model.provider && assistantMsg.api === api) {
|
if (assistantMsg.provider === model.provider && assistantMsg.api === model.api) {
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,8 +35,6 @@ export function transformMessages(messages: Message[], model: Model, api: string
|
||||||
content: transformedContent,
|
content: transformedContent,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not reach here, but return as-is for safety
|
|
||||||
return msg;
|
return msg;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,27 @@
|
||||||
export type KnownApi = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
import type { AnthropicOptions } from "./providers/anthropic";
|
||||||
export type Api = KnownApi | string;
|
import type { GoogleOptions } from "./providers/google";
|
||||||
|
import type { OpenAICompletionsOptions } from "./providers/openai-completions";
|
||||||
|
import type { OpenAIResponsesOptions } from "./providers/openai-responses";
|
||||||
|
|
||||||
|
export type Api = "openai-completions" | "openai-responses" | "anthropic-messages" | "google-generative-ai";
|
||||||
|
|
||||||
|
export interface ApiOptionsMap {
|
||||||
|
"anthropic-messages": AnthropicOptions;
|
||||||
|
"openai-completions": OpenAICompletionsOptions;
|
||||||
|
"openai-responses": OpenAIResponsesOptions;
|
||||||
|
"google-generative-ai": GoogleOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys
|
||||||
|
type _CheckExhaustive = ApiOptionsMap extends Record<Api, GenerateOptions>
|
||||||
|
? Record<Api, GenerateOptions> extends ApiOptionsMap
|
||||||
|
? true
|
||||||
|
: ["ApiOptionsMap is missing some KnownApi values", Exclude<Api, keyof ApiOptionsMap>]
|
||||||
|
: ["ApiOptionsMap doesn't extend Record<KnownApi, GenerateOptions>"];
|
||||||
|
const _exhaustive: _CheckExhaustive = true;
|
||||||
|
|
||||||
|
// Helper type to get options for a specific API
|
||||||
|
export type OptionsForApi<TApi extends Api> = ApiOptionsMap[TApi];
|
||||||
|
|
||||||
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
|
export type KnownProvider = "anthropic" | "google" | "openai" | "xai" | "groq" | "cerebras" | "openrouter";
|
||||||
export type Provider = KnownProvider | string;
|
export type Provider = KnownProvider | string;
|
||||||
|
|
@ -21,31 +43,17 @@ export interface GenerateOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unified options with reasoning (what public generate() accepts)
|
// Unified options with reasoning (what public generate() accepts)
|
||||||
export interface GenerateOptionsUnified extends GenerateOptions {
|
export interface SimpleGenerateOptions extends GenerateOptions {
|
||||||
reasoning?: ReasoningEffort;
|
reasoning?: ReasoningEffort;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generic GenerateFunction with typed options
|
// Generic GenerateFunction with typed options
|
||||||
export type GenerateFunction<TOptions extends GenerateOptions = GenerateOptions> = (
|
export type GenerateFunction<TApi extends Api> = (
|
||||||
model: Model,
|
model: Model<TApi>,
|
||||||
context: Context,
|
context: Context,
|
||||||
options: TOptions,
|
options: OptionsForApi<TApi>,
|
||||||
) => GenerateStream;
|
) => GenerateStream;
|
||||||
|
|
||||||
// Legacy LLM interface (to be removed)
|
|
||||||
export interface LLMOptions {
|
|
||||||
temperature?: number;
|
|
||||||
maxTokens?: number;
|
|
||||||
onEvent?: (event: AssistantMessageEvent) => void;
|
|
||||||
signal?: AbortSignal;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LLM<T extends LLMOptions> {
|
|
||||||
generate(request: Context, options?: T): Promise<AssistantMessage>;
|
|
||||||
getModel(): Model;
|
|
||||||
getApi(): string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface TextContent {
|
export interface TextContent {
|
||||||
type: "text";
|
type: "text";
|
||||||
text: string;
|
text: string;
|
||||||
|
|
@ -100,7 +108,7 @@ export interface AssistantMessage {
|
||||||
model: string;
|
model: string;
|
||||||
usage: Usage;
|
usage: Usage;
|
||||||
stopReason: StopReason;
|
stopReason: StopReason;
|
||||||
error?: string | Error;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolResultMessage {
|
export interface ToolResultMessage {
|
||||||
|
|
@ -138,10 +146,10 @@ export type AssistantMessageEvent =
|
||||||
| { type: "error"; error: string; partial: AssistantMessage };
|
| { type: "error"; error: string; partial: AssistantMessage };
|
||||||
|
|
||||||
// Model interface for the unified model system
|
// Model interface for the unified model system
|
||||||
export interface Model {
|
export interface Model<TApi extends Api> {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
api: Api;
|
api: TApi;
|
||||||
provider: Provider;
|
provider: Provider;
|
||||||
baseUrl: string;
|
baseUrl: string;
|
||||||
reasoning: boolean;
|
reasoning: boolean;
|
||||||
|
|
|
||||||
|
|
@ -1,128 +1,103 @@
|
||||||
import { describe, it, beforeAll, expect } from "vitest";
|
import { beforeAll, describe, expect, it } from "vitest";
|
||||||
import { GoogleLLM } from "../src/providers/google.js";
|
import { complete, stream } from "../src/generate.js";
|
||||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
|
||||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
|
||||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
|
||||||
import type { LLM, LLMOptions, Context } from "../src/types.js";
|
|
||||||
import { getModel } from "../src/models.js";
|
import { getModel } from "../src/models.js";
|
||||||
|
import type { Api, Context, Model, OptionsForApi } from "../src/types.js";
|
||||||
|
|
||||||
async function testAbortSignal<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testAbortSignal<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{
|
messages: [
|
||||||
role: "user",
|
{
|
||||||
content: "What is 15 + 27? Think step by step. Then list 50 first names."
|
role: "user",
|
||||||
}]
|
content: "What is 15 + 27? Think step by step. Then list 50 first names.",
|
||||||
};
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
let abortFired = false;
|
let abortFired = false;
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
const response = await llm.generate(context, {
|
const response = await stream(llm, context, { ...options, signal: controller.signal });
|
||||||
...options,
|
for await (const event of response) {
|
||||||
signal: controller.signal,
|
if (abortFired) return;
|
||||||
onEvent: (event) => {
|
setTimeout(() => controller.abort(), 3000);
|
||||||
// console.log(JSON.stringify(event, null, 2));
|
abortFired = true;
|
||||||
if (abortFired) return;
|
break;
|
||||||
setTimeout(() => controller.abort(), 2000);
|
}
|
||||||
abortFired = true;
|
const msg = await response.finalMessage();
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// If we get here without throwing, the abort didn't work
|
// If we get here without throwing, the abort didn't work
|
||||||
expect(response.stopReason).toBe("error");
|
expect(msg.stopReason).toBe("error");
|
||||||
expect(response.content.length).toBeGreaterThan(0);
|
expect(msg.content.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
context.messages.push(response);
|
context.messages.push(msg);
|
||||||
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
context.messages.push({ role: "user", content: "Please continue, but only generate 5 names." });
|
||||||
|
|
||||||
// Ensure we can still make requests after abort
|
const followUp = await complete(llm, context, options);
|
||||||
const followUp = await llm.generate(context, options);
|
expect(followUp.stopReason).toBe("stop");
|
||||||
expect(followUp.stopReason).toBe("stop");
|
expect(followUp.content.length).toBeGreaterThan(0);
|
||||||
expect(followUp.content.length).toBeGreaterThan(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function testImmediateAbort<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testImmediateAbort<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
|
|
||||||
// Abort immediately
|
controller.abort();
|
||||||
controller.abort();
|
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{ role: "user", content: "Hello" }]
|
messages: [{ role: "user", content: "Hello" }],
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await llm.generate(context, {
|
const response = await complete(llm, context, { ...options, signal: controller.signal });
|
||||||
...options,
|
expect(response.stopReason).toBe("error");
|
||||||
signal: controller.signal
|
|
||||||
});
|
|
||||||
expect(response.stopReason).toBe("error");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("AI Providers Abort Tests", () => {
|
describe("AI Providers Abort Tests", () => {
|
||||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Abort", () => {
|
||||||
let llm: GoogleLLM;
|
const llm = getModel("google", "gemini-2.5-flash");
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should abort mid-stream", async () => {
|
||||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
await testAbortSignal(llm, { thinking: { enabled: true } });
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should abort mid-stream", async () => {
|
it("should handle immediate abort", async () => {
|
||||||
await testAbortSignal(llm, { thinking: { enabled: true } });
|
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle immediate abort", async () => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
||||||
await testImmediateAbort(llm, { thinking: { enabled: true } });
|
const llm: Model<"openai-completions"> = {
|
||||||
});
|
...getModel("openai", "gpt-4o-mini")!,
|
||||||
});
|
api: "openai-completions",
|
||||||
|
};
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Abort", () => {
|
it("should abort mid-stream", async () => {
|
||||||
let llm: OpenAICompletionsLLM;
|
await testAbortSignal(llm);
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle immediate abort", async () => {
|
||||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
await testImmediateAbort(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should abort mid-stream", async () => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
||||||
await testAbortSignal(llm);
|
const llm = getModel("openai", "gpt-5-mini");
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle immediate abort", async () => {
|
it("should abort mid-stream", async () => {
|
||||||
await testImmediateAbort(llm);
|
await testAbortSignal(llm);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Abort", () => {
|
it("should handle immediate abort", async () => {
|
||||||
let llm: OpenAIResponsesLLM;
|
await testImmediateAbort(llm);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
||||||
const model = getModel("openai", "gpt-5-mini");
|
const llm = getModel("anthropic", "claude-opus-4-1-20250805");
|
||||||
if (!model) {
|
|
||||||
throw new Error("Model not found");
|
|
||||||
}
|
|
||||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should abort mid-stream", async () => {
|
it("should abort mid-stream", async () => {
|
||||||
await testAbortSignal(llm, {});
|
await testAbortSignal(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle immediate abort", async () => {
|
it("should handle immediate abort", async () => {
|
||||||
await testImmediateAbort(llm, {});
|
await testImmediateAbort(llm, { thinkingEnabled: true, thinkingBudgetTokens: 2048 });
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Abort", () => {
|
|
||||||
let llm: AnthropicLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new AnthropicLLM(getModel("anthropic", "claude-opus-4-1-20250805")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should abort mid-stream", async () => {
|
|
||||||
await testAbortSignal(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle immediate abort", async () => {
|
|
||||||
await testImmediateAbort(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
@ -1,313 +1,265 @@
|
||||||
import { describe, it, beforeAll, expect } from "vitest";
|
import { describe, expect, it } from "vitest";
|
||||||
import { GoogleLLM } from "../src/providers/google.js";
|
import { complete } from "../src/generate.js";
|
||||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
|
||||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
|
||||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
|
||||||
import type { LLM, LLMOptions, Context, UserMessage, AssistantMessage } from "../src/types.js";
|
|
||||||
import { getModel } from "../src/models.js";
|
import { getModel } from "../src/models.js";
|
||||||
|
import type { Api, AssistantMessage, Context, Model, OptionsForApi, UserMessage } from "../src/types.js";
|
||||||
|
|
||||||
async function testEmptyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testEmptyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
// Test with completely empty content array
|
// Test with completely empty content array
|
||||||
const emptyMessage: UserMessage = {
|
const emptyMessage: UserMessage = {
|
||||||
role: "user",
|
role: "user",
|
||||||
content: []
|
content: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [emptyMessage]
|
messages: [emptyMessage],
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await llm.generate(context, options);
|
const response = await complete(llm, context, options);
|
||||||
|
|
||||||
// Should either handle gracefully or return an error
|
// Should either handle gracefully or return an error
|
||||||
expect(response).toBeDefined();
|
expect(response).toBeDefined();
|
||||||
expect(response.role).toBe("assistant");
|
expect(response.role).toBe("assistant");
|
||||||
|
// Should handle empty string gracefully
|
||||||
// Most providers should return an error or empty response
|
if (response.stopReason === "error") {
|
||||||
if (response.stopReason === "error") {
|
expect(response.error).toBeDefined();
|
||||||
expect(response.error).toBeDefined();
|
} else {
|
||||||
} else {
|
expect(response.content).toBeDefined();
|
||||||
// If it didn't error, it should have some content or gracefully handle empty
|
}
|
||||||
expect(response.content).toBeDefined();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function testEmptyStringMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testEmptyStringMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
// Test with empty string content
|
// Test with empty string content
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{
|
messages: [
|
||||||
role: "user",
|
{
|
||||||
content: ""
|
role: "user",
|
||||||
}]
|
content: "",
|
||||||
};
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
const response = await llm.generate(context, options);
|
const response = await complete(llm, context, options);
|
||||||
|
|
||||||
expect(response).toBeDefined();
|
expect(response).toBeDefined();
|
||||||
expect(response.role).toBe("assistant");
|
expect(response.role).toBe("assistant");
|
||||||
|
|
||||||
// Should handle empty string gracefully
|
// Should handle empty string gracefully
|
||||||
if (response.stopReason === "error") {
|
if (response.stopReason === "error") {
|
||||||
expect(response.error).toBeDefined();
|
expect(response.error).toBeDefined();
|
||||||
} else {
|
} else {
|
||||||
expect(response.content).toBeDefined();
|
expect(response.content).toBeDefined();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function testWhitespaceOnlyMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testWhitespaceOnlyMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
// Test with whitespace-only content
|
// Test with whitespace-only content
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{
|
messages: [
|
||||||
role: "user",
|
{
|
||||||
content: " \n\t "
|
role: "user",
|
||||||
}]
|
content: " \n\t ",
|
||||||
};
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
const response = await llm.generate(context, options);
|
const response = await complete(llm, context, options);
|
||||||
|
|
||||||
expect(response).toBeDefined();
|
expect(response).toBeDefined();
|
||||||
expect(response.role).toBe("assistant");
|
expect(response.role).toBe("assistant");
|
||||||
|
|
||||||
// Should handle whitespace-only gracefully
|
// Should handle whitespace-only gracefully
|
||||||
if (response.stopReason === "error") {
|
if (response.stopReason === "error") {
|
||||||
expect(response.error).toBeDefined();
|
expect(response.error).toBeDefined();
|
||||||
} else {
|
} else {
|
||||||
expect(response.content).toBeDefined();
|
expect(response.content).toBeDefined();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function testEmptyAssistantMessage<T extends LLMOptions>(llm: LLM<T>, options: T = {} as T) {
|
async function testEmptyAssistantMessage<TApi extends Api>(llm: Model<TApi>, options: OptionsForApi<TApi> = {}) {
|
||||||
// Test with empty assistant message in conversation flow
|
// Test with empty assistant message in conversation flow
|
||||||
// User -> Empty Assistant -> User
|
// User -> Empty Assistant -> User
|
||||||
const emptyAssistant: AssistantMessage = {
|
const emptyAssistant: AssistantMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [],
|
content: [],
|
||||||
api: llm.getApi(),
|
api: llm.api,
|
||||||
provider: llm.getModel().provider,
|
provider: llm.provider,
|
||||||
model: llm.getModel().id,
|
model: llm.id,
|
||||||
usage: {
|
usage: {
|
||||||
input: 10,
|
input: 10,
|
||||||
output: 0,
|
output: 0,
|
||||||
cacheRead: 0,
|
cacheRead: 0,
|
||||||
cacheWrite: 0,
|
cacheWrite: 0,
|
||||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
},
|
},
|
||||||
stopReason: "stop"
|
stopReason: "stop",
|
||||||
};
|
};
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: "Hello, how are you?"
|
content: "Hello, how are you?",
|
||||||
},
|
},
|
||||||
emptyAssistant,
|
emptyAssistant,
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: "Please respond this time."
|
content: "Please respond this time.",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await llm.generate(context, options);
|
const response = await complete(llm, context, options);
|
||||||
|
|
||||||
expect(response).toBeDefined();
|
expect(response).toBeDefined();
|
||||||
expect(response.role).toBe("assistant");
|
expect(response.role).toBe("assistant");
|
||||||
|
|
||||||
// Should handle empty assistant message in context gracefully
|
// Should handle empty assistant message in context gracefully
|
||||||
if (response.stopReason === "error") {
|
if (response.stopReason === "error") {
|
||||||
expect(response.error).toBeDefined();
|
expect(response.error).toBeDefined();
|
||||||
} else {
|
} else {
|
||||||
expect(response.content).toBeDefined();
|
expect(response.content).toBeDefined();
|
||||||
expect(response.content.length).toBeGreaterThan(0);
|
expect(response.content.length).toBeGreaterThan(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("AI Providers Empty Message Tests", () => {
|
describe("AI Providers Empty Message Tests", () => {
|
||||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Empty Messages", () => {
|
||||||
let llm: GoogleLLM;
|
const llm = getModel("google", "gemini-2.5-flash");
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle empty content array", async () => {
|
||||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
await testEmptyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
it("should handle empty string content", async () => {
|
||||||
await testEmptyMessage(llm);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
await testEmptyStringMessage(llm);
|
await testWhitespaceOnlyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
||||||
await testEmptyAssistantMessage(llm);
|
const llm = getModel("openai", "gpt-4o-mini");
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Empty Messages", () => {
|
it("should handle empty content array", async () => {
|
||||||
let llm: OpenAICompletionsLLM;
|
await testEmptyMessage(llm);
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle empty string content", async () => {
|
||||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
await testEmptyMessage(llm);
|
await testWhitespaceOnlyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
await testEmptyStringMessage(llm);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
const llm = getModel("openai", "gpt-5-mini");
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
it("should handle empty content array", async () => {
|
||||||
await testEmptyAssistantMessage(llm);
|
await testEmptyMessage(llm);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Empty Messages", () => {
|
it("should handle empty string content", async () => {
|
||||||
let llm: OpenAIResponsesLLM;
|
await testEmptyStringMessage(llm);
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle whitespace-only content", async () => {
|
||||||
const model = getModel("openai", "gpt-5-mini");
|
await testWhitespaceOnlyMessage(llm);
|
||||||
if (!model) {
|
});
|
||||||
throw new Error("Model gpt-5-mini not found");
|
|
||||||
}
|
|
||||||
llm = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
await testEmptyMessage(llm);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
||||||
await testEmptyStringMessage(llm);
|
const llm = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
it("should handle empty content array", async () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
await testEmptyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
it("should handle empty string content", async () => {
|
||||||
await testEmptyAssistantMessage(llm);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider Empty Messages", () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
let llm: AnthropicLLM;
|
await testWhitespaceOnlyMessage(llm);
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
llm = new AnthropicLLM(getModel("anthropic", "claude-3-5-haiku-20241022")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
||||||
await testEmptyMessage(llm);
|
const llm = getModel("xai", "grok-3");
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
it("should handle empty content array", async () => {
|
||||||
await testEmptyStringMessage(llm);
|
await testEmptyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
it("should handle empty string content", async () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
await testEmptyAssistantMessage(llm);
|
await testWhitespaceOnlyMessage(llm);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
// Test with xAI/Grok if available
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider Empty Messages", () => {
|
await testEmptyAssistantMessage(llm);
|
||||||
let llm: OpenAICompletionsLLM;
|
});
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
||||||
const model = getModel("xai", "grok-3");
|
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||||
if (!model) {
|
|
||||||
throw new Error("Model grok-3 not found");
|
|
||||||
}
|
|
||||||
llm = new OpenAICompletionsLLM(model, process.env.XAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
it("should handle empty content array", async () => {
|
||||||
await testEmptyMessage(llm);
|
await testEmptyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
it("should handle empty string content", async () => {
|
||||||
await testEmptyStringMessage(llm);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
await testWhitespaceOnlyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
await testEmptyAssistantMessage(llm);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// Test with Groq if available
|
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
||||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider Empty Messages", () => {
|
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle empty content array", async () => {
|
||||||
const model = getModel("groq", "llama-3.3-70b-versatile");
|
await testEmptyMessage(llm);
|
||||||
if (!model) {
|
});
|
||||||
throw new Error("Model llama-3.3-70b-versatile not found");
|
|
||||||
}
|
|
||||||
llm = new OpenAICompletionsLLM(model, process.env.GROQ_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
it("should handle empty string content", async () => {
|
||||||
await testEmptyMessage(llm);
|
await testEmptyStringMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
it("should handle whitespace-only content", async () => {
|
||||||
await testEmptyStringMessage(llm);
|
await testWhitespaceOnlyMessage(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
it("should handle empty assistant message in conversation", async () => {
|
||||||
await testWhitespaceOnlyMessage(llm);
|
await testEmptyAssistantMessage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
|
||||||
await testEmptyAssistantMessage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test with Cerebras if available
|
|
||||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider Empty Messages", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
const model = getModel("cerebras", "gpt-oss-120b");
|
|
||||||
if (!model) {
|
|
||||||
throw new Error("Model gpt-oss-120b not found");
|
|
||||||
}
|
|
||||||
llm = new OpenAICompletionsLLM(model, process.env.CEREBRAS_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty content array", async () => {
|
|
||||||
await testEmptyMessage(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty string content", async () => {
|
|
||||||
await testEmptyStringMessage(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle whitespace-only content", async () => {
|
|
||||||
await testWhitespaceOnlyMessage(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty assistant message in conversation", async () => {
|
|
||||||
await testEmptyAssistantMessage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
@ -1,311 +1,612 @@
|
||||||
import { describe, it, beforeAll, expect } from "vitest";
|
import { type ChildProcess, execSync, spawn } from "child_process";
|
||||||
import { getModel } from "../src/models.js";
|
|
||||||
import { generate, generateComplete } from "../src/generate.js";
|
|
||||||
import type { Context, Tool, GenerateOptionsUnified, Model, ImageContent, GenerateStream, GenerateOptions } from "../src/types.js";
|
|
||||||
import { readFileSync } from "fs";
|
import { readFileSync } from "fs";
|
||||||
import { join, dirname } from "path";
|
import { dirname, join } from "path";
|
||||||
import { fileURLToPath } from "url";
|
import { fileURLToPath } from "url";
|
||||||
|
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||||
|
import { complete, stream } from "../src/generate.js";
|
||||||
|
import { getModel } from "../src/models.js";
|
||||||
|
import type { Api, Context, ImageContent, Model, OptionsForApi, Tool } from "../src/types.js";
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url);
|
const __filename = fileURLToPath(import.meta.url);
|
||||||
const __dirname = dirname(__filename);
|
const __dirname = dirname(__filename);
|
||||||
|
|
||||||
// Calculator tool definition (same as examples)
|
// Calculator tool definition (same as examples)
|
||||||
const calculatorTool: Tool = {
|
const calculatorTool: Tool = {
|
||||||
name: "calculator",
|
name: "calculator",
|
||||||
description: "Perform basic arithmetic operations",
|
description: "Perform basic arithmetic operations",
|
||||||
parameters: {
|
parameters: {
|
||||||
type: "object",
|
type: "object",
|
||||||
properties: {
|
properties: {
|
||||||
a: { type: "number", description: "First number" },
|
a: { type: "number", description: "First number" },
|
||||||
b: { type: "number", description: "Second number" },
|
b: { type: "number", description: "Second number" },
|
||||||
operation: {
|
operation: {
|
||||||
type: "string",
|
type: "string",
|
||||||
enum: ["add", "subtract", "multiply", "divide"],
|
enum: ["add", "subtract", "multiply", "divide"],
|
||||||
description: "The operation to perform"
|
description: "The operation to perform",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
required: ["a", "b", "operation"]
|
required: ["a", "b", "operation"],
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
async function basicTextGeneration<P extends GenerateOptions>(model: Model, options?: P) {
|
async function basicTextGeneration<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
systemPrompt: "You are a helpful assistant. Be concise.",
|
||||||
messages: [
|
messages: [{ role: "user", content: "Reply with exactly: 'Hello test successful'" }],
|
||||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
};
|
||||||
]
|
const response = await complete(model, context, options);
|
||||||
};
|
|
||||||
|
|
||||||
const response = await generateComplete(model, context, options);
|
expect(response.role).toBe("assistant");
|
||||||
|
expect(response.content).toBeTruthy();
|
||||||
|
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
||||||
|
expect(response.usage.output).toBeGreaterThan(0);
|
||||||
|
expect(response.error).toBeFalsy();
|
||||||
|
expect(response.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain("Hello test successful");
|
||||||
|
|
||||||
expect(response.role).toBe("assistant");
|
context.messages.push(response);
|
||||||
expect(response.content).toBeTruthy();
|
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
||||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
|
||||||
expect(response.usage.output).toBeGreaterThan(0);
|
|
||||||
expect(response.error).toBeFalsy();
|
|
||||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
|
||||||
|
|
||||||
context.messages.push(response);
|
const secondResponse = await complete(model, context, options);
|
||||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
|
||||||
|
|
||||||
const secondResponse = await generateComplete(model, context, options);
|
expect(secondResponse.role).toBe("assistant");
|
||||||
|
expect(secondResponse.content).toBeTruthy();
|
||||||
expect(secondResponse.role).toBe("assistant");
|
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
||||||
expect(secondResponse.content).toBeTruthy();
|
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
||||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
expect(secondResponse.error).toBeFalsy();
|
||||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
expect(secondResponse.content.map((b) => (b.type === "text" ? b.text : "")).join("")).toContain(
|
||||||
expect(secondResponse.error).toBeFalsy();
|
"Goodbye test successful",
|
||||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleToolCall(model: Model, options?: GenerateOptionsUnified) {
|
async function handleToolCall<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
||||||
messages: [{
|
messages: [
|
||||||
role: "user",
|
{
|
||||||
content: "Calculate 15 + 27 using the calculator tool."
|
role: "user",
|
||||||
}],
|
content: "Calculate 15 + 27 using the calculator tool.",
|
||||||
tools: [calculatorTool]
|
},
|
||||||
};
|
],
|
||||||
|
tools: [calculatorTool],
|
||||||
|
};
|
||||||
|
|
||||||
const response = await generateComplete(model, context, options);
|
const response = await complete(model, context, options);
|
||||||
expect(response.stopReason).toBe("toolUse");
|
expect(response.stopReason).toBe("toolUse");
|
||||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
expect(response.content.some((b) => b.type === "toolCall")).toBeTruthy();
|
||||||
const toolCall = response.content.find(b => b.type == "toolCall");
|
const toolCall = response.content.find((b) => b.type === "toolCall");
|
||||||
if (toolCall && toolCall.type === "toolCall") {
|
if (toolCall && toolCall.type === "toolCall") {
|
||||||
expect(toolCall.name).toBe("calculator");
|
expect(toolCall.name).toBe("calculator");
|
||||||
expect(toolCall.id).toBeTruthy();
|
expect(toolCall.id).toBeTruthy();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleStreaming(model: Model, options?: GenerateOptionsUnified) {
|
async function handleStreaming<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
let textStarted = false;
|
let textStarted = false;
|
||||||
let textChunks = "";
|
let textChunks = "";
|
||||||
let textCompleted = false;
|
let textCompleted = false;
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
messages: [{ role: "user", content: "Count from 1 to 3" }],
|
||||||
};
|
};
|
||||||
|
|
||||||
const stream = generate(model, context, options);
|
const s = stream(model, context, options);
|
||||||
|
|
||||||
for await (const event of stream) {
|
for await (const event of s) {
|
||||||
if (event.type === "text_start") {
|
if (event.type === "text_start") {
|
||||||
textStarted = true;
|
textStarted = true;
|
||||||
} else if (event.type === "text_delta") {
|
} else if (event.type === "text_delta") {
|
||||||
textChunks += event.delta;
|
textChunks += event.delta;
|
||||||
} else if (event.type === "text_end") {
|
} else if (event.type === "text_end") {
|
||||||
textCompleted = true;
|
textCompleted = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await stream.finalMessage();
|
const response = await s.finalMessage();
|
||||||
|
|
||||||
expect(textStarted).toBe(true);
|
expect(textStarted).toBe(true);
|
||||||
expect(textChunks.length).toBeGreaterThan(0);
|
expect(textChunks.length).toBeGreaterThan(0);
|
||||||
expect(textCompleted).toBe(true);
|
expect(textCompleted).toBe(true);
|
||||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
expect(response.content.some((b) => b.type === "text")).toBeTruthy();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleThinking(model: Model, options: GenerateOptionsUnified) {
|
async function handleThinking<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
let thinkingStarted = false;
|
let thinkingStarted = false;
|
||||||
let thinkingChunks = "";
|
let thinkingChunks = "";
|
||||||
let thinkingCompleted = false;
|
let thinkingCompleted = false;
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
messages: [
|
||||||
};
|
{
|
||||||
|
role: "user",
|
||||||
|
content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
const stream = generate(model, context, options);
|
const s = stream(model, context, options);
|
||||||
|
|
||||||
for await (const event of stream) {
|
for await (const event of s) {
|
||||||
if (event.type === "thinking_start") {
|
if (event.type === "thinking_start") {
|
||||||
thinkingStarted = true;
|
thinkingStarted = true;
|
||||||
} else if (event.type === "thinking_delta") {
|
} else if (event.type === "thinking_delta") {
|
||||||
thinkingChunks += event.delta;
|
thinkingChunks += event.delta;
|
||||||
} else if (event.type === "thinking_end") {
|
} else if (event.type === "thinking_end") {
|
||||||
thinkingCompleted = true;
|
thinkingCompleted = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await stream.finalMessage();
|
const response = await s.finalMessage();
|
||||||
|
|
||||||
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
expect(response.stopReason, `Error: ${response.error}`).toBe("stop");
|
||||||
expect(thinkingStarted).toBe(true);
|
expect(thinkingStarted).toBe(true);
|
||||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
expect(thinkingChunks.length).toBeGreaterThan(0);
|
||||||
expect(thinkingCompleted).toBe(true);
|
expect(thinkingCompleted).toBe(true);
|
||||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
expect(response.content.some((b) => b.type === "thinking")).toBeTruthy();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleImage(model: Model, options?: GenerateOptionsUnified) {
|
async function handleImage<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
// Check if the model supports images
|
// Check if the model supports images
|
||||||
if (!model.input.includes("image")) {
|
if (!model.input.includes("image")) {
|
||||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the test image
|
// Read the test image
|
||||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
const imagePath = join(__dirname, "data", "red-circle.png");
|
||||||
const imageBuffer = readFileSync(imagePath);
|
const imageBuffer = readFileSync(imagePath);
|
||||||
const base64Image = imageBuffer.toString("base64");
|
const base64Image = imageBuffer.toString("base64");
|
||||||
|
|
||||||
const imageContent: ImageContent = {
|
const imageContent: ImageContent = {
|
||||||
type: "image",
|
type: "image",
|
||||||
data: base64Image,
|
data: base64Image,
|
||||||
mimeType: "image/png",
|
mimeType: "image/png",
|
||||||
};
|
};
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: [
|
content: [
|
||||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
{
|
||||||
imageContent,
|
type: "text",
|
||||||
],
|
text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...).",
|
||||||
},
|
},
|
||||||
],
|
imageContent,
|
||||||
};
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
const response = await generateComplete(model, context, options);
|
const response = await complete(model, context, options);
|
||||||
|
|
||||||
// Check the response mentions red and circle
|
// Check the response mentions red and circle
|
||||||
expect(response.content.length > 0).toBeTruthy();
|
expect(response.content.length > 0).toBeTruthy();
|
||||||
const textContent = response.content.find(b => b.type == "text");
|
const textContent = response.content.find((b) => b.type === "text");
|
||||||
if (textContent && textContent.type === "text") {
|
if (textContent && textContent.type === "text") {
|
||||||
const lowerContent = textContent.text.toLowerCase();
|
const lowerContent = textContent.text.toLowerCase();
|
||||||
expect(lowerContent).toContain("red");
|
expect(lowerContent).toContain("red");
|
||||||
expect(lowerContent).toContain("circle");
|
expect(lowerContent).toContain("circle");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function multiTurn(model: Model, options?: GenerateOptionsUnified) {
|
async function multiTurn<TApi extends Api>(model: Model<TApi>, options?: OptionsForApi<TApi>) {
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool.",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
tools: [calculatorTool]
|
tools: [calculatorTool],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collect all text content from all assistant responses
|
// Collect all text content from all assistant responses
|
||||||
let allTextContent = "";
|
let allTextContent = "";
|
||||||
let hasSeenThinking = false;
|
let hasSeenThinking = false;
|
||||||
let hasSeenToolCalls = false;
|
let hasSeenToolCalls = false;
|
||||||
const maxTurns = 5; // Prevent infinite loops
|
const maxTurns = 5; // Prevent infinite loops
|
||||||
|
|
||||||
for (let turn = 0; turn < maxTurns; turn++) {
|
for (let turn = 0; turn < maxTurns; turn++) {
|
||||||
const response = await generateComplete(model, context, options);
|
const response = await complete(model, context, options);
|
||||||
|
|
||||||
// Add the assistant response to context
|
// Add the assistant response to context
|
||||||
context.messages.push(response);
|
context.messages.push(response);
|
||||||
|
|
||||||
// Process content blocks
|
// Process content blocks
|
||||||
for (const block of response.content) {
|
for (const block of response.content) {
|
||||||
if (block.type === "text") {
|
if (block.type === "text") {
|
||||||
allTextContent += block.text;
|
allTextContent += block.text;
|
||||||
} else if (block.type === "thinking") {
|
} else if (block.type === "thinking") {
|
||||||
hasSeenThinking = true;
|
hasSeenThinking = true;
|
||||||
} else if (block.type === "toolCall") {
|
} else if (block.type === "toolCall") {
|
||||||
hasSeenToolCalls = true;
|
hasSeenToolCalls = true;
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
expect(block.name).toBe("calculator");
|
expect(block.name).toBe("calculator");
|
||||||
expect(block.id).toBeTruthy();
|
expect(block.id).toBeTruthy();
|
||||||
expect(block.arguments).toBeTruthy();
|
expect(block.arguments).toBeTruthy();
|
||||||
|
|
||||||
const { a, b, operation } = block.arguments;
|
const { a, b, operation } = block.arguments;
|
||||||
let result: number;
|
let result: number;
|
||||||
switch (operation) {
|
switch (operation) {
|
||||||
case "add": result = a + b; break;
|
case "add":
|
||||||
case "multiply": result = a * b; break;
|
result = a + b;
|
||||||
default: result = 0;
|
break;
|
||||||
}
|
case "multiply":
|
||||||
|
result = a * b;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
result = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// Add tool result to context
|
// Add tool result to context
|
||||||
context.messages.push({
|
context.messages.push({
|
||||||
role: "toolResult",
|
role: "toolResult",
|
||||||
toolCallId: block.id,
|
toolCallId: block.id,
|
||||||
toolName: block.name,
|
toolName: block.name,
|
||||||
content: `${result}`,
|
content: `${result}`,
|
||||||
isError: false
|
isError: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we got a stop response with text content, we're likely done
|
// If we got a stop response with text content, we're likely done
|
||||||
expect(response.stopReason).not.toBe("error");
|
expect(response.stopReason).not.toBe("error");
|
||||||
if (response.stopReason === "stop") {
|
if (response.stopReason === "stop") {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we got either thinking content or tool calls (or both)
|
// Verify we got either thinking content or tool calls (or both)
|
||||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
||||||
|
|
||||||
// The accumulated text should reference both calculations
|
// The accumulated text should reference both calculations
|
||||||
expect(allTextContent).toBeTruthy();
|
expect(allTextContent).toBeTruthy();
|
||||||
expect(allTextContent.includes("714")).toBe(true);
|
expect(allTextContent.includes("714")).toBe(true);
|
||||||
expect(allTextContent.includes("887")).toBe(true);
|
expect(allTextContent.includes("887")).toBe(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("Generate E2E Tests", () => {
|
describe("Generate E2E Tests", () => {
|
||||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
||||||
let model: Model;
|
const llm = getModel("google", "gemini-2.5-flash");
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should complete basic text generation", async () => {
|
||||||
model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
await basicTextGeneration(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
it("should handle tool calling", async () => {
|
||||||
await basicTextGeneration(model);
|
await handleToolCall(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
it("should handle streaming", async () => {
|
||||||
await handleToolCall(model);
|
await handleStreaming(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
it("should handle ", async () => {
|
||||||
await handleStreaming(model);
|
await handleThinking(llm, { thinking: { enabled: true, budgetTokens: 1024 } });
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
await handleImage(model);
|
await multiTurn(llm, { thinking: { enabled: true, budgetTokens: 2048 } });
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
it("should handle image input", async () => {
|
||||||
let model: Model;
|
await handleImage(llm);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
beforeAll(() => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
||||||
model = getModel("anthropic", "claude-sonnet-4-20250514");
|
const llm: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
it("should complete basic text generation", async () => {
|
||||||
await basicTextGeneration(model);
|
await basicTextGeneration(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
it("should handle tool calling", async () => {
|
||||||
await handleToolCall(model);
|
await handleToolCall(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
it("should handle streaming", async () => {
|
||||||
await handleStreaming(model);
|
await handleStreaming(llm);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
it("should handle image input", async () => {
|
||||||
await handleThinking(model, { reasoning: "low" });
|
await handleImage(llm);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||||
await multiTurn(model, { reasoning: "medium" });
|
const llm = getModel("openai", "gpt-5-mini");
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
it("should complete basic text generation", async () => {
|
||||||
await handleImage(model);
|
await basicTextGeneration(llm);
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle ", { retry: 2 }, async () => {
|
||||||
|
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle image input", async () => {
|
||||||
|
await handleImage(llm);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-3-5-haiku-20241022)", () => {
|
||||||
|
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle image input", async () => {
|
||||||
|
await handleImage(model);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
||||||
|
const model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(model, { thinkingEnabled: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking", async () => {
|
||||||
|
await handleThinking(model, { thinkingEnabled: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(model, { thinkingEnabled: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle image input", async () => {
|
||||||
|
await handleImage(model);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
||||||
|
const model = getModel("openai", "gpt-5-mini");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle image input", async () => {
|
||||||
|
await handleImage(model);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
||||||
|
const llm = getModel("xai", "grok-code-fast-1");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking mode", async () => {
|
||||||
|
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||||
|
const llm = getModel("groq", "openai/gpt-oss-20b");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking mode", async () => {
|
||||||
|
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
||||||
|
const llm = getModel("cerebras", "gpt-oss-120b");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking mode", async () => {
|
||||||
|
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
||||||
|
const llm = getModel("openrouter", "z-ai/glm-4.5v");
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking mode", async () => {
|
||||||
|
await handleThinking(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
||||||
|
await multiTurn(llm, { reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle image input", async () => {
|
||||||
|
await handleImage(llm);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check if ollama is installed
|
||||||
|
let ollamaInstalled = false;
|
||||||
|
try {
|
||||||
|
execSync("which ollama", { stdio: "ignore" });
|
||||||
|
ollamaInstalled = true;
|
||||||
|
} catch {
|
||||||
|
ollamaInstalled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
||||||
|
let llm: Model<"openai-completions">;
|
||||||
|
let ollamaProcess: ChildProcess | null = null;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
// Check if model is available, if not pull it
|
||||||
|
try {
|
||||||
|
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
||||||
|
} catch {
|
||||||
|
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
||||||
|
try {
|
||||||
|
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
||||||
|
} catch (e) {
|
||||||
|
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start ollama server
|
||||||
|
ollamaProcess = spawn("ollama", ["serve"], {
|
||||||
|
detached: false,
|
||||||
|
stdio: "ignore",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for server to be ready
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const checkServer = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch("http://localhost:11434/api/tags");
|
||||||
|
if (response.ok) {
|
||||||
|
resolve();
|
||||||
|
} else {
|
||||||
|
setTimeout(checkServer, 500);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
setTimeout(checkServer, 500);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
setTimeout(checkServer, 1000); // Initial delay
|
||||||
|
});
|
||||||
|
|
||||||
|
llm = {
|
||||||
|
id: "gpt-oss:20b",
|
||||||
|
api: "openai-completions",
|
||||||
|
provider: "ollama",
|
||||||
|
baseUrl: "http://localhost:11434/v1",
|
||||||
|
reasoning: true,
|
||||||
|
input: ["text"],
|
||||||
|
contextWindow: 128000,
|
||||||
|
maxTokens: 16000,
|
||||||
|
cost: {
|
||||||
|
input: 0,
|
||||||
|
output: 0,
|
||||||
|
cacheRead: 0,
|
||||||
|
cacheWrite: 0,
|
||||||
|
},
|
||||||
|
name: "Ollama GPT-OSS 20B",
|
||||||
|
};
|
||||||
|
}, 30000); // 30 second timeout for setup
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
// Kill ollama server
|
||||||
|
if (ollamaProcess) {
|
||||||
|
ollamaProcess.kill("SIGTERM");
|
||||||
|
ollamaProcess = null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should complete basic text generation", async () => {
|
||||||
|
await basicTextGeneration(llm, { apiKey: "test" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle tool calling", async () => {
|
||||||
|
await handleToolCall(llm, { apiKey: "test" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle streaming", async () => {
|
||||||
|
await handleStreaming(llm, { apiKey: "test" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle thinking mode", async () => {
|
||||||
|
await handleThinking(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle multi-turn with thinking and tools", async () => {
|
||||||
|
await multiTurn(llm, { apiKey: "test", reasoningEffort: "medium" });
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -1,503 +1,489 @@
|
||||||
import { describe, it, expect, beforeAll } from "vitest";
|
import { describe, expect, it } from "vitest";
|
||||||
import { GoogleLLM } from "../src/providers/google.js";
|
import { complete } from "../src/generate.js";
|
||||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
import { getModel } from "../src/models.js";
|
||||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
import type { Api, AssistantMessage, Context, Message, Model, Tool } from "../src/types.js";
|
||||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
|
||||||
import type { LLM, Context, AssistantMessage, Tool, Message } from "../src/types.js";
|
|
||||||
import { createLLM, getModel } from "../src/models.js";
|
|
||||||
|
|
||||||
// Tool for testing
|
// Tool for testing
|
||||||
const weatherTool: Tool = {
|
const weatherTool: Tool = {
|
||||||
name: "get_weather",
|
name: "get_weather",
|
||||||
description: "Get the weather for a location",
|
description: "Get the weather for a location",
|
||||||
parameters: {
|
parameters: {
|
||||||
type: "object",
|
type: "object",
|
||||||
properties: {
|
properties: {
|
||||||
location: { type: "string", description: "City name" }
|
location: { type: "string", description: "City name" },
|
||||||
},
|
},
|
||||||
required: ["location"]
|
required: ["location"],
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Pre-built contexts representing typical outputs from each provider
|
// Pre-built contexts representing typical outputs from each provider
|
||||||
const providerContexts = {
|
const providerContexts = {
|
||||||
// Anthropic-style message with thinking block
|
// Anthropic-style message with thinking block
|
||||||
anthropic: {
|
anthropic: {
|
||||||
message: {
|
message: {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
thinking: "Let me calculate 17 * 23. That's 17 * 20 + 17 * 3 = 340 + 51 = 391",
|
||||||
thinkingSignature: "signature_abc123"
|
thinkingSignature: "signature_abc123",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "text",
|
type: "text",
|
||||||
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you."
|
text: "I'll help you with the calculation and check the weather. The result of 17 × 23 is 391. The capital of Austria is Vienna. Now let me check the weather for you.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
id: "toolu_01abc123",
|
id: "toolu_01abc123",
|
||||||
name: "get_weather",
|
name: "get_weather",
|
||||||
arguments: { location: "Tokyo" }
|
arguments: { location: "Tokyo" },
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
provider: "anthropic",
|
provider: "anthropic",
|
||||||
model: "claude-3-5-haiku-latest",
|
model: "claude-3-5-haiku-latest",
|
||||||
usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
usage: {
|
||||||
stopReason: "toolUse"
|
input: 100,
|
||||||
} as AssistantMessage,
|
output: 50,
|
||||||
toolResult: {
|
cacheRead: 0,
|
||||||
role: "toolResult" as const,
|
cacheWrite: 0,
|
||||||
toolCallId: "toolu_01abc123",
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
toolName: "get_weather",
|
},
|
||||||
content: "Weather in Tokyo: 18°C, partly cloudy",
|
stopReason: "toolUse",
|
||||||
isError: false
|
} as AssistantMessage,
|
||||||
},
|
toolResult: {
|
||||||
facts: {
|
role: "toolResult" as const,
|
||||||
calculation: 391,
|
toolCallId: "toolu_01abc123",
|
||||||
city: "Tokyo",
|
toolName: "get_weather",
|
||||||
temperature: 18,
|
content: "Weather in Tokyo: 18°C, partly cloudy",
|
||||||
capital: "Vienna"
|
isError: false,
|
||||||
}
|
},
|
||||||
},
|
facts: {
|
||||||
|
calculation: 391,
|
||||||
|
city: "Tokyo",
|
||||||
|
temperature: 18,
|
||||||
|
capital: "Vienna",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// Google-style message with thinking
|
// Google-style message with thinking
|
||||||
google: {
|
google: {
|
||||||
message: {
|
message: {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
thinking: "I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
thinking:
|
||||||
thinkingSignature: undefined
|
"I need to multiply 19 * 24. Let me work through this: 19 * 24 = 19 * 20 + 19 * 4 = 380 + 76 = 456",
|
||||||
},
|
thinkingSignature: undefined,
|
||||||
{
|
},
|
||||||
type: "text",
|
{
|
||||||
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you."
|
type: "text",
|
||||||
},
|
text: "The multiplication of 19 × 24 equals 456. The capital of France is Paris. Let me check the weather in Berlin for you.",
|
||||||
{
|
},
|
||||||
type: "toolCall",
|
{
|
||||||
id: "call_gemini_123",
|
type: "toolCall",
|
||||||
name: "get_weather",
|
id: "call_gemini_123",
|
||||||
arguments: { location: "Berlin" }
|
name: "get_weather",
|
||||||
}
|
arguments: { location: "Berlin" },
|
||||||
],
|
},
|
||||||
provider: "google",
|
],
|
||||||
model: "gemini-2.5-flash",
|
provider: "google",
|
||||||
usage: { input: 120, output: 60, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
model: "gemini-2.5-flash",
|
||||||
stopReason: "toolUse"
|
usage: {
|
||||||
} as AssistantMessage,
|
input: 120,
|
||||||
toolResult: {
|
output: 60,
|
||||||
role: "toolResult" as const,
|
cacheRead: 0,
|
||||||
toolCallId: "call_gemini_123",
|
cacheWrite: 0,
|
||||||
toolName: "get_weather",
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
content: "Weather in Berlin: 22°C, sunny",
|
},
|
||||||
isError: false
|
stopReason: "toolUse",
|
||||||
},
|
} as AssistantMessage,
|
||||||
facts: {
|
toolResult: {
|
||||||
calculation: 456,
|
role: "toolResult" as const,
|
||||||
city: "Berlin",
|
toolCallId: "call_gemini_123",
|
||||||
temperature: 22,
|
toolName: "get_weather",
|
||||||
capital: "Paris"
|
content: "Weather in Berlin: 22°C, sunny",
|
||||||
}
|
isError: false,
|
||||||
},
|
},
|
||||||
|
facts: {
|
||||||
|
calculation: 456,
|
||||||
|
city: "Berlin",
|
||||||
|
temperature: 22,
|
||||||
|
capital: "Paris",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// OpenAI Completions style (with reasoning_content)
|
// OpenAI Completions style (with reasoning_content)
|
||||||
openaiCompletions: {
|
openaiCompletions: {
|
||||||
message: {
|
message: {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
thinking: "Let me calculate 21 * 25. That's 21 * 25 = 525",
|
||||||
thinkingSignature: "reasoning_content"
|
thinkingSignature: "reasoning_content",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "text",
|
type: "text",
|
||||||
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now."
|
text: "The result of 21 × 25 is 525. The capital of Spain is Madrid. I'll check the weather in London now.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "toolCall",
|
type: "toolCall",
|
||||||
id: "call_abc123",
|
id: "call_abc123",
|
||||||
name: "get_weather",
|
name: "get_weather",
|
||||||
arguments: { location: "London" }
|
arguments: { location: "London" },
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
model: "gpt-4o-mini",
|
model: "gpt-4o-mini",
|
||||||
usage: { input: 110, output: 55, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
usage: {
|
||||||
stopReason: "toolUse"
|
input: 110,
|
||||||
} as AssistantMessage,
|
output: 55,
|
||||||
toolResult: {
|
cacheRead: 0,
|
||||||
role: "toolResult" as const,
|
cacheWrite: 0,
|
||||||
toolCallId: "call_abc123",
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
toolName: "get_weather",
|
},
|
||||||
content: "Weather in London: 15°C, rainy",
|
stopReason: "toolUse",
|
||||||
isError: false
|
} as AssistantMessage,
|
||||||
},
|
toolResult: {
|
||||||
facts: {
|
role: "toolResult" as const,
|
||||||
calculation: 525,
|
toolCallId: "call_abc123",
|
||||||
city: "London",
|
toolName: "get_weather",
|
||||||
temperature: 15,
|
content: "Weather in London: 15°C, rainy",
|
||||||
capital: "Madrid"
|
isError: false,
|
||||||
}
|
},
|
||||||
},
|
facts: {
|
||||||
|
calculation: 525,
|
||||||
|
city: "London",
|
||||||
|
temperature: 15,
|
||||||
|
capital: "Madrid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// OpenAI Responses style (with complex tool call IDs)
|
// OpenAI Responses style (with complex tool call IDs)
|
||||||
openaiResponses: {
|
openaiResponses: {
|
||||||
message: {
|
message: {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
thinking: "Calculating 18 * 27: 18 * 27 = 486",
|
||||||
thinkingSignature: '{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}'
|
thinkingSignature:
|
||||||
},
|
'{"type":"reasoning","id":"rs_2b2342acdde","summary":[{"type":"summary_text","text":"Calculating 18 * 27: 18 * 27 = 486"}]}',
|
||||||
{
|
},
|
||||||
type: "text",
|
{
|
||||||
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
type: "text",
|
||||||
textSignature: "msg_response_456"
|
text: "The calculation of 18 × 27 gives us 486. The capital of Italy is Rome. Let me check Sydney's weather.",
|
||||||
},
|
textSignature: "msg_response_456",
|
||||||
{
|
},
|
||||||
type: "toolCall",
|
{
|
||||||
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
type: "toolCall",
|
||||||
name: "get_weather",
|
id: "call_789_item_012", // Anthropic requires alphanumeric, dash, and underscore only
|
||||||
arguments: { location: "Sydney" }
|
name: "get_weather",
|
||||||
}
|
arguments: { location: "Sydney" },
|
||||||
],
|
},
|
||||||
provider: "openai",
|
],
|
||||||
model: "gpt-5-mini",
|
provider: "openai",
|
||||||
usage: { input: 115, output: 58, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
model: "gpt-5-mini",
|
||||||
stopReason: "toolUse"
|
usage: {
|
||||||
} as AssistantMessage,
|
input: 115,
|
||||||
toolResult: {
|
output: 58,
|
||||||
role: "toolResult" as const,
|
cacheRead: 0,
|
||||||
toolCallId: "call_789_item_012", // Match the updated ID format
|
cacheWrite: 0,
|
||||||
toolName: "get_weather",
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
content: "Weather in Sydney: 25°C, clear",
|
},
|
||||||
isError: false
|
stopReason: "toolUse",
|
||||||
},
|
} as AssistantMessage,
|
||||||
facts: {
|
toolResult: {
|
||||||
calculation: 486,
|
role: "toolResult" as const,
|
||||||
city: "Sydney",
|
toolCallId: "call_789_item_012", // Match the updated ID format
|
||||||
temperature: 25,
|
toolName: "get_weather",
|
||||||
capital: "Rome"
|
content: "Weather in Sydney: 25°C, clear",
|
||||||
}
|
isError: false,
|
||||||
},
|
},
|
||||||
|
facts: {
|
||||||
|
calculation: 486,
|
||||||
|
city: "Sydney",
|
||||||
|
temperature: 25,
|
||||||
|
capital: "Rome",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// Aborted message (stopReason: 'error')
|
// Aborted message (stopReason: 'error')
|
||||||
aborted: {
|
aborted: {
|
||||||
message: {
|
message: {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: "thinking",
|
type: "thinking",
|
||||||
thinking: "Let me start calculating 20 * 30...",
|
thinking: "Let me start calculating 20 * 30...",
|
||||||
thinkingSignature: "partial_sig"
|
thinkingSignature: "partial_sig",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "text",
|
type: "text",
|
||||||
text: "I was about to calculate 20 × 30 which is"
|
text: "I was about to calculate 20 × 30 which is",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
provider: "test",
|
provider: "test",
|
||||||
model: "test-model",
|
model: "test-model",
|
||||||
usage: { input: 50, output: 25, cacheRead: 0, cacheWrite: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
usage: {
|
||||||
stopReason: "error",
|
input: 50,
|
||||||
error: "Request was aborted"
|
output: 25,
|
||||||
} as AssistantMessage,
|
cacheRead: 0,
|
||||||
toolResult: null,
|
cacheWrite: 0,
|
||||||
facts: {
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||||
calculation: 600,
|
},
|
||||||
city: "none",
|
stopReason: "error",
|
||||||
temperature: 0,
|
error: "Request was aborted",
|
||||||
capital: "none"
|
} as AssistantMessage,
|
||||||
}
|
toolResult: null,
|
||||||
}
|
facts: {
|
||||||
|
calculation: 600,
|
||||||
|
city: "none",
|
||||||
|
temperature: 0,
|
||||||
|
capital: "none",
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test that a provider can handle contexts from different sources
|
* Test that a provider can handle contexts from different sources
|
||||||
*/
|
*/
|
||||||
async function testProviderHandoff(
|
async function testProviderHandoff<TApi extends Api>(
|
||||||
targetProvider: LLM<any>,
|
targetModel: Model<TApi>,
|
||||||
sourceLabel: string,
|
sourceLabel: string,
|
||||||
sourceContext: typeof providerContexts[keyof typeof providerContexts]
|
sourceContext: (typeof providerContexts)[keyof typeof providerContexts],
|
||||||
): Promise<boolean> {
|
): Promise<boolean> {
|
||||||
// Build conversation context
|
// Build conversation context
|
||||||
const messages: Message[] = [
|
const messages: Message[] = [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: "Please do some calculations, tell me about capitals, and check the weather."
|
content: "Please do some calculations, tell me about capitals, and check the weather.",
|
||||||
},
|
},
|
||||||
sourceContext.message
|
sourceContext.message,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Add tool result if present
|
// Add tool result if present
|
||||||
if (sourceContext.toolResult) {
|
if (sourceContext.toolResult) {
|
||||||
messages.push(sourceContext.toolResult);
|
messages.push(sourceContext.toolResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ask follow-up question
|
// Ask follow-up question
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `Based on our conversation, please answer:
|
content: `Based on our conversation, please answer:
|
||||||
1) What was the multiplication result?
|
1) What was the multiplication result?
|
||||||
2) Which city's weather did we check?
|
2) Which city's weather did we check?
|
||||||
3) What was the temperature?
|
3) What was the temperature?
|
||||||
4) What capital city was mentioned?
|
4) What capital city was mentioned?
|
||||||
Please include the specific numbers and names.`
|
Please include the specific numbers and names.`,
|
||||||
});
|
});
|
||||||
|
|
||||||
const context: Context = {
|
const context: Context = {
|
||||||
messages,
|
messages,
|
||||||
tools: [weatherTool]
|
tools: [weatherTool],
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await targetProvider.generate(context, {});
|
const response = await complete(targetModel, context, {});
|
||||||
|
|
||||||
// Check for error
|
// Check for error
|
||||||
if (response.stopReason === "error") {
|
if (response.stopReason === "error") {
|
||||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Failed with error: ${response.error}`);
|
console.log(`[${sourceLabel} → ${targetModel.provider}] Failed with error: ${response.error}`);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract text from response
|
// Extract text from response
|
||||||
const responseText = response.content
|
const responseText = response.content
|
||||||
.filter(b => b.type === "text")
|
.filter((b) => b.type === "text")
|
||||||
.map(b => b.text)
|
.map((b) => b.text)
|
||||||
.join(" ")
|
.join(" ")
|
||||||
.toLowerCase();
|
.toLowerCase();
|
||||||
|
|
||||||
// For aborted messages, we don't expect to find the facts
|
// For aborted messages, we don't expect to find the facts
|
||||||
if (sourceContext.message.stopReason === "error") {
|
if (sourceContext.message.stopReason === "error") {
|
||||||
const hasToolCalls = response.content.some(b => b.type === "toolCall");
|
const hasToolCalls = response.content.some((b) => b.type === "toolCall");
|
||||||
const hasThinking = response.content.some(b => b.type === "thinking");
|
const hasThinking = response.content.some((b) => b.type === "thinking");
|
||||||
const hasText = response.content.some(b => b.type === "text");
|
const hasText = response.content.some((b) => b.type === "text");
|
||||||
|
|
||||||
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
expect(response.stopReason === "stop" || response.stopReason === "toolUse").toBe(true);
|
||||||
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
expect(hasThinking || hasText || hasToolCalls).toBe(true);
|
||||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`);
|
console.log(
|
||||||
return true;
|
`[${sourceLabel} → ${targetModel.provider}] Handled aborted message successfully, tool calls: ${hasToolCalls}, thinking: ${hasThinking}, text: ${hasText}`,
|
||||||
}
|
);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if response contains our facts
|
// Check if response contains our facts
|
||||||
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
const hasCalculation = responseText.includes(sourceContext.facts.calculation.toString());
|
||||||
const hasCity = sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
const hasCity =
|
||||||
const hasTemperature = sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
sourceContext.facts.city !== "none" && responseText.includes(sourceContext.facts.city.toLowerCase());
|
||||||
const hasCapital = sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
const hasTemperature =
|
||||||
|
sourceContext.facts.temperature > 0 && responseText.includes(sourceContext.facts.temperature.toString());
|
||||||
|
const hasCapital =
|
||||||
|
sourceContext.facts.capital !== "none" && responseText.includes(sourceContext.facts.capital.toLowerCase());
|
||||||
|
|
||||||
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
const success = hasCalculation && hasCity && hasTemperature && hasCapital;
|
||||||
|
|
||||||
console.log(`[${sourceLabel} → ${targetProvider.getModel().provider}] Handoff test:`);
|
console.log(`[${sourceLabel} → ${targetModel.provider}] Handoff test:`);
|
||||||
if (!success) {
|
if (!success) {
|
||||||
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? '✓' : '✗'}`);
|
console.log(` Calculation (${sourceContext.facts.calculation}): ${hasCalculation ? "✓" : "✗"}`);
|
||||||
console.log(` City (${sourceContext.facts.city}): ${hasCity ? '✓' : '✗'}`);
|
console.log(` City (${sourceContext.facts.city}): ${hasCity ? "✓" : "✗"}`);
|
||||||
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? '✓' : '✗'}`);
|
console.log(` Temperature (${sourceContext.facts.temperature}): ${hasTemperature ? "✓" : "✗"}`);
|
||||||
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? '✓' : '✗'}`);
|
console.log(` Capital (${sourceContext.facts.capital}): ${hasCapital ? "✓" : "✗"}`);
|
||||||
} else {
|
} else {
|
||||||
console.log(` ✓ All facts found`);
|
console.log(` ✓ All facts found`);
|
||||||
}
|
}
|
||||||
|
|
||||||
return success;
|
return success;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`[${sourceLabel} → ${targetProvider.getModel().provider}] Exception:`, error);
|
console.error(`[${sourceLabel} → ${targetModel.provider}] Exception:`, error);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("Cross-Provider Handoff Tests", () => {
|
describe("Cross-Provider Handoff Tests", () => {
|
||||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider Handoff", () => {
|
||||||
let provider: AnthropicLLM;
|
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
||||||
|
|
||||||
beforeAll(() => {
|
it("should handle contexts from all providers", async () => {
|
||||||
const model = getModel("anthropic", "claude-3-5-haiku-20241022");
|
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
||||||
if (model) {
|
|
||||||
provider = new AnthropicLLM(model, process.env.ANTHROPIC_API_KEY!);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle contexts from all providers", async () => {
|
const contextTests = [
|
||||||
if (!provider) {
|
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||||
console.log("Anthropic provider not available, skipping");
|
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||||
return;
|
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||||
}
|
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||||
|
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||||
|
];
|
||||||
|
|
||||||
console.log("\nTesting Anthropic with pre-built contexts:\n");
|
let successCount = 0;
|
||||||
|
let skippedCount = 0;
|
||||||
|
|
||||||
const contextTests = [
|
for (const { label, context, sourceModel } of contextTests) {
|
||||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
// Skip testing same model against itself
|
||||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
if (sourceModel && sourceModel === model.id) {
|
||||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
skippedCount++;
|
||||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
continue;
|
||||||
];
|
}
|
||||||
|
const success = await testProviderHandoff(model, label, context);
|
||||||
|
if (success) successCount++;
|
||||||
|
}
|
||||||
|
|
||||||
let successCount = 0;
|
const totalTests = contextTests.length - skippedCount;
|
||||||
let skippedCount = 0;
|
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||||
|
|
||||||
for (const { label, context, sourceModel } of contextTests) {
|
// All non-skipped handoffs should succeed
|
||||||
// Skip testing same model against itself
|
expect(successCount).toBe(totalTests);
|
||||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
});
|
||||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
});
|
||||||
skippedCount++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const success = await testProviderHandoff(provider, label, context);
|
|
||||||
if (success) successCount++;
|
|
||||||
}
|
|
||||||
|
|
||||||
const totalTests = contextTests.length - skippedCount;
|
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
||||||
console.log(`\nAnthropic success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
const model = getModel("google", "gemini-2.5-flash");
|
||||||
|
|
||||||
// All non-skipped handoffs should succeed
|
it("should handle contexts from all providers", async () => {
|
||||||
expect(successCount).toBe(totalTests);
|
console.log("\nTesting Google with pre-built contexts:\n");
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider Handoff", () => {
|
const contextTests = [
|
||||||
let provider: GoogleLLM;
|
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||||
|
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||||
|
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||||
|
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||||
|
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||||
|
];
|
||||||
|
|
||||||
beforeAll(() => {
|
let successCount = 0;
|
||||||
const model = getModel("google", "gemini-2.5-flash");
|
let skippedCount = 0;
|
||||||
if (model) {
|
|
||||||
provider = new GoogleLLM(model, process.env.GEMINI_API_KEY!);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle contexts from all providers", async () => {
|
for (const { label, context, sourceModel } of contextTests) {
|
||||||
if (!provider) {
|
// Skip testing same model against itself
|
||||||
console.log("Google provider not available, skipping");
|
if (sourceModel && sourceModel === model.id) {
|
||||||
return;
|
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||||
}
|
skippedCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const success = await testProviderHandoff(model, label, context);
|
||||||
|
if (success) successCount++;
|
||||||
|
}
|
||||||
|
|
||||||
console.log("\nTesting Google with pre-built contexts:\n");
|
const totalTests = contextTests.length - skippedCount;
|
||||||
|
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||||
|
|
||||||
const contextTests = [
|
// All non-skipped handoffs should succeed
|
||||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
expect(successCount).toBe(totalTests);
|
||||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
});
|
||||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
});
|
||||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
|
||||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
|
||||||
];
|
|
||||||
|
|
||||||
let successCount = 0;
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
||||||
let skippedCount = 0;
|
const model: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" };
|
||||||
|
|
||||||
for (const { label, context, sourceModel } of contextTests) {
|
it("should handle contexts from all providers", async () => {
|
||||||
// Skip testing same model against itself
|
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
||||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
|
||||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
|
||||||
skippedCount++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const success = await testProviderHandoff(provider, label, context);
|
|
||||||
if (success) successCount++;
|
|
||||||
}
|
|
||||||
|
|
||||||
const totalTests = contextTests.length - skippedCount;
|
const contextTests = [
|
||||||
console.log(`\nGoogle success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||||
|
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||||
|
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||||
|
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||||
|
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||||
|
];
|
||||||
|
|
||||||
// All non-skipped handoffs should succeed
|
let successCount = 0;
|
||||||
expect(successCount).toBe(totalTests);
|
let skippedCount = 0;
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider Handoff", () => {
|
for (const { label, context, sourceModel } of contextTests) {
|
||||||
let provider: OpenAICompletionsLLM;
|
// Skip testing same model against itself
|
||||||
|
if (sourceModel && sourceModel === model.id) {
|
||||||
|
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||||
|
skippedCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const success = await testProviderHandoff(model, label, context);
|
||||||
|
if (success) successCount++;
|
||||||
|
}
|
||||||
|
|
||||||
beforeAll(() => {
|
const totalTests = contextTests.length - skippedCount;
|
||||||
const model = getModel("openai", "gpt-4o-mini");
|
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||||
if (model) {
|
|
||||||
provider = new OpenAICompletionsLLM(model, process.env.OPENAI_API_KEY!);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle contexts from all providers", async () => {
|
// All non-skipped handoffs should succeed
|
||||||
if (!provider) {
|
expect(successCount).toBe(totalTests);
|
||||||
console.log("OpenAI Completions provider not available, skipping");
|
});
|
||||||
return;
|
});
|
||||||
}
|
|
||||||
|
|
||||||
console.log("\nTesting OpenAI Completions with pre-built contexts:\n");
|
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
||||||
|
const model = getModel("openai", "gpt-5-mini");
|
||||||
|
|
||||||
const contextTests = [
|
it("should handle contexts from all providers", async () => {
|
||||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
||||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
|
||||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
|
||||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
|
||||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
|
||||||
];
|
|
||||||
|
|
||||||
let successCount = 0;
|
const contextTests = [
|
||||||
let skippedCount = 0;
|
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
||||||
|
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
||||||
|
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
||||||
|
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
||||||
|
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null },
|
||||||
|
];
|
||||||
|
|
||||||
for (const { label, context, sourceModel } of contextTests) {
|
let successCount = 0;
|
||||||
// Skip testing same model against itself
|
let skippedCount = 0;
|
||||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
|
||||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
|
||||||
skippedCount++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const success = await testProviderHandoff(provider, label, context);
|
|
||||||
if (success) successCount++;
|
|
||||||
}
|
|
||||||
|
|
||||||
const totalTests = contextTests.length - skippedCount;
|
for (const { label, context, sourceModel } of contextTests) {
|
||||||
console.log(`\nOpenAI Completions success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
// Skip testing same model against itself
|
||||||
|
if (sourceModel && sourceModel === model.id) {
|
||||||
|
console.log(`[${label} → ${model.provider}] Skipping same-model test`);
|
||||||
|
skippedCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const success = await testProviderHandoff(model, label, context);
|
||||||
|
if (success) successCount++;
|
||||||
|
}
|
||||||
|
|
||||||
// All non-skipped handoffs should succeed
|
const totalTests = contextTests.length - skippedCount;
|
||||||
expect(successCount).toBe(totalTests);
|
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider Handoff", () => {
|
// All non-skipped handoffs should succeed
|
||||||
let provider: OpenAIResponsesLLM;
|
expect(successCount).toBe(totalTests);
|
||||||
|
});
|
||||||
beforeAll(() => {
|
});
|
||||||
const model = getModel("openai", "gpt-5-mini");
|
|
||||||
if (model) {
|
|
||||||
provider = new OpenAIResponsesLLM(model, process.env.OPENAI_API_KEY!);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle contexts from all providers", async () => {
|
|
||||||
if (!provider) {
|
|
||||||
console.log("OpenAI Responses provider not available, skipping");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
console.log("\nTesting OpenAI Responses with pre-built contexts:\n");
|
|
||||||
|
|
||||||
const contextTests = [
|
|
||||||
{ label: "Anthropic-style", context: providerContexts.anthropic, sourceModel: "claude-3-5-haiku-20241022" },
|
|
||||||
{ label: "Google-style", context: providerContexts.google, sourceModel: "gemini-2.5-flash" },
|
|
||||||
{ label: "OpenAI-Completions", context: providerContexts.openaiCompletions, sourceModel: "gpt-4o-mini" },
|
|
||||||
{ label: "OpenAI-Responses", context: providerContexts.openaiResponses, sourceModel: "gpt-5-mini" },
|
|
||||||
{ label: "Aborted", context: providerContexts.aborted, sourceModel: null }
|
|
||||||
];
|
|
||||||
|
|
||||||
let successCount = 0;
|
|
||||||
let skippedCount = 0;
|
|
||||||
|
|
||||||
for (const { label, context, sourceModel } of contextTests) {
|
|
||||||
// Skip testing same model against itself
|
|
||||||
if (sourceModel && sourceModel === provider.getModel().id) {
|
|
||||||
console.log(`[${label} → ${provider.getModel().provider}] Skipping same-model test`);
|
|
||||||
skippedCount++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const success = await testProviderHandoff(provider, label, context);
|
|
||||||
if (success) successCount++;
|
|
||||||
}
|
|
||||||
|
|
||||||
const totalTests = contextTests.length - skippedCount;
|
|
||||||
console.log(`\nOpenAI Responses success rate: ${successCount}/${totalTests} (${skippedCount} skipped)\n`);
|
|
||||||
|
|
||||||
// All non-skipped handoffs should succeed
|
|
||||||
expect(successCount).toBe(totalTests);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
import { GoogleGenAI } from "@google/genai";
|
|
||||||
import OpenAI from "openai";
|
|
||||||
|
|
||||||
const ai = new GoogleGenAI({});
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
/*let pager = await ai.models.list();
|
|
||||||
do {
|
|
||||||
for (const model of pager.page) {
|
|
||||||
console.log(JSON.stringify(model, null, 2));
|
|
||||||
console.log("---");
|
|
||||||
}
|
|
||||||
if (!pager.hasNextPage()) break;
|
|
||||||
await pager.nextPage();
|
|
||||||
} while (true);*/
|
|
||||||
|
|
||||||
const openai = new OpenAI();
|
|
||||||
const response = await openai.models.list();
|
|
||||||
do {
|
|
||||||
const page = response.data;
|
|
||||||
for (const model of page) {
|
|
||||||
const info = await openai.models.retrieve(model.id);
|
|
||||||
console.log(JSON.stringify(model, null, 2));
|
|
||||||
console.log("---");
|
|
||||||
}
|
|
||||||
if (!response.hasNextPage()) break;
|
|
||||||
await response.getNextPage();
|
|
||||||
} while (true);
|
|
||||||
}
|
|
||||||
|
|
||||||
await main();
|
|
||||||
|
|
@ -1,618 +0,0 @@
|
||||||
import { describe, it, beforeAll, afterAll, expect } from "vitest";
|
|
||||||
import { GoogleLLM } from "../src/providers/google.js";
|
|
||||||
import { OpenAICompletionsLLM } from "../src/providers/openai-completions.js";
|
|
||||||
import { OpenAIResponsesLLM } from "../src/providers/openai-responses.js";
|
|
||||||
import { AnthropicLLM } from "../src/providers/anthropic.js";
|
|
||||||
import type { LLM, LLMOptions, Context, Tool, AssistantMessage, Model, ImageContent } from "../src/types.js";
|
|
||||||
import { spawn, ChildProcess, execSync } from "child_process";
|
|
||||||
import { createLLM, getModel } from "../src/models.js";
|
|
||||||
import { readFileSync } from "fs";
|
|
||||||
import { join, dirname } from "path";
|
|
||||||
import { fileURLToPath } from "url";
|
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url);
|
|
||||||
const __dirname = dirname(__filename);
|
|
||||||
|
|
||||||
// Calculator tool definition (same as examples)
|
|
||||||
const calculatorTool: Tool = {
|
|
||||||
name: "calculator",
|
|
||||||
description: "Perform basic arithmetic operations",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
a: { type: "number", description: "First number" },
|
|
||||||
b: { type: "number", description: "Second number" },
|
|
||||||
operation: {
|
|
||||||
type: "string",
|
|
||||||
enum: ["add", "subtract", "multiply", "divide"],
|
|
||||||
description: "The operation to perform"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
required: ["a", "b", "operation"]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
async function basicTextGeneration<T extends LLMOptions>(llm: LLM<T>) {
|
|
||||||
const context: Context = {
|
|
||||||
systemPrompt: "You are a helpful assistant. Be concise.",
|
|
||||||
messages: [
|
|
||||||
{ role: "user", content: "Reply with exactly: 'Hello test successful'" }
|
|
||||||
]
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await llm.generate(context);
|
|
||||||
|
|
||||||
expect(response.role).toBe("assistant");
|
|
||||||
expect(response.content).toBeTruthy();
|
|
||||||
expect(response.usage.input + response.usage.cacheRead).toBeGreaterThan(0);
|
|
||||||
expect(response.usage.output).toBeGreaterThan(0);
|
|
||||||
expect(response.error).toBeFalsy();
|
|
||||||
expect(response.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Hello test successful");
|
|
||||||
|
|
||||||
context.messages.push(response);
|
|
||||||
context.messages.push({ role: "user", content: "Now say 'Goodbye test successful'" });
|
|
||||||
|
|
||||||
const secondResponse = await llm.generate(context);
|
|
||||||
|
|
||||||
expect(secondResponse.role).toBe("assistant");
|
|
||||||
expect(secondResponse.content).toBeTruthy();
|
|
||||||
expect(secondResponse.usage.input + secondResponse.usage.cacheRead).toBeGreaterThan(0);
|
|
||||||
expect(secondResponse.usage.output).toBeGreaterThan(0);
|
|
||||||
expect(secondResponse.error).toBeFalsy();
|
|
||||||
expect(secondResponse.content.map(b => b.type == "text" ? b.text : "").join("")).toContain("Goodbye test successful");
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleToolCall<T extends LLMOptions>(llm: LLM<T>) {
|
|
||||||
const context: Context = {
|
|
||||||
systemPrompt: "You are a helpful assistant that uses tools when asked.",
|
|
||||||
messages: [{
|
|
||||||
role: "user",
|
|
||||||
content: "Calculate 15 + 27 using the calculator tool."
|
|
||||||
}],
|
|
||||||
tools: [calculatorTool]
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await llm.generate(context);
|
|
||||||
expect(response.stopReason).toBe("toolUse");
|
|
||||||
expect(response.content.some(b => b.type == "toolCall")).toBeTruthy();
|
|
||||||
const toolCall = response.content.find(b => b.type == "toolCall")!;
|
|
||||||
expect(toolCall.name).toBe("calculator");
|
|
||||||
expect(toolCall.id).toBeTruthy();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleStreaming<T extends LLMOptions>(llm: LLM<T>) {
|
|
||||||
let textStarted = false;
|
|
||||||
let textChunks = "";
|
|
||||||
let textCompleted = false;
|
|
||||||
|
|
||||||
const context: Context = {
|
|
||||||
messages: [{ role: "user", content: "Count from 1 to 3" }]
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await llm.generate(context, {
|
|
||||||
onEvent: (event) => {
|
|
||||||
if (event.type === "text_start") {
|
|
||||||
textStarted = true;
|
|
||||||
} else if (event.type === "text_delta") {
|
|
||||||
textChunks += event.delta;
|
|
||||||
} else if (event.type === "text_end") {
|
|
||||||
textCompleted = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} as T);
|
|
||||||
|
|
||||||
expect(textStarted).toBe(true);
|
|
||||||
expect(textChunks.length).toBeGreaterThan(0);
|
|
||||||
expect(textCompleted).toBe(true);
|
|
||||||
expect(response.content.some(b => b.type == "text")).toBeTruthy();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleThinking<T extends LLMOptions>(llm: LLM<T>, options: T) {
|
|
||||||
let thinkingStarted = false;
|
|
||||||
let thinkingChunks = "";
|
|
||||||
let thinkingCompleted = false;
|
|
||||||
|
|
||||||
const context: Context = {
|
|
||||||
messages: [{ role: "user", content: `Think about ${(Math.random() * 255) | 0} + 27. Think step by step. Then output the result.` }]
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await llm.generate(context, {
|
|
||||||
onEvent: (event) => {
|
|
||||||
if (event.type === "thinking_start") {
|
|
||||||
thinkingStarted = true;
|
|
||||||
} else if (event.type === "thinking_delta") {
|
|
||||||
expect(event.content.endsWith(event.delta)).toBe(true);
|
|
||||||
thinkingChunks += event.delta;
|
|
||||||
} else if (event.type === "thinking_end") {
|
|
||||||
thinkingCompleted = true;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
...options
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
expect(response.stopReason, `Error: ${(response as any).error}`).toBe("stop");
|
|
||||||
expect(thinkingStarted).toBe(true);
|
|
||||||
expect(thinkingChunks.length).toBeGreaterThan(0);
|
|
||||||
expect(thinkingCompleted).toBe(true);
|
|
||||||
expect(response.content.some(b => b.type == "thinking")).toBeTruthy();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleImage<T extends LLMOptions>(llm: LLM<T>) {
|
|
||||||
// Check if the model supports images
|
|
||||||
const model = llm.getModel();
|
|
||||||
if (!model.input.includes("image")) {
|
|
||||||
console.log(`Skipping image test - model ${model.id} doesn't support images`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the test image
|
|
||||||
const imagePath = join(__dirname, "data", "red-circle.png");
|
|
||||||
const imageBuffer = readFileSync(imagePath);
|
|
||||||
const base64Image = imageBuffer.toString("base64");
|
|
||||||
|
|
||||||
const imageContent: ImageContent = {
|
|
||||||
type: "image",
|
|
||||||
data: base64Image,
|
|
||||||
mimeType: "image/png",
|
|
||||||
};
|
|
||||||
|
|
||||||
const context: Context = {
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: [
|
|
||||||
{ type: "text", text: "What do you see in this image? Please describe the shape (circle, rectangle, square, triangle, ...) and color (red, blue, green, ...)." },
|
|
||||||
imageContent,
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await llm.generate(context);
|
|
||||||
|
|
||||||
// Check the response mentions red and circle
|
|
||||||
expect(response.content.length > 0).toBeTruthy();
|
|
||||||
const lowerContent = response.content.find(b => b.type == "text")?.text || "";
|
|
||||||
expect(lowerContent).toContain("red");
|
|
||||||
expect(lowerContent).toContain("circle");
|
|
||||||
}
|
|
||||||
|
|
||||||
async function multiTurn<T extends LLMOptions>(llm: LLM<T>, thinkingOptions: T) {
|
|
||||||
const context: Context = {
|
|
||||||
systemPrompt: "You are a helpful assistant that can use tools to answer questions.",
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "Think about this briefly, then calculate 42 * 17 and 453 + 434 using the calculator tool."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
tools: [calculatorTool]
|
|
||||||
};
|
|
||||||
|
|
||||||
// Collect all text content from all assistant responses
|
|
||||||
let allTextContent = "";
|
|
||||||
let hasSeenThinking = false;
|
|
||||||
let hasSeenToolCalls = false;
|
|
||||||
const maxTurns = 5; // Prevent infinite loops
|
|
||||||
|
|
||||||
for (let turn = 0; turn < maxTurns; turn++) {
|
|
||||||
const response = await llm.generate(context, thinkingOptions);
|
|
||||||
|
|
||||||
// Add the assistant response to context
|
|
||||||
context.messages.push(response);
|
|
||||||
|
|
||||||
// Process content blocks
|
|
||||||
for (const block of response.content) {
|
|
||||||
if (block.type === "text") {
|
|
||||||
allTextContent += block.text;
|
|
||||||
} else if (block.type === "thinking") {
|
|
||||||
hasSeenThinking = true;
|
|
||||||
} else if (block.type === "toolCall") {
|
|
||||||
hasSeenToolCalls = true;
|
|
||||||
|
|
||||||
// Process the tool call
|
|
||||||
expect(block.name).toBe("calculator");
|
|
||||||
expect(block.id).toBeTruthy();
|
|
||||||
expect(block.arguments).toBeTruthy();
|
|
||||||
|
|
||||||
const { a, b, operation } = block.arguments;
|
|
||||||
let result: number;
|
|
||||||
switch (operation) {
|
|
||||||
case "add": result = a + b; break;
|
|
||||||
case "multiply": result = a * b; break;
|
|
||||||
default: result = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add tool result to context
|
|
||||||
context.messages.push({
|
|
||||||
role: "toolResult",
|
|
||||||
toolCallId: block.id,
|
|
||||||
toolName: block.name,
|
|
||||||
content: `${result}`,
|
|
||||||
isError: false
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we got a stop response with text content, we're likely done
|
|
||||||
expect(response.stopReason).not.toBe("error");
|
|
||||||
if (response.stopReason === "stop") {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify we got either thinking content or tool calls (or both)
|
|
||||||
expect(hasSeenThinking || hasSeenToolCalls).toBe(true);
|
|
||||||
|
|
||||||
// The accumulated text should reference both calculations
|
|
||||||
expect(allTextContent).toBeTruthy();
|
|
||||||
expect(allTextContent.includes("714")).toBe(true);
|
|
||||||
expect(allTextContent.includes("887")).toBe(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
describe("AI Providers E2E Tests", () => {
|
|
||||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Gemini Provider (gemini-2.5-flash)", () => {
|
|
||||||
let llm: GoogleLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new GoogleLLM(getModel("google", "gemini-2.5-flash")!, process.env.GEMINI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {thinking: { enabled: true, budgetTokens: 1024 }});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAICompletionsLLM(getModel("openai", "gpt-4o-mini")!, process.env.OPENAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Responses Provider (gpt-5-mini)", () => {
|
|
||||||
let llm: OpenAIResponsesLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAIResponsesLLM(getModel("openai", "gpt-5-mini")!, process.env.OPENAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", {retry: 2}, async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "high"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "high"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.ANTHROPIC_OAUTH_TOKEN)("Anthropic Provider (claude-sonnet-4-20250514)", () => {
|
|
||||||
let llm: AnthropicLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new AnthropicLLM(getModel("anthropic", "claude-sonnet-4-20250514")!, process.env.ANTHROPIC_OAUTH_TOKEN!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {thinking: { enabled: true } });
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {thinking: { enabled: true, budgetTokens: 2048 }});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-code-fast-1 via OpenAI Completions)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAICompletionsLLM(getModel("xai", "grok-code-fast-1")!, process.env.XAI_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAICompletionsLLM(getModel("groq", "openai/gpt-oss-20b")!, process.env.GROQ_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b via OpenAI Completions)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAICompletionsLLM(getModel("cerebras", "gpt-oss-120b")!, process.env.CEREBRAS_API_KEY!);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe.skipIf(!process.env.OPENROUTER_API_KEY)("OpenRouter Provider (glm-4.5v via OpenAI Completions)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = new OpenAICompletionsLLM(getModel("openrouter", "z-ai/glm-4.5v")!, process.env.OPENROUTER_API_KEY!);;
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", { retry: 2 }, async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Check if ollama is installed
|
|
||||||
let ollamaInstalled = false;
|
|
||||||
try {
|
|
||||||
execSync("which ollama", { stdio: "ignore" });
|
|
||||||
ollamaInstalled = true;
|
|
||||||
} catch {
|
|
||||||
ollamaInstalled = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
describe.skipIf(!ollamaInstalled)("Ollama Provider (gpt-oss-20b via OpenAI Completions)", () => {
|
|
||||||
let llm: OpenAICompletionsLLM;
|
|
||||||
let ollamaProcess: ChildProcess | null = null;
|
|
||||||
|
|
||||||
beforeAll(async () => {
|
|
||||||
// Check if model is available, if not pull it
|
|
||||||
try {
|
|
||||||
execSync("ollama list | grep -q 'gpt-oss:20b'", { stdio: "ignore" });
|
|
||||||
} catch {
|
|
||||||
console.log("Pulling gpt-oss:20b model for Ollama tests...");
|
|
||||||
try {
|
|
||||||
execSync("ollama pull gpt-oss:20b", { stdio: "inherit" });
|
|
||||||
} catch (e) {
|
|
||||||
console.warn("Failed to pull gpt-oss:20b model, tests will be skipped");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start ollama server
|
|
||||||
ollamaProcess = spawn("ollama", ["serve"], {
|
|
||||||
detached: false,
|
|
||||||
stdio: "ignore"
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wait for server to be ready
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
const checkServer = async () => {
|
|
||||||
try {
|
|
||||||
const response = await fetch("http://localhost:11434/api/tags");
|
|
||||||
if (response.ok) {
|
|
||||||
resolve();
|
|
||||||
} else {
|
|
||||||
setTimeout(checkServer, 500);
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
setTimeout(checkServer, 500);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
setTimeout(checkServer, 1000); // Initial delay
|
|
||||||
});
|
|
||||||
|
|
||||||
const model: Model = {
|
|
||||||
id: "gpt-oss:20b",
|
|
||||||
provider: "ollama",
|
|
||||||
baseUrl: "http://localhost:11434/v1",
|
|
||||||
reasoning: true,
|
|
||||||
input: ["text"],
|
|
||||||
contextWindow: 128000,
|
|
||||||
maxTokens: 16000,
|
|
||||||
cost: {
|
|
||||||
input: 0,
|
|
||||||
output: 0,
|
|
||||||
cacheRead: 0,
|
|
||||||
cacheWrite: 0,
|
|
||||||
},
|
|
||||||
name: "Ollama GPT-OSS 20B"
|
|
||||||
}
|
|
||||||
llm = new OpenAICompletionsLLM(model, "dummy");
|
|
||||||
}, 30000); // 30 second timeout for setup
|
|
||||||
|
|
||||||
afterAll(() => {
|
|
||||||
// Kill ollama server
|
|
||||||
if (ollamaProcess) {
|
|
||||||
ollamaProcess.kill("SIGTERM");
|
|
||||||
ollamaProcess = null;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle thinking mode", async () => {
|
|
||||||
await handleThinking(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {reasoningEffort: "medium"});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
/*
|
|
||||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (Haiku 3.5)", () => {
|
|
||||||
let llm: AnthropicLLM;
|
|
||||||
|
|
||||||
beforeAll(() => {
|
|
||||||
llm = createLLM("anthropic", "claude-3-5-haiku-latest");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should complete basic text generation", async () => {
|
|
||||||
await basicTextGeneration(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle tool calling", async () => {
|
|
||||||
await handleToolCall(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle streaming", async () => {
|
|
||||||
await handleStreaming(llm);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multi-turn with thinking and tools", async () => {
|
|
||||||
await multiTurn(llm, {thinking: {enabled: true}});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle image input", async () => {
|
|
||||||
await handleImage(llm);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
*/
|
|
||||||
});
|
|
||||||
|
|
@ -1,15 +1,6 @@
|
||||||
#!/usr/bin/env npx tsx
|
#!/usr/bin/env npx tsx
|
||||||
import {
|
|
||||||
Container,
|
|
||||||
LoadingAnimation,
|
|
||||||
TextComponent,
|
|
||||||
TextEditor,
|
|
||||||
TUI,
|
|
||||||
WhitespaceComponent,
|
|
||||||
} from "../src/index.js";
|
|
||||||
import chalk from "chalk";
|
import chalk from "chalk";
|
||||||
|
import { Container, LoadingAnimation, TextComponent, TextEditor, TUI, WhitespaceComponent } from "../src/index.js";
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test the new smart double-buffered TUI implementation
|
* Test the new smart double-buffered TUI implementation
|
||||||
|
|
@ -24,7 +15,7 @@ async function main() {
|
||||||
|
|
||||||
// Monkey-patch requestRender to measure performance
|
// Monkey-patch requestRender to measure performance
|
||||||
const originalRequestRender = ui.requestRender.bind(ui);
|
const originalRequestRender = ui.requestRender.bind(ui);
|
||||||
ui.requestRender = function() {
|
ui.requestRender = () => {
|
||||||
const startTime = process.hrtime.bigint();
|
const startTime = process.hrtime.bigint();
|
||||||
originalRequestRender();
|
originalRequestRender();
|
||||||
process.nextTick(() => {
|
process.nextTick(() => {
|
||||||
|
|
@ -38,10 +29,12 @@ async function main() {
|
||||||
|
|
||||||
// Add header
|
// Add header
|
||||||
const header = new TextComponent(
|
const header = new TextComponent(
|
||||||
chalk.bold.green("Smart Double Buffer TUI Test") + "\n" +
|
chalk.bold.green("Smart Double Buffer TUI Test") +
|
||||||
chalk.dim("Testing new implementation with component-level caching and smart diffing") + "\n" +
|
"\n" +
|
||||||
chalk.dim("Press CTRL+C to exit"),
|
chalk.dim("Testing new implementation with component-level caching and smart diffing") +
|
||||||
{ bottom: 1 }
|
"\n" +
|
||||||
|
chalk.dim("Press CTRL+C to exit"),
|
||||||
|
{ bottom: 1 },
|
||||||
);
|
);
|
||||||
ui.addChild(header);
|
ui.addChild(header);
|
||||||
|
|
||||||
|
|
@ -57,7 +50,9 @@ async function main() {
|
||||||
|
|
||||||
// Add text editor
|
// Add text editor
|
||||||
const editor = new TextEditor();
|
const editor = new TextEditor();
|
||||||
editor.setText("Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still.");
|
editor.setText(
|
||||||
|
"Type here to test the text editor.\n\nWith smart diffing, only changed lines are redrawn!\n\nThe animation above updates every 80ms but the editor stays perfectly still.",
|
||||||
|
);
|
||||||
container.addChild(editor);
|
container.addChild(editor);
|
||||||
|
|
||||||
// Add the container to UI
|
// Add the container to UI
|
||||||
|
|
@ -71,15 +66,20 @@ async function main() {
|
||||||
const statsInterval = setInterval(() => {
|
const statsInterval = setInterval(() => {
|
||||||
if (renderCount > 0) {
|
if (renderCount > 0) {
|
||||||
const avgRenderTime = Number(totalRenderTime / BigInt(renderCount)) / 1_000_000; // Convert to ms
|
const avgRenderTime = Number(totalRenderTime / BigInt(renderCount)) / 1_000_000; // Convert to ms
|
||||||
const lastRenderTime = renderTimings.length > 0
|
const lastRenderTime =
|
||||||
? Number(renderTimings[renderTimings.length - 1]) / 1_000_000
|
renderTimings.length > 0 ? Number(renderTimings[renderTimings.length - 1]) / 1_000_000 : 0;
|
||||||
: 0;
|
|
||||||
const avgLinesRedrawn = ui.getAverageLinesRedrawn();
|
const avgLinesRedrawn = ui.getAverageLinesRedrawn();
|
||||||
|
|
||||||
statsComponent.setText(
|
statsComponent.setText(
|
||||||
chalk.yellow(`Performance Stats:`) + "\n" +
|
chalk.yellow(`Performance Stats:`) +
|
||||||
chalk.dim(`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`) + "\n" +
|
"\n" +
|
||||||
chalk.dim(`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`)
|
chalk.dim(
|
||||||
|
`Renders: ${renderCount} | Avg Time: ${avgRenderTime.toFixed(2)}ms | Last: ${lastRenderTime.toFixed(2)}ms`,
|
||||||
|
) +
|
||||||
|
"\n" +
|
||||||
|
chalk.dim(
|
||||||
|
`Lines Redrawn: ${ui.getLinesRedrawn()} total | Avg per render: ${avgLinesRedrawn.toFixed(1)}`,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}, 1000);
|
}, 1000);
|
||||||
|
|
@ -96,7 +96,11 @@ async function main() {
|
||||||
ui.stop();
|
ui.stop();
|
||||||
console.log("\n" + chalk.green("Exited double-buffer test"));
|
console.log("\n" + chalk.green("Exited double-buffer test"));
|
||||||
console.log(chalk.dim(`Total renders: ${renderCount}`));
|
console.log(chalk.dim(`Total renders: ${renderCount}`));
|
||||||
console.log(chalk.dim(`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`));
|
console.log(
|
||||||
|
chalk.dim(
|
||||||
|
`Average render time: ${renderCount > 0 ? (Number(totalRenderTime / BigInt(renderCount)) / 1_000_000).toFixed(2) : 0}ms`,
|
||||||
|
),
|
||||||
|
);
|
||||||
console.log(chalk.dim(`Total lines redrawn: ${ui.getLinesRedrawn()}`));
|
console.log(chalk.dim(`Total lines redrawn: ${ui.getLinesRedrawn()}`));
|
||||||
console.log(chalk.dim(`Average lines redrawn per render: ${ui.getAverageLinesRedrawn().toFixed(1)}`));
|
console.log(chalk.dim(`Average lines redrawn per render: ${ui.getAverageLinesRedrawn().toFixed(1)}`));
|
||||||
process.exit(0);
|
process.exit(0);
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,12 @@
|
||||||
#!/usr/bin/env npx tsx
|
#!/usr/bin/env npx tsx
|
||||||
import { TUI, Container, TextEditor, TextComponent, MarkdownComponent, CombinedAutocompleteProvider } from "../src/index.js";
|
import {
|
||||||
|
CombinedAutocompleteProvider,
|
||||||
|
Container,
|
||||||
|
MarkdownComponent,
|
||||||
|
TextComponent,
|
||||||
|
TextEditor,
|
||||||
|
TUI,
|
||||||
|
} from "../src/index.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Chat Application with Autocomplete
|
* Chat Application with Autocomplete
|
||||||
|
|
@ -16,7 +23,7 @@ const ui = new TUI();
|
||||||
// Add header with instructions
|
// Add header with instructions
|
||||||
const header = new TextComponent(
|
const header = new TextComponent(
|
||||||
"💬 Chat Demo | Type '/' for commands | Start typing a filename + Tab to autocomplete | Ctrl+C to exit",
|
"💬 Chat Demo | Type '/' for commands | Start typing a filename + Tab to autocomplete | Ctrl+C to exit",
|
||||||
{ bottom: 1 }
|
{ bottom: 1 },
|
||||||
);
|
);
|
||||||
|
|
||||||
const chatHistory = new Container();
|
const chatHistory = new Container();
|
||||||
|
|
@ -82,7 +89,8 @@ ui.onGlobalKeyPress = (data: string) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add initial welcome message to chat history
|
// Add initial welcome message to chat history
|
||||||
chatHistory.addChild(new MarkdownComponent(`
|
chatHistory.addChild(
|
||||||
|
new MarkdownComponent(`
|
||||||
## Welcome to the Chat Demo!
|
## Welcome to the Chat Demo!
|
||||||
|
|
||||||
**Available slash commands:**
|
**Available slash commands:**
|
||||||
|
|
@ -96,7 +104,8 @@ chatHistory.addChild(new MarkdownComponent(`
|
||||||
- Works with home directory (\`~/\`)
|
- Works with home directory (\`~/\`)
|
||||||
|
|
||||||
Try it out! Type a message or command below.
|
Try it out! Type a message or command below.
|
||||||
`));
|
`),
|
||||||
|
);
|
||||||
|
|
||||||
ui.addChild(header);
|
ui.addChild(header);
|
||||||
ui.addChild(chatHistory);
|
ui.addChild(chatHistory);
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import { test, describe } from "node:test";
|
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
|
import { describe, test } from "node:test";
|
||||||
|
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||||
import { TUI, Container, TextComponent, TextEditor } from "../src/index.js";
|
|
||||||
|
|
||||||
describe("Differential Rendering - Dynamic Content", () => {
|
describe("Differential Rendering - Dynamic Content", () => {
|
||||||
test("handles static text, dynamic container, and text editor correctly", async () => {
|
test("handles static text, dynamic container, and text editor correctly", async () => {
|
||||||
|
|
@ -23,7 +23,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
ui.setFocus(editor);
|
ui.setFocus(editor);
|
||||||
|
|
||||||
// Wait for next tick to complete and flush virtual terminal
|
// Wait for next tick to complete and flush virtual terminal
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 4: Check initial output in scrollbuffer
|
// Step 4: Check initial output in scrollbuffer
|
||||||
|
|
@ -35,12 +35,14 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
console.log("ScrollBuffer lines:", scrollBuffer.length);
|
console.log("ScrollBuffer lines:", scrollBuffer.length);
|
||||||
|
|
||||||
// Count non-empty lines in scrollbuffer
|
// Count non-empty lines in scrollbuffer
|
||||||
let nonEmptyInBuffer = scrollBuffer.filter(line => line.trim() !== "").length;
|
const nonEmptyInBuffer = scrollBuffer.filter((line) => line.trim() !== "").length;
|
||||||
console.log("Non-empty lines in scrollbuffer:", nonEmptyInBuffer);
|
console.log("Non-empty lines in scrollbuffer:", nonEmptyInBuffer);
|
||||||
|
|
||||||
// Verify initial render has static text in scrollbuffer
|
// Verify initial render has static text in scrollbuffer
|
||||||
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")),
|
assert.ok(
|
||||||
`Expected static text in scrollbuffer`);
|
scrollBuffer.some((line) => line.includes("Static Header Text")),
|
||||||
|
`Expected static text in scrollbuffer`,
|
||||||
|
);
|
||||||
|
|
||||||
// Step 5: Add 100 text components to container
|
// Step 5: Add 100 text components to container
|
||||||
console.log("\nAdding 100 components to container...");
|
console.log("\nAdding 100 components to container...");
|
||||||
|
|
@ -52,7 +54,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
// Wait for next tick to complete and flush
|
// Wait for next tick to complete and flush
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 6: Check output after adding 100 components
|
// Step 6: Check output after adding 100 components
|
||||||
|
|
@ -65,7 +67,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
|
|
||||||
// Count all dynamic items in scrollbuffer
|
// Count all dynamic items in scrollbuffer
|
||||||
let dynamicItemsInBuffer = 0;
|
let dynamicItemsInBuffer = 0;
|
||||||
let allItemNumbers = new Set<number>();
|
const allItemNumbers = new Set<number>();
|
||||||
for (const line of scrollBuffer) {
|
for (const line of scrollBuffer) {
|
||||||
const match = line.match(/Dynamic Item (\d+)/);
|
const match = line.match(/Dynamic Item (\d+)/);
|
||||||
if (match) {
|
if (match) {
|
||||||
|
|
@ -80,8 +82,11 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
|
|
||||||
// CRITICAL TEST: The scrollbuffer should contain ALL 100 items
|
// CRITICAL TEST: The scrollbuffer should contain ALL 100 items
|
||||||
// This is what the differential render should preserve!
|
// This is what the differential render should preserve!
|
||||||
assert.strictEqual(allItemNumbers.size, 100,
|
assert.strictEqual(
|
||||||
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`);
|
allItemNumbers.size,
|
||||||
|
100,
|
||||||
|
`Expected all 100 unique items in scrollbuffer, but found ${allItemNumbers.size}`,
|
||||||
|
);
|
||||||
|
|
||||||
// Verify items are 1-100
|
// Verify items are 1-100
|
||||||
for (let i = 1; i <= 100; i++) {
|
for (let i = 1; i <= 100; i++) {
|
||||||
|
|
@ -89,15 +94,20 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also verify the static header is still in scrollbuffer
|
// Also verify the static header is still in scrollbuffer
|
||||||
assert.ok(scrollBuffer.some(line => line.includes("Static Header Text")),
|
assert.ok(
|
||||||
"Static header should still be in scrollbuffer");
|
scrollBuffer.some((line) => line.includes("Static Header Text")),
|
||||||
|
"Static header should still be in scrollbuffer",
|
||||||
|
);
|
||||||
|
|
||||||
// And the editor should be there too
|
// And the editor should be there too
|
||||||
assert.ok(scrollBuffer.some(line => line.includes("╭") && line.includes("╮")),
|
assert.ok(
|
||||||
"Editor top border should be in scrollbuffer");
|
scrollBuffer.some((line) => line.includes("╭") && line.includes("╮")),
|
||||||
assert.ok(scrollBuffer.some(line => line.includes("╰") && line.includes("╯")),
|
"Editor top border should be in scrollbuffer",
|
||||||
"Editor bottom border should be in scrollbuffer");
|
);
|
||||||
|
assert.ok(
|
||||||
|
scrollBuffer.some((line) => line.includes("╰") && line.includes("╯")),
|
||||||
|
"Editor bottom border should be in scrollbuffer",
|
||||||
|
);
|
||||||
|
|
||||||
ui.stop();
|
ui.stop();
|
||||||
});
|
});
|
||||||
|
|
@ -124,7 +134,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
contentContainer.addChild(new TextComponent("Content Line 2"));
|
contentContainer.addChild(new TextComponent("Content Line 2"));
|
||||||
|
|
||||||
// Initial render
|
// Initial render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
let viewport = terminal.getViewport();
|
let viewport = terminal.getViewport();
|
||||||
|
|
@ -142,7 +152,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
statusContainer.addChild(new TextComponent("Status: Processing..."));
|
statusContainer.addChild(new TextComponent("Status: Processing..."));
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
viewport = terminal.getViewport();
|
viewport = terminal.getViewport();
|
||||||
|
|
@ -162,7 +172,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
}
|
}
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
viewport = terminal.getViewport();
|
viewport = terminal.getViewport();
|
||||||
|
|
@ -180,7 +190,7 @@ describe("Differential Rendering - Dynamic Content", () => {
|
||||||
contentLine10.setText("Content Line 10 - MODIFIED");
|
contentLine10.setText("Content Line 10 - MODIFIED");
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
viewport = terminal.getViewport();
|
viewport = terminal.getViewport();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import { TUI, SelectList } from "../src/index.js";
|
|
||||||
import { readdirSync, statSync } from "fs";
|
import { readdirSync, statSync } from "fs";
|
||||||
import { join } from "path";
|
import { join } from "path";
|
||||||
|
import { SelectList, TUI } from "../src/index.js";
|
||||||
|
|
||||||
const ui = new TUI();
|
const ui = new TUI();
|
||||||
ui.start();
|
ui.start();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import { describe, test } from "node:test";
|
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
import { TextEditor, TextComponent, Container, TUI } from "../src/index.js";
|
import { describe, test } from "node:test";
|
||||||
|
import { Container, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||||
|
|
||||||
describe("Layout shift artifacts", () => {
|
describe("Layout shift artifacts", () => {
|
||||||
|
|
@ -27,7 +27,7 @@ describe("Layout shift artifacts", () => {
|
||||||
|
|
||||||
// Initial render
|
// Initial render
|
||||||
ui.start();
|
ui.start();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Capture initial state
|
// Capture initial state
|
||||||
|
|
@ -40,7 +40,7 @@ describe("Layout shift artifacts", () => {
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
// Wait for render
|
// Wait for render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Capture state with status message
|
// Capture state with status message
|
||||||
|
|
@ -51,7 +51,7 @@ describe("Layout shift artifacts", () => {
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
// Wait for render
|
// Wait for render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Capture final state
|
// Capture final state
|
||||||
|
|
@ -64,8 +64,12 @@ describe("Layout shift artifacts", () => {
|
||||||
const nextLine = finalViewport[i + 1];
|
const nextLine = finalViewport[i + 1];
|
||||||
|
|
||||||
// Check if we have duplicate bottom borders (the artifact)
|
// Check if we have duplicate bottom borders (the artifact)
|
||||||
if (currentLine.includes("╰") && currentLine.includes("╯") &&
|
if (
|
||||||
nextLine.includes("╰") && nextLine.includes("╯")) {
|
currentLine.includes("╰") &&
|
||||||
|
currentLine.includes("╯") &&
|
||||||
|
nextLine.includes("╰") &&
|
||||||
|
nextLine.includes("╯")
|
||||||
|
) {
|
||||||
foundDuplicateBorder = true;
|
foundDuplicateBorder = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -74,18 +78,12 @@ describe("Layout shift artifacts", () => {
|
||||||
assert.strictEqual(foundDuplicateBorder, false, "Found duplicate bottom borders - rendering artifact detected!");
|
assert.strictEqual(foundDuplicateBorder, false, "Found duplicate bottom borders - rendering artifact detected!");
|
||||||
|
|
||||||
// Also check that there's only one bottom border total
|
// Also check that there's only one bottom border total
|
||||||
const bottomBorderCount = finalViewport.filter((line) =>
|
const bottomBorderCount = finalViewport.filter((line) => line.includes("╰")).length;
|
||||||
line.includes("╰")
|
|
||||||
).length;
|
|
||||||
assert.strictEqual(bottomBorderCount, 1, `Expected 1 bottom border, found ${bottomBorderCount}`);
|
assert.strictEqual(bottomBorderCount, 1, `Expected 1 bottom border, found ${bottomBorderCount}`);
|
||||||
|
|
||||||
// Verify the editor is back in its original position
|
// Verify the editor is back in its original position
|
||||||
const finalEditorStartLine = finalViewport.findIndex((line) =>
|
const finalEditorStartLine = finalViewport.findIndex((line) => line.includes("╭"));
|
||||||
line.includes("╭")
|
const initialEditorStartLine = initialViewport.findIndex((line) => line.includes("╭"));
|
||||||
);
|
|
||||||
const initialEditorStartLine = initialViewport.findIndex((line) =>
|
|
||||||
line.includes("╭")
|
|
||||||
);
|
|
||||||
assert.strictEqual(finalEditorStartLine, initialEditorStartLine);
|
assert.strictEqual(finalEditorStartLine, initialEditorStartLine);
|
||||||
|
|
||||||
ui.stop();
|
ui.stop();
|
||||||
|
|
@ -103,7 +101,7 @@ describe("Layout shift artifacts", () => {
|
||||||
|
|
||||||
// Initial render
|
// Initial render
|
||||||
ui.start();
|
ui.start();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Rapidly add and remove a status message
|
// Rapidly add and remove a status message
|
||||||
|
|
@ -112,25 +110,21 @@ describe("Layout shift artifacts", () => {
|
||||||
// Add status
|
// Add status
|
||||||
ui.children.splice(1, 0, status);
|
ui.children.splice(1, 0, status);
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Remove status immediately
|
// Remove status immediately
|
||||||
ui.children.splice(1, 1);
|
ui.children.splice(1, 1);
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await term.flush();
|
await term.flush();
|
||||||
|
|
||||||
// Final output check
|
// Final output check
|
||||||
const finalViewport = term.getViewport();
|
const finalViewport = term.getViewport();
|
||||||
|
|
||||||
// Should only have one set of borders for the editor
|
// Should only have one set of borders for the editor
|
||||||
const topBorderCount = finalViewport.filter((line) =>
|
const topBorderCount = finalViewport.filter((line) => line.includes("╭") && line.includes("╮")).length;
|
||||||
line.includes("╭") && line.includes("╮")
|
const bottomBorderCount = finalViewport.filter((line) => line.includes("╰") && line.includes("╯")).length;
|
||||||
).length;
|
|
||||||
const bottomBorderCount = finalViewport.filter((line) =>
|
|
||||||
line.includes("╰") && line.includes("╯")
|
|
||||||
).length;
|
|
||||||
|
|
||||||
assert.strictEqual(topBorderCount, 1);
|
assert.strictEqual(topBorderCount, 1);
|
||||||
assert.strictEqual(bottomBorderCount, 1);
|
assert.strictEqual(bottomBorderCount, 1);
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env npx tsx
|
#!/usr/bin/env npx tsx
|
||||||
import { TUI, Container, TextComponent, TextEditor, MarkdownComponent } from "../src/index.js";
|
import { Container, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multi-Component Layout Demo
|
* Multi-Component Layout Demo
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import { test, describe } from "node:test";
|
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
|
import { describe, test } from "node:test";
|
||||||
|
import { Container, LoadingAnimation, MarkdownComponent, TextComponent, TextEditor, TUI } from "../src/index.js";
|
||||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||||
import { TUI, Container, TextComponent, MarkdownComponent, TextEditor, LoadingAnimation } from "../src/index.js";
|
|
||||||
|
|
||||||
describe("Multi-Message Garbled Output Reproduction", () => {
|
describe("Multi-Message Garbled Output Reproduction", () => {
|
||||||
test("handles rapid message additions with large content without garbling", async () => {
|
test("handles rapid message additions with large content without garbling", async () => {
|
||||||
|
|
@ -20,7 +20,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
|
||||||
ui.setFocus(editor);
|
ui.setFocus(editor);
|
||||||
|
|
||||||
// Initial render
|
// Initial render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 1: Simulate user message
|
// Step 1: Simulate user message
|
||||||
|
|
@ -32,7 +32,7 @@ describe("Multi-Message Garbled Output Reproduction", () => {
|
||||||
statusContainer.addChild(loadingAnim);
|
statusContainer.addChild(loadingAnim);
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 3: Simulate rapid tool calls with large outputs
|
// Step 3: Simulate rapid tool calls with large outputs
|
||||||
|
|
@ -54,7 +54,7 @@ node_modules/get-tsconfig/README.md
|
||||||
chatContainer.addChild(new TextComponent(globResult));
|
chatContainer.addChild(new TextComponent(globResult));
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Simulate multiple read tool calls with long content
|
// Simulate multiple read tool calls with long content
|
||||||
|
|
@ -74,7 +74,7 @@ A collection of tools for managing LLM deployments and building AI agents.
|
||||||
chatContainer.addChild(new MarkdownComponent(readmeContent));
|
chatContainer.addChild(new MarkdownComponent(readmeContent));
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Second read with even more content
|
// Second read with even more content
|
||||||
|
|
@ -94,7 +94,7 @@ Terminal UI framework with surgical differential rendering for building flicker-
|
||||||
chatContainer.addChild(new MarkdownComponent(tuiReadmeContent));
|
chatContainer.addChild(new MarkdownComponent(tuiReadmeContent));
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 4: Stop loading animation and add assistant response
|
// Step 4: Stop loading animation and add assistant response
|
||||||
|
|
@ -114,7 +114,7 @@ The TUI library features surgical differential rendering that minimizes screen u
|
||||||
chatContainer.addChild(new MarkdownComponent(assistantResponse));
|
chatContainer.addChild(new MarkdownComponent(assistantResponse));
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Step 5: CRITICAL - Send a new message while previous content is displayed
|
// Step 5: CRITICAL - Send a new message while previous content is displayed
|
||||||
|
|
@ -126,7 +126,7 @@ The TUI library features surgical differential rendering that minimizes screen u
|
||||||
statusContainer.addChild(loadingAnim2);
|
statusContainer.addChild(loadingAnim2);
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Add assistant response
|
// Add assistant response
|
||||||
|
|
@ -144,7 +144,7 @@ Key aspects:
|
||||||
chatContainer.addChild(new MarkdownComponent(secondResponse));
|
chatContainer.addChild(new MarkdownComponent(secondResponse));
|
||||||
|
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Debug: Show the garbled output after the problematic step
|
// Debug: Show the garbled output after the problematic step
|
||||||
|
|
@ -158,14 +158,20 @@ Key aspects:
|
||||||
const finalOutput = terminal.getScrollBuffer();
|
const finalOutput = terminal.getScrollBuffer();
|
||||||
|
|
||||||
// Check that first user message is NOT garbled
|
// Check that first user message is NOT garbled
|
||||||
const userLine1 = finalOutput.find(line => line.includes("read all README.md files"));
|
const userLine1 = finalOutput.find((line) => line.includes("read all README.md files"));
|
||||||
assert.strictEqual(userLine1, "read all README.md files except in node_modules",
|
assert.strictEqual(
|
||||||
`First user message is garbled: "${userLine1}"`);
|
userLine1,
|
||||||
|
"read all README.md files except in node_modules",
|
||||||
|
`First user message is garbled: "${userLine1}"`,
|
||||||
|
);
|
||||||
|
|
||||||
// Check that second user message is clean
|
// Check that second user message is clean
|
||||||
const userLine2 = finalOutput.find(line => line.includes("What is the main purpose"));
|
const userLine2 = finalOutput.find((line) => line.includes("What is the main purpose"));
|
||||||
assert.strictEqual(userLine2, "What is the main purpose of the TUI library?",
|
assert.strictEqual(
|
||||||
`Second user message is garbled: "${userLine2}"`);
|
userLine2,
|
||||||
|
"What is the main purpose of the TUI library?",
|
||||||
|
`Second user message is garbled: "${userLine2}"`,
|
||||||
|
);
|
||||||
|
|
||||||
// Check for common garbling patterns
|
// Check for common garbling patterns
|
||||||
const garbledPatterns = [
|
const garbledPatterns = [
|
||||||
|
|
@ -173,11 +179,11 @@ Key aspects:
|
||||||
"README.mdectly",
|
"README.mdectly",
|
||||||
"modulesl rendering",
|
"modulesl rendering",
|
||||||
"[assistant]ns.",
|
"[assistant]ns.",
|
||||||
"node_modules/@esbuild/darwin-arm64/README.mdategy"
|
"node_modules/@esbuild/darwin-arm64/README.mdategy",
|
||||||
];
|
];
|
||||||
|
|
||||||
for (const pattern of garbledPatterns) {
|
for (const pattern of garbledPatterns) {
|
||||||
const hasGarbled = finalOutput.some(line => line.includes(pattern));
|
const hasGarbled = finalOutput.some((line) => line.includes(pattern));
|
||||||
assert.ok(!hasGarbled, `Found garbled pattern "${pattern}" in output`);
|
assert.ok(!hasGarbled, `Found garbled pattern "${pattern}" in output`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,17 @@
|
||||||
import { test, describe } from "node:test";
|
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
import { describe, test } from "node:test";
|
||||||
import {
|
import {
|
||||||
TUI,
|
|
||||||
Container,
|
Container,
|
||||||
TextComponent,
|
|
||||||
TextEditor,
|
|
||||||
WhitespaceComponent,
|
|
||||||
MarkdownComponent,
|
MarkdownComponent,
|
||||||
SelectList,
|
SelectList,
|
||||||
|
TextComponent,
|
||||||
|
TextEditor,
|
||||||
|
TUI,
|
||||||
|
WhitespaceComponent,
|
||||||
} from "../src/index.js";
|
} from "../src/index.js";
|
||||||
|
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||||
|
|
||||||
describe("TUI Rendering", () => {
|
describe("TUI Rendering", () => {
|
||||||
|
|
||||||
test("renders single text component", async () => {
|
test("renders single text component", async () => {
|
||||||
const terminal = new VirtualTerminal(80, 24);
|
const terminal = new VirtualTerminal(80, 24);
|
||||||
const ui = new TUI(terminal);
|
const ui = new TUI(terminal);
|
||||||
|
|
@ -22,7 +21,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(text);
|
ui.addChild(text);
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
// Wait for writes to complete and get the rendered output
|
// Wait for writes to complete and get the rendered output
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
|
|
@ -48,7 +47,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(new TextComponent("Line 3"));
|
ui.addChild(new TextComponent("Line 3"));
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Line 1");
|
assert.strictEqual(output[0], "Line 1");
|
||||||
|
|
@ -68,7 +67,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(new TextComponent("Bottom text"));
|
ui.addChild(new TextComponent("Bottom text"));
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Top text");
|
assert.strictEqual(output[0], "Top text");
|
||||||
|
|
@ -96,7 +95,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(new TextComponent("After container"));
|
ui.addChild(new TextComponent("After container"));
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Before container");
|
assert.strictEqual(output[0], "Before container");
|
||||||
|
|
@ -117,7 +116,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.setFocus(editor);
|
ui.setFocus(editor);
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
// Initial state - empty editor with cursor
|
// Initial state - empty editor with cursor
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
|
|
@ -142,7 +141,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(dynamicText);
|
ui.addChild(dynamicText);
|
||||||
|
|
||||||
// Wait for initial render
|
// Wait for initial render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Save initial state
|
// Save initial state
|
||||||
|
|
@ -153,7 +152,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
// Wait for render
|
// Wait for render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
// Flush terminal buffer
|
// Flush terminal buffer
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
@ -180,7 +179,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(text3);
|
ui.addChild(text3);
|
||||||
|
|
||||||
// Wait for initial render
|
// Wait for initial render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
let output = await terminal.flushAndGetViewport();
|
let output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Line 1");
|
assert.strictEqual(output[0], "Line 1");
|
||||||
|
|
@ -191,7 +190,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.removeChild(text2);
|
ui.removeChild(text2);
|
||||||
ui.requestRender();
|
ui.requestRender();
|
||||||
|
|
||||||
await new Promise(resolve => setImmediate(resolve));
|
await new Promise((resolve) => setImmediate(resolve));
|
||||||
|
|
||||||
output = await terminal.flushAndGetViewport();
|
output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Line 1");
|
assert.strictEqual(output[0], "Line 1");
|
||||||
|
|
@ -212,7 +211,7 @@ describe("TUI Rendering", () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
|
|
||||||
|
|
@ -241,7 +240,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(new TextComponent("After"));
|
ui.addChild(new TextComponent("After"));
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
assert.strictEqual(output[0], "Before");
|
assert.strictEqual(output[0], "Before");
|
||||||
|
|
@ -262,7 +261,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.addChild(markdown);
|
ui.addChild(markdown);
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
// Should have formatted markdown
|
// Should have formatted markdown
|
||||||
|
|
@ -289,7 +288,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.setFocus(selectList);
|
ui.setFocus(selectList);
|
||||||
|
|
||||||
// Wait for next tick for render to complete
|
// Wait for next tick for render to complete
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
|
|
||||||
const output = await terminal.flushAndGetViewport();
|
const output = await terminal.flushAndGetViewport();
|
||||||
// First option should be selected (has → indicator)
|
// First option should be selected (has → indicator)
|
||||||
|
|
@ -334,7 +333,7 @@ describe("TUI Rendering", () => {
|
||||||
ui.setFocus(editor);
|
ui.setFocus(editor);
|
||||||
|
|
||||||
// Wait for initial render
|
// Wait for initial render
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Check that the editor is rendered after the existing content
|
// Check that the editor is rendered after the existing content
|
||||||
|
|
@ -365,7 +364,7 @@ describe("TUI Rendering", () => {
|
||||||
terminal.sendInput("Hello World");
|
terminal.sendInput("Hello World");
|
||||||
|
|
||||||
// Wait for the input to be processed
|
// Wait for the input to be processed
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Check that text appears in the editor
|
// Check that text appears in the editor
|
||||||
|
|
@ -383,7 +382,7 @@ describe("TUI Rendering", () => {
|
||||||
terminal.sendInput("\n");
|
terminal.sendInput("\n");
|
||||||
|
|
||||||
// Wait for the input to be processed
|
// Wait for the input to be processed
|
||||||
await new Promise(resolve => process.nextTick(resolve));
|
await new Promise((resolve) => process.nextTick(resolve));
|
||||||
await terminal.flush();
|
await terminal.flush();
|
||||||
|
|
||||||
// Check that existing content is still preserved after adding new line
|
// Check that existing content is still preserved after adding new line
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import { test, describe } from "node:test";
|
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
|
import { describe, test } from "node:test";
|
||||||
import { VirtualTerminal } from "./virtual-terminal.js";
|
import { VirtualTerminal } from "./virtual-terminal.js";
|
||||||
|
|
||||||
describe("VirtualTerminal", () => {
|
describe("VirtualTerminal", () => {
|
||||||
|
|
@ -86,13 +86,13 @@ describe("VirtualTerminal", () => {
|
||||||
assert.strictEqual(viewport.length, 10);
|
assert.strictEqual(viewport.length, 10);
|
||||||
assert.strictEqual(viewport[0], "Line 7");
|
assert.strictEqual(viewport[0], "Line 7");
|
||||||
assert.strictEqual(viewport[8], "Line 15");
|
assert.strictEqual(viewport[8], "Line 15");
|
||||||
assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n
|
assert.strictEqual(viewport[9], ""); // Last line is empty after the final \r\n
|
||||||
|
|
||||||
// Scroll buffer should have all lines
|
// Scroll buffer should have all lines
|
||||||
assert.ok(scrollBuffer.length >= 15);
|
assert.ok(scrollBuffer.length >= 15);
|
||||||
// Check specific lines exist in the buffer
|
// Check specific lines exist in the buffer
|
||||||
const hasLine1 = scrollBuffer.some(line => line === "Line 1");
|
const hasLine1 = scrollBuffer.some((line) => line === "Line 1");
|
||||||
const hasLine15 = scrollBuffer.some(line => line === "Line 15");
|
const hasLine15 = scrollBuffer.some((line) => line === "Line 15");
|
||||||
assert.ok(hasLine1, "Buffer should contain 'Line 1'");
|
assert.ok(hasLine1, "Buffer should contain 'Line 1'");
|
||||||
assert.ok(hasLine15, "Buffer should contain 'Line 15'");
|
assert.ok(hasLine15, "Buffer should contain 'Line 15'");
|
||||||
});
|
});
|
||||||
|
|
@ -129,9 +129,12 @@ describe("VirtualTerminal", () => {
|
||||||
const terminal = new VirtualTerminal(80, 24);
|
const terminal = new VirtualTerminal(80, 24);
|
||||||
|
|
||||||
let received = "";
|
let received = "";
|
||||||
terminal.start((data) => {
|
terminal.start(
|
||||||
received = data;
|
(data) => {
|
||||||
}, () => {});
|
received = data;
|
||||||
|
},
|
||||||
|
() => {},
|
||||||
|
);
|
||||||
|
|
||||||
terminal.sendInput("a");
|
terminal.sendInput("a");
|
||||||
assert.strictEqual(received, "a");
|
assert.strictEqual(received, "a");
|
||||||
|
|
@ -146,9 +149,12 @@ describe("VirtualTerminal", () => {
|
||||||
const terminal = new VirtualTerminal(80, 24);
|
const terminal = new VirtualTerminal(80, 24);
|
||||||
|
|
||||||
let resized = false;
|
let resized = false;
|
||||||
terminal.start(() => {}, () => {
|
terminal.start(
|
||||||
resized = true;
|
() => {},
|
||||||
});
|
() => {
|
||||||
|
resized = true;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
terminal.resize(100, 30);
|
terminal.resize(100, 30);
|
||||||
assert.strictEqual(resized, true);
|
assert.strictEqual(resized, true);
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import xterm from '@xterm/headless';
|
import type { Terminal as XtermTerminalType } from "@xterm/headless";
|
||||||
import type { Terminal as XtermTerminalType } from '@xterm/headless';
|
import xterm from "@xterm/headless";
|
||||||
import { Terminal } from '../src/terminal.js';
|
import type { Terminal } from "../src/terminal.js";
|
||||||
|
|
||||||
// Extract Terminal class from the module
|
// Extract Terminal class from the module
|
||||||
const XtermTerminal = xterm.Terminal;
|
const XtermTerminal = xterm.Terminal;
|
||||||
|
|
@ -81,7 +81,7 @@ export class VirtualTerminal implements Terminal {
|
||||||
async flush(): Promise<void> {
|
async flush(): Promise<void> {
|
||||||
// Write an empty string to ensure all previous writes are flushed
|
// Write an empty string to ensure all previous writes are flushed
|
||||||
return new Promise<void>((resolve) => {
|
return new Promise<void>((resolve) => {
|
||||||
this.xterm.write('', () => resolve());
|
this.xterm.write("", () => resolve());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -107,7 +107,7 @@ export class VirtualTerminal implements Terminal {
|
||||||
if (line) {
|
if (line) {
|
||||||
lines.push(line.translateToString(true));
|
lines.push(line.translateToString(true));
|
||||||
} else {
|
} else {
|
||||||
lines.push('');
|
lines.push("");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,7 +127,7 @@ export class VirtualTerminal implements Terminal {
|
||||||
if (line) {
|
if (line) {
|
||||||
lines.push(line.translateToString(true));
|
lines.push(line.translateToString(true));
|
||||||
} else {
|
} else {
|
||||||
lines.push('');
|
lines.push("");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -155,7 +155,7 @@ export class VirtualTerminal implements Terminal {
|
||||||
const buffer = this.xterm.buffer.active;
|
const buffer = this.xterm.buffer.active;
|
||||||
return {
|
return {
|
||||||
x: buffer.cursorX,
|
x: buffer.cursorX,
|
||||||
y: buffer.cursorY
|
y: buffer.cursorY,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue