co-mono/packages/ai/src/api-registry.ts
2026-01-24 23:15:11 +01:00

98 lines
2.5 KiB
TypeScript

import type {
Api,
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "./types.js";
export type ApiStreamFunction = (
model: Model<Api>,
context: Context,
options?: StreamOptions,
) => AssistantMessageEventStream;
export type ApiStreamSimpleFunction = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
export interface ApiProvider<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> {
api: TApi;
stream: StreamFunction<TApi, TOptions>;
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
}
interface ApiProviderInternal {
api: Api;
stream: ApiStreamFunction;
streamSimple: ApiStreamSimpleFunction;
}
type RegisteredApiProvider = {
provider: ApiProviderInternal;
sourceId?: string;
};
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
api: TApi,
stream: StreamFunction<TApi, TOptions>,
): ApiStreamFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return stream(model as Model<TApi>, context, options as TOptions);
};
}
function wrapStreamSimple<TApi extends Api>(
api: TApi,
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
): ApiStreamSimpleFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return streamSimple(model as Model<TApi>, context, options);
};
}
export function registerApiProvider<TApi extends Api, TOptions extends StreamOptions>(
provider: ApiProvider<TApi, TOptions>,
sourceId?: string,
): void {
apiProviderRegistry.set(provider.api, {
provider: {
api: provider.api,
stream: wrapStream(provider.api, provider.stream),
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
},
sourceId,
});
}
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
return apiProviderRegistry.get(api)?.provider;
}
export function getApiProviders(): ApiProviderInternal[] {
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
}
export function unregisterApiProviders(sourceId: string): void {
for (const [api, entry] of apiProviderRegistry.entries()) {
if (entry.sourceId === sourceId) {
apiProviderRegistry.delete(api);
}
}
}
export function clearApiProviders(): void {
apiProviderRegistry.clear();
}