import type { Api, AssistantMessageEventStream, Context, Model, SimpleStreamOptions, StreamFunction, StreamOptions, } from "./types.js"; export type ApiStreamFunction = ( model: Model, context: Context, options?: StreamOptions, ) => AssistantMessageEventStream; export type ApiStreamSimpleFunction = ( model: Model, context: Context, options?: SimpleStreamOptions, ) => AssistantMessageEventStream; export interface ApiProvider { api: TApi; stream: StreamFunction; streamSimple: StreamFunction; } interface ApiProviderInternal { api: Api; stream: ApiStreamFunction; streamSimple: ApiStreamSimpleFunction; } type RegisteredApiProvider = { provider: ApiProviderInternal; sourceId?: string; }; const apiProviderRegistry = new Map(); function wrapStream( api: TApi, stream: StreamFunction, ): ApiStreamFunction { return (model, context, options) => { if (model.api !== api) { throw new Error(`Mismatched api: ${model.api} expected ${api}`); } return stream(model as Model, context, options as TOptions); }; } function wrapStreamSimple( api: TApi, streamSimple: StreamFunction, ): ApiStreamSimpleFunction { return (model, context, options) => { if (model.api !== api) { throw new Error(`Mismatched api: ${model.api} expected ${api}`); } return streamSimple(model as Model, context, options); }; } export function registerApiProvider( provider: ApiProvider, sourceId?: string, ): void { apiProviderRegistry.set(provider.api, { provider: { api: provider.api, stream: wrapStream(provider.api, provider.stream), streamSimple: wrapStreamSimple(provider.api, provider.streamSimple), }, sourceId, }); } export function getApiProvider(api: Api): ApiProviderInternal | undefined { return apiProviderRegistry.get(api)?.provider; } export function getApiProviders(): ApiProviderInternal[] { return Array.from(apiProviderRegistry.values(), (entry) => entry.provider); } export function unregisterApiProviders(sourceId: string): void { for (const [api, entry] of apiProviderRegistry.entries()) { if (entry.sourceId === sourceId) { apiProviderRegistry.delete(api); } } } export function clearApiProviders(): void { apiProviderRegistry.clear(); }