mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-19 00:05:24 +00:00
feat(ai): Add new streaming generate API with AsyncIterable interface
- Implement QueuedGenerateStream class that extends AsyncIterable with finalMessage() method - Add new types: GenerateStream, GenerateOptions, GenerateOptionsUnified, GenerateFunction - Create generateAnthropic function-based implementation replacing class-based approach - Add comprehensive test suite for the new generate API - Support streaming events with text, thinking, and tool call deltas - Map ReasoningEffort to provider-specific options - Include apiKey in options instead of constructor parameter
This commit is contained in:
parent
be07c08a75
commit
004de3c9d0
6 changed files with 1106 additions and 129 deletions
|
|
@ -1,108 +1,44 @@
|
|||
import { PROVIDERS } from "./models.generated.js";
|
||||
import { AnthropicLLM } from "./providers/anthropic.js";
|
||||
import { GoogleLLM } from "./providers/google.js";
|
||||
import { OpenAICompletionsLLM } from "./providers/openai-completions.js";
|
||||
import { OpenAIResponsesLLM } from "./providers/openai-responses.js";
|
||||
import type { Model, Usage } from "./types.js";
|
||||
import type { KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
// Provider configuration with factory functions
|
||||
export const PROVIDER_CONFIG = {
|
||||
google: {
|
||||
envKey: "GEMINI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new GoogleLLM(model, apiKey),
|
||||
},
|
||||
openai: {
|
||||
envKey: "OPENAI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAIResponsesLLM(model, apiKey),
|
||||
},
|
||||
anthropic: {
|
||||
envKey: "ANTHROPIC_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new AnthropicLLM(model, apiKey),
|
||||
},
|
||||
xai: {
|
||||
envKey: "XAI_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
groq: {
|
||||
envKey: "GROQ_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
cerebras: {
|
||||
envKey: "CEREBRAS_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
openrouter: {
|
||||
envKey: "OPENROUTER_API_KEY",
|
||||
create: (model: Model, apiKey: string) => new OpenAICompletionsLLM(model, apiKey),
|
||||
},
|
||||
} as const;
|
||||
// Re-export Model type
|
||||
export type { KnownProvider, Model } from "./types.js";
|
||||
|
||||
// Type mapping from provider to LLM implementation
|
||||
export type ProviderToLLM = {
|
||||
google: GoogleLLM;
|
||||
openai: OpenAIResponsesLLM;
|
||||
anthropic: AnthropicLLM;
|
||||
xai: OpenAICompletionsLLM;
|
||||
groq: OpenAICompletionsLLM;
|
||||
cerebras: OpenAICompletionsLLM;
|
||||
openrouter: OpenAICompletionsLLM;
|
||||
};
|
||||
// Dynamic model registry initialized from PROVIDERS
|
||||
const modelRegistry: Map<string, Map<string, Model>> = new Map();
|
||||
|
||||
// Extract model types for each provider
|
||||
export type GoogleModel = keyof typeof PROVIDERS.google.models;
|
||||
export type OpenAIModel = keyof typeof PROVIDERS.openai.models;
|
||||
export type AnthropicModel = keyof typeof PROVIDERS.anthropic.models;
|
||||
export type XAIModel = keyof typeof PROVIDERS.xai.models;
|
||||
export type GroqModel = keyof typeof PROVIDERS.groq.models;
|
||||
export type CerebrasModel = keyof typeof PROVIDERS.cerebras.models;
|
||||
export type OpenRouterModel = keyof typeof PROVIDERS.openrouter.models;
|
||||
|
||||
// Map providers to their model types
|
||||
export type ProviderModels = {
|
||||
google: GoogleModel;
|
||||
openai: OpenAIModel;
|
||||
anthropic: AnthropicModel;
|
||||
xai: XAIModel;
|
||||
groq: GroqModel;
|
||||
cerebras: CerebrasModel;
|
||||
openrouter: OpenRouterModel;
|
||||
};
|
||||
|
||||
// Single generic factory function
|
||||
export function createLLM<P extends keyof typeof PROVIDERS, M extends keyof (typeof PROVIDERS)[P]["models"]>(
|
||||
provider: P,
|
||||
model: M,
|
||||
apiKey?: string,
|
||||
): ProviderToLLM[P] {
|
||||
const config = PROVIDER_CONFIG[provider as keyof typeof PROVIDER_CONFIG];
|
||||
if (!config) throw new Error(`Unknown provider: ${provider}`);
|
||||
|
||||
const providerData = PROVIDERS[provider];
|
||||
if (!providerData) throw new Error(`Unknown provider: ${provider}`);
|
||||
|
||||
// Type-safe model lookup
|
||||
const models = providerData.models as Record<string, Model>;
|
||||
const modelData = models[model as string];
|
||||
if (!modelData) throw new Error(`Unknown model: ${String(model)} for provider ${provider}`);
|
||||
|
||||
const key = apiKey || process.env[config.envKey];
|
||||
if (!key) throw new Error(`No API key provided for ${provider}. Set ${config.envKey} or pass apiKey.`);
|
||||
|
||||
return config.create(modelData, key) as ProviderToLLM[P];
|
||||
// Initialize registry from PROVIDERS on module load
|
||||
for (const [provider, models] of Object.entries(PROVIDERS)) {
|
||||
const providerModels = new Map<string, Model>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
// Helper function to get model info with type-safe model IDs
|
||||
export function getModel<P extends keyof typeof PROVIDERS>(
|
||||
provider: P,
|
||||
modelId: keyof (typeof PROVIDERS)[P]["models"],
|
||||
): Model | undefined {
|
||||
const providerData = PROVIDERS[provider];
|
||||
if (!providerData) return undefined;
|
||||
const models = providerData.models as Record<string, Model>;
|
||||
return models[modelId as string];
|
||||
/**
|
||||
* Get a model from the registry - typed overload for known providers
|
||||
*/
|
||||
export function getModel<P extends KnownProvider>(provider: P, modelId: keyof (typeof PROVIDERS)[P]): Model;
|
||||
export function getModel(provider: string, modelId: string): Model | undefined;
|
||||
export function getModel(provider: any, modelId: any): Model | undefined {
|
||||
return modelRegistry.get(provider)?.get(modelId);
|
||||
}
|
||||
|
||||
export function calculateCost(model: Model, usage: Usage) {
|
||||
/**
|
||||
* Register a custom model
|
||||
*/
|
||||
export function registerModel(model: Model): void {
|
||||
if (!modelRegistry.has(model.provider)) {
|
||||
modelRegistry.set(model.provider, new Map());
|
||||
}
|
||||
modelRegistry.get(model.provider)!.set(model.id, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cost for token usage
|
||||
*/
|
||||
export function calculateCost(model: Model, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
|
|
@ -110,6 +46,3 @@ export function calculateCost(model: Model, usage: Usage) {
|
|||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
// Re-export Model type for convenience
|
||||
export type { Model };
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue