mirror of
https://github.com/harivansh-afk/clanker-agent.git
synced 2026-04-20 09:01:52 +00:00
move pi-mono into companion-cloud as apps/companion-os
- Copy all pi-mono source into apps/companion-os/ - Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases - Update deploy-staging.yml to build pi from source (bun compile) before Docker build - Add apps/companion-os/** to path triggers - No more cross-repo dispatch needed Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
0250f72976
579 changed files with 206942 additions and 0 deletions
101
packages/ai/src/api-registry.ts
Normal file
101
packages/ai/src/api-registry.ts
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export type ApiStreamFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export type ApiStreamSimpleFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface ApiProvider<
|
||||
TApi extends Api = Api,
|
||||
TOptions extends StreamOptions = StreamOptions,
|
||||
> {
|
||||
api: TApi;
|
||||
stream: StreamFunction<TApi, TOptions>;
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface ApiProviderInternal {
|
||||
api: Api;
|
||||
stream: ApiStreamFunction;
|
||||
streamSimple: ApiStreamSimpleFunction;
|
||||
}
|
||||
|
||||
type RegisteredApiProvider = {
|
||||
provider: ApiProviderInternal;
|
||||
sourceId?: string;
|
||||
};
|
||||
|
||||
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
|
||||
|
||||
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
|
||||
api: TApi,
|
||||
stream: StreamFunction<TApi, TOptions>,
|
||||
): ApiStreamFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return stream(model as Model<TApi>, context, options as TOptions);
|
||||
};
|
||||
}
|
||||
|
||||
function wrapStreamSimple<TApi extends Api>(
|
||||
api: TApi,
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
|
||||
): ApiStreamSimpleFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return streamSimple(model as Model<TApi>, context, options);
|
||||
};
|
||||
}
|
||||
|
||||
export function registerApiProvider<
|
||||
TApi extends Api,
|
||||
TOptions extends StreamOptions,
|
||||
>(provider: ApiProvider<TApi, TOptions>, sourceId?: string): void {
|
||||
apiProviderRegistry.set(provider.api, {
|
||||
provider: {
|
||||
api: provider.api,
|
||||
stream: wrapStream(provider.api, provider.stream),
|
||||
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
|
||||
},
|
||||
sourceId,
|
||||
});
|
||||
}
|
||||
|
||||
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
|
||||
return apiProviderRegistry.get(api)?.provider;
|
||||
}
|
||||
|
||||
export function getApiProviders(): ApiProviderInternal[] {
|
||||
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
|
||||
}
|
||||
|
||||
export function unregisterApiProviders(sourceId: string): void {
|
||||
for (const [api, entry] of apiProviderRegistry.entries()) {
|
||||
if (entry.sourceId === sourceId) {
|
||||
apiProviderRegistry.delete(api);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function clearApiProviders(): void {
|
||||
apiProviderRegistry.clear();
|
||||
}
|
||||
9
packages/ai/src/bedrock-provider.ts
Normal file
9
packages/ai/src/bedrock-provider.ts
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
} from "./providers/amazon-bedrock.js";
|
||||
|
||||
export const bedrockProviderModule = {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
};
|
||||
152
packages/ai/src/cli.ts
Normal file
152
packages/ai/src/cli.ts
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { createInterface } from "readline";
|
||||
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
|
||||
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
|
||||
|
||||
const AUTH_FILE = "auth.json";
|
||||
const PROVIDERS = getOAuthProviders();
|
||||
|
||||
function prompt(
|
||||
rl: ReturnType<typeof createInterface>,
|
||||
question: string,
|
||||
): Promise<string> {
|
||||
return new Promise((resolve) => rl.question(question, resolve));
|
||||
}
|
||||
|
||||
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
|
||||
if (!existsSync(AUTH_FILE)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveAuth(
|
||||
auth: Record<string, { type: "oauth" } & OAuthCredentials>,
|
||||
): void {
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
async function login(providerId: OAuthProviderId): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
console.error(`Unknown provider: ${providerId}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
||||
|
||||
try {
|
||||
const credentials = await provider.login({
|
||||
onAuth: (info) => {
|
||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
||||
if (info.instructions) console.log(info.instructions);
|
||||
console.log();
|
||||
},
|
||||
onPrompt: async (p) => {
|
||||
return await promptFn(
|
||||
`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`,
|
||||
);
|
||||
},
|
||||
onProgress: (msg) => console.log(msg),
|
||||
});
|
||||
|
||||
const auth = loadAuth();
|
||||
auth[providerId] = { type: "oauth", ...credentials };
|
||||
saveAuth(auth);
|
||||
|
||||
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
||||
} finally {
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
const command = args[0];
|
||||
|
||||
if (
|
||||
!command ||
|
||||
command === "help" ||
|
||||
command === "--help" ||
|
||||
command === "-h"
|
||||
) {
|
||||
const providerList = PROVIDERS.map(
|
||||
(p) => ` ${p.id.padEnd(20)} ${p.name}`,
|
||||
).join("\n");
|
||||
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
|
||||
|
||||
Commands:
|
||||
login [provider] Login to an OAuth provider
|
||||
list List available providers
|
||||
|
||||
Providers:
|
||||
${providerList}
|
||||
|
||||
Examples:
|
||||
npx @mariozechner/pi-ai login # interactive provider selection
|
||||
npx @mariozechner/pi-ai login anthropic # login to specific provider
|
||||
npx @mariozechner/pi-ai list # list providers
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "list") {
|
||||
console.log("Available OAuth providers:\n");
|
||||
for (const p of PROVIDERS) {
|
||||
console.log(` ${p.id.padEnd(20)} ${p.name}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "login") {
|
||||
let provider = args[1] as OAuthProviderId | undefined;
|
||||
|
||||
if (!provider) {
|
||||
const rl = createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
console.log("Select a provider:\n");
|
||||
for (let i = 0; i < PROVIDERS.length; i++) {
|
||||
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
|
||||
rl.close();
|
||||
|
||||
const index = parseInt(choice, 10) - 1;
|
||||
if (index < 0 || index >= PROVIDERS.length) {
|
||||
console.error("Invalid selection");
|
||||
process.exit(1);
|
||||
}
|
||||
provider = PROVIDERS[index].id;
|
||||
}
|
||||
|
||||
if (!PROVIDERS.some((p) => p.id === provider)) {
|
||||
console.error(`Unknown provider: ${provider}`);
|
||||
console.error(
|
||||
`Use 'npx @mariozechner/pi-ai list' to see available providers`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(`Logging in to ${provider}...`);
|
||||
await login(provider);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error(`Unknown command: ${command}`);
|
||||
console.error(`Use 'npx @mariozechner/pi-ai --help' for usage`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Error:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
145
packages/ai/src/env-api-keys.ts
Normal file
145
packages/ai/src/env-api-keys.ts
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _existsSync: typeof import("node:fs").existsSync | null = null;
|
||||
let _homedir: typeof import("node:os").homedir | null = null;
|
||||
let _join: typeof import("node:path").join | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_FS_SPECIFIER = "node:" + "fs";
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
const NODE_PATH_SPECIFIER = "node:" + "path";
|
||||
|
||||
// Eagerly load in Node.js/Bun environment only
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
|
||||
_existsSync = (m as typeof import("node:fs")).existsSync;
|
||||
});
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_homedir = (m as typeof import("node:os")).homedir;
|
||||
});
|
||||
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
|
||||
_join = (m as typeof import("node:path")).join;
|
||||
});
|
||||
}
|
||||
|
||||
import type { KnownProvider } from "./types.js";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
if (cachedVertexAdcCredentialsExists === null) {
|
||||
// If node modules haven't loaded yet (async import race at startup),
|
||||
// return false WITHOUT caching so the next call retries once they're ready.
|
||||
// Only cache false permanently in a browser environment where fs is never available.
|
||||
if (!_existsSync || !_homedir || !_join) {
|
||||
const isNode =
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun);
|
||||
if (!isNode) {
|
||||
// Definitively in a browser — safe to cache false permanently
|
||||
cachedVertexAdcCredentialsExists = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
if (gacPath) {
|
||||
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
|
||||
} else {
|
||||
// Fall back to default ADC path (lazy evaluation)
|
||||
cachedVertexAdcCredentialsExists = _existsSync(
|
||||
_join(
|
||||
_homedir(),
|
||||
".config",
|
||||
"gcloud",
|
||||
"application_default_credentials.json",
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: any): string | undefined {
|
||||
// Fall back to environment variables
|
||||
if (provider === "github-copilot") {
|
||||
return (
|
||||
process.env.COPILOT_GITHUB_TOKEN ||
|
||||
process.env.GH_TOKEN ||
|
||||
process.env.GITHUB_TOKEN
|
||||
);
|
||||
}
|
||||
|
||||
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// Vertex AI uses Application Default Credentials, not API keys.
|
||||
// Auth is configured via `gcloud auth application-default login`.
|
||||
if (provider === "google-vertex") {
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasProject = !!(
|
||||
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
|
||||
);
|
||||
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
|
||||
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
if (provider === "amazon-bedrock") {
|
||||
// Amazon Bedrock supports multiple credential sources:
|
||||
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
|
||||
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
|
||||
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
|
||||
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
|
||||
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
|
||||
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
|
||||
if (
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
const envMap: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
mistral: "MISTRAL_API_KEY",
|
||||
minimax: "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
huggingface: "HF_TOKEN",
|
||||
opencode: "OPENCODE_API_KEY",
|
||||
"opencode-go": "OPENCODE_API_KEY",
|
||||
"kimi-coding": "KIMI_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
32
packages/ai/src/index.ts
Normal file
32
packages/ai/src/index.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
export type { Static, TSchema } from "@sinclair/typebox";
|
||||
export { Type } from "@sinclair/typebox";
|
||||
|
||||
export * from "./api-registry.js";
|
||||
export * from "./env-api-keys.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export * from "./providers/azure-openai-responses.js";
|
||||
export * from "./providers/google.js";
|
||||
export * from "./providers/google-gemini-cli.js";
|
||||
export * from "./providers/google-vertex.js";
|
||||
export * from "./providers/mistral.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./providers/register-builtins.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
export * from "./utils/json-parse.js";
|
||||
export type {
|
||||
OAuthAuthInfo,
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthPrompt,
|
||||
OAuthProvider,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInfo,
|
||||
OAuthProviderInterface,
|
||||
} from "./utils/oauth/types.js";
|
||||
export * from "./utils/overflow.js";
|
||||
export * from "./utils/typebox-helpers.js";
|
||||
export * from "./utils/validation.js";
|
||||
13496
packages/ai/src/models.generated.ts
Normal file
13496
packages/ai/src/models.generated.ts
Normal file
File diff suppressed because it is too large
Load diff
101
packages/ai/src/models.ts
Normal file
101
packages/ai/src/models.ts
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import { MODELS } from "./models.generated.js";
|
||||
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
// Initialize registry from MODELS on module load
|
||||
for (const [provider, models] of Object.entries(MODELS)) {
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi }
|
||||
? TApi extends Api
|
||||
? TApi
|
||||
: never
|
||||
: never;
|
||||
|
||||
export function getModel<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>> {
|
||||
const providerModels = modelRegistry.get(provider);
|
||||
return providerModels?.get(modelId as string) as Model<
|
||||
ModelApi<TProvider, TModelId>
|
||||
>;
|
||||
}
|
||||
|
||||
export function getProviders(): KnownProvider[] {
|
||||
return Array.from(modelRegistry.keys()) as KnownProvider[];
|
||||
}
|
||||
|
||||
export function getModels<TProvider extends KnownProvider>(
|
||||
provider: TProvider,
|
||||
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
|
||||
const models = modelRegistry.get(provider);
|
||||
return models
|
||||
? (Array.from(models.values()) as Model<
|
||||
ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>
|
||||
>[])
|
||||
: [];
|
||||
}
|
||||
|
||||
export function calculateCost<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
usage: Usage,
|
||||
): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
|
||||
usage.cost.total =
|
||||
usage.cost.input +
|
||||
usage.cost.output +
|
||||
usage.cost.cacheRead +
|
||||
usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports xhigh thinking level.
|
||||
*
|
||||
* Supported today:
|
||||
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
|
||||
* - Anthropic Messages API Opus 4.6 models (xhigh maps to adaptive effort "max")
|
||||
*/
|
||||
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
|
||||
if (
|
||||
model.id.includes("gpt-5.2") ||
|
||||
model.id.includes("gpt-5.3") ||
|
||||
model.id.includes("gpt-5.4")
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (model.api === "anthropic-messages") {
|
||||
return model.id.includes("opus-4-6") || model.id.includes("opus-4.6");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two models are equal by comparing both their id and provider.
|
||||
* Returns false if either model is null or undefined.
|
||||
*/
|
||||
export function modelsAreEqual<TApi extends Api>(
|
||||
a: Model<TApi> | null | undefined,
|
||||
b: Model<TApi> | null | undefined,
|
||||
): boolean {
|
||||
if (!a || !b) return false;
|
||||
return a.id === b.id && a.provider === b.provider;
|
||||
}
|
||||
1
packages/ai/src/oauth.ts
Normal file
1
packages/ai/src/oauth.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./utils/oauth/index.js";
|
||||
894
packages/ai/src/providers/amazon-bedrock.ts
Normal file
894
packages/ai/src/providers/amazon-bedrock.ts
Normal file
|
|
@ -0,0 +1,894 @@
|
|||
import {
|
||||
BedrockRuntimeClient,
|
||||
type BedrockRuntimeClientConfig,
|
||||
StopReason as BedrockStopReason,
|
||||
type Tool as BedrockTool,
|
||||
CachePointType,
|
||||
CacheTTL,
|
||||
type ContentBlock,
|
||||
type ContentBlockDeltaEvent,
|
||||
type ContentBlockStartEvent,
|
||||
type ContentBlockStopEvent,
|
||||
ConversationRole,
|
||||
ConverseStreamCommand,
|
||||
type ConverseStreamMetadataEvent,
|
||||
ImageFormat,
|
||||
type Message,
|
||||
type SystemContentBlock,
|
||||
type ToolChoice,
|
||||
type ToolConfiguration,
|
||||
ToolResultStatus,
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
adjustMaxTokensForThinking,
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
} from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
export interface BedrockOptions extends StreamOptions {
|
||||
region?: string;
|
||||
profile?: string;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
|
||||
reasoning?: ThinkingLevel;
|
||||
/* Custom token budgets per thinking level. Overrides default budgets. */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
|
||||
interleavedThinking?: boolean;
|
||||
}
|
||||
|
||||
type Block = (TextContent | ThinkingContent | ToolCall) & {
|
||||
index?: number;
|
||||
partialJson?: string;
|
||||
};
|
||||
|
||||
export const streamBedrock: StreamFunction<
|
||||
"bedrock-converse-stream",
|
||||
BedrockOptions
|
||||
> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options: BedrockOptions = {},
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
const config: BedrockRuntimeClientConfig = {
|
||||
profile: options.profile,
|
||||
};
|
||||
|
||||
// in Node.js/Bun environment only
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
// Region resolution: explicit option > env vars > SDK default chain.
|
||||
// When AWS_PROFILE is set, we leave region undefined so the SDK can
|
||||
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
|
||||
const explicitRegion =
|
||||
options.region ||
|
||||
process.env.AWS_REGION ||
|
||||
process.env.AWS_DEFAULT_REGION;
|
||||
if (explicitRegion) {
|
||||
config.region = explicitRegion;
|
||||
} else if (!process.env.AWS_PROFILE) {
|
||||
config.region = "us-east-1";
|
||||
}
|
||||
|
||||
// Support proxies that don't need authentication
|
||||
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
|
||||
config.credentials = {
|
||||
accessKeyId: "dummy-access-key",
|
||||
secretAccessKey: "dummy-secret-key",
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.HTTP_PROXY ||
|
||||
process.env.HTTPS_PROXY ||
|
||||
process.env.NO_PROXY ||
|
||||
process.env.http_proxy ||
|
||||
process.env.https_proxy ||
|
||||
process.env.no_proxy
|
||||
) {
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
const proxyAgent = await import("proxy-agent");
|
||||
|
||||
const agent = new proxyAgent.ProxyAgent();
|
||||
|
||||
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
|
||||
// on `http2` module and has no support for http agent.
|
||||
// Use NodeHttpHandler to support http agent.
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
|
||||
httpAgent: agent,
|
||||
httpsAgent: agent,
|
||||
});
|
||||
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
|
||||
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
|
||||
}
|
||||
} else {
|
||||
// Non-Node environment (browser): fall back to us-east-1 since
|
||||
// there's no config file resolution available.
|
||||
config.region = options.region || "us-east-1";
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new BedrockRuntimeClient(config);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
const commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
inferenceConfig: {
|
||||
maxTokens: options.maxTokens,
|
||||
temperature: options.temperature,
|
||||
},
|
||||
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
||||
additionalModelRequestFields: buildAdditionalModelRequestFields(
|
||||
model,
|
||||
options,
|
||||
),
|
||||
};
|
||||
options?.onPayload?.(commandInput);
|
||||
const command = new ConverseStreamCommand(commandInput);
|
||||
|
||||
const response = await client.send(command, {
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
for await (const item of response.stream!) {
|
||||
if (item.messageStart) {
|
||||
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
|
||||
throw new Error(
|
||||
"Unexpected assistant message start but got user message start instead",
|
||||
);
|
||||
}
|
||||
stream.push({ type: "start", partial: output });
|
||||
} else if (item.contentBlockStart) {
|
||||
handleContentBlockStart(
|
||||
item.contentBlockStart,
|
||||
blocks,
|
||||
output,
|
||||
stream,
|
||||
);
|
||||
} else if (item.contentBlockDelta) {
|
||||
handleContentBlockDelta(
|
||||
item.contentBlockDelta,
|
||||
blocks,
|
||||
output,
|
||||
stream,
|
||||
);
|
||||
} else if (item.contentBlockStop) {
|
||||
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
|
||||
} else if (item.messageStop) {
|
||||
output.stopReason = mapStopReason(item.messageStop.stopReason);
|
||||
} else if (item.metadata) {
|
||||
handleMetadata(item.metadata, model, output);
|
||||
} else if (item.internalServerException) {
|
||||
throw new Error(
|
||||
`Internal server error: ${item.internalServerException.message}`,
|
||||
);
|
||||
} else if (item.modelStreamErrorException) {
|
||||
throw new Error(
|
||||
`Model stream error: ${item.modelStreamErrorException.message}`,
|
||||
);
|
||||
} else if (item.validationException) {
|
||||
throw new Error(
|
||||
`Validation error: ${item.validationException.message}`,
|
||||
);
|
||||
} else if (item.throttlingException) {
|
||||
throw new Error(
|
||||
`Throttling error: ${item.throttlingException.message}`,
|
||||
);
|
||||
} else if (item.serviceUnavailableException) {
|
||||
throw new Error(
|
||||
`Service unavailable: ${item.serviceUnavailableException.message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "error" || output.stopReason === "aborted") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
delete (block as Block).index;
|
||||
delete (block as Block).partialJson;
|
||||
}
|
||||
output.stopReason = options.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleBedrock: StreamFunction<
|
||||
"bedrock-converse-stream",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: undefined,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
if (
|
||||
model.id.includes("anthropic.claude") ||
|
||||
model.id.includes("anthropic/claude")
|
||||
) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: {
|
||||
...(options.thinkingBudgets || {}),
|
||||
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
|
||||
},
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
};
|
||||
|
||||
function handleContentBlockStart(
|
||||
event: ContentBlockStartEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = event.contentBlockIndex!;
|
||||
const start = event.start;
|
||||
|
||||
if (start?.toolUse) {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: start.toolUse.toolUseId || "",
|
||||
name: start.toolUse.name || "",
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blocks.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockDelta(
|
||||
event: ContentBlockDeltaEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const contentBlockIndex = event.contentBlockIndex!;
|
||||
const delta = event.delta;
|
||||
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
|
||||
let block = blocks[index];
|
||||
|
||||
if (delta?.text !== undefined) {
|
||||
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
|
||||
if (!block) {
|
||||
const newBlock: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
output.content.push(newBlock);
|
||||
index = blocks.length - 1;
|
||||
block = blocks[index];
|
||||
stream.push({ type: "text_start", contentIndex: index, partial: output });
|
||||
}
|
||||
if (block.type === "text") {
|
||||
block.text += delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (delta?.toolUse && block?.type === "toolCall") {
|
||||
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: delta.toolUse.input || "",
|
||||
partial: output,
|
||||
});
|
||||
} else if (delta?.reasoningContent) {
|
||||
let thinkingBlock = block;
|
||||
let thinkingIndex = index;
|
||||
|
||||
if (!thinkingBlock) {
|
||||
const newBlock: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
output.content.push(newBlock);
|
||||
thinkingIndex = blocks.length - 1;
|
||||
thinkingBlock = blocks[thinkingIndex];
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: thinkingIndex,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (thinkingBlock?.type === "thinking") {
|
||||
if (delta.reasoningContent.text) {
|
||||
thinkingBlock.thinking += delta.reasoningContent.text;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: thinkingIndex,
|
||||
delta: delta.reasoningContent.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
if (delta.reasoningContent.signature) {
|
||||
thinkingBlock.thinkingSignature =
|
||||
(thinkingBlock.thinkingSignature || "") +
|
||||
delta.reasoningContent.signature;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleMetadata(
|
||||
event: ConverseStreamMetadataEvent,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
if (event.usage) {
|
||||
output.usage.input = event.usage.inputTokens || 0;
|
||||
output.usage.output = event.usage.outputTokens || 0;
|
||||
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
|
||||
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
|
||||
output.usage.totalTokens =
|
||||
event.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockStop(
|
||||
event: ContentBlockStopEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
|
||||
const block = blocks[index];
|
||||
if (!block) return;
|
||||
delete (block as Block).index;
|
||||
|
||||
switch (block.type) {
|
||||
case "text":
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
case "toolCall":
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as Block).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
function mapThinkingLevelToEffort(
|
||||
level: SimpleStreamOptions["reasoning"],
|
||||
modelId: string,
|
||||
): "low" | "medium" | "high" | "max" {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
|
||||
? "max"
|
||||
: "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports prompt caching.
|
||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||
*/
|
||||
function supportsPromptCaching(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
): boolean {
|
||||
if (model.cost.cacheRead || model.cost.cacheWrite) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const id = model.id.toLowerCase();
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4.")))
|
||||
return true;
|
||||
// Claude 3.7 Sonnet
|
||||
if (id.includes("claude-3-7-sonnet")) return true;
|
||||
// Claude 3.5 Haiku
|
||||
if (id.includes("claude-3-5-haiku")) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports thinking signatures in reasoningContent.
|
||||
* Only Anthropic Claude models support the signature field.
|
||||
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
|
||||
* "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
*/
|
||||
function supportsThinkingSignature(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
): boolean {
|
||||
const id = model.id.toLowerCase();
|
||||
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
|
||||
}
|
||||
|
||||
function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [
|
||||
{ text: sanitizeSurrogates(systemPrompt) },
|
||||
];
|
||||
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
blocks.push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function normalizeToolCallId(id: string): string {
|
||||
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessages(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const m = transformedMessages[i];
|
||||
|
||||
switch (m.role) {
|
||||
case "user":
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content:
|
||||
typeof m.content === "string"
|
||||
? [{ text: sanitizeSurrogates(m.content) }]
|
||||
: m.content.map((c) => {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
return { text: sanitizeSurrogates(c.text) };
|
||||
case "image":
|
||||
return { image: createImageBlock(c.mimeType, c.data) };
|
||||
default:
|
||||
throw new Error("Unknown user content type");
|
||||
}
|
||||
}),
|
||||
});
|
||||
break;
|
||||
case "assistant": {
|
||||
// Skip assistant messages with empty content (e.g., from aborted requests)
|
||||
// Bedrock rejects messages with empty content arrays
|
||||
if (m.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const contentBlocks: ContentBlock[] = [];
|
||||
for (const c of m.content) {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
// Skip empty text blocks
|
||||
if (c.text.trim().length === 0) continue;
|
||||
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
|
||||
break;
|
||||
case "toolCall":
|
||||
contentBlocks.push({
|
||||
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
// Skip empty thinking blocks
|
||||
if (c.thinking.trim().length === 0) continue;
|
||||
// Only Anthropic models support the signature field in reasoningText.
|
||||
// For other models, we omit the signature to avoid errors like:
|
||||
// "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
if (supportsThinkingSignature(model)) {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: {
|
||||
text: sanitizeSurrogates(c.thinking),
|
||||
signature: c.thinkingSignature,
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: { text: sanitizeSurrogates(c.thinking) },
|
||||
},
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unknown assistant content type");
|
||||
}
|
||||
}
|
||||
// Skip if all content blocks were filtered out
|
||||
if (contentBlocks.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.push({
|
||||
role: ConversationRole.ASSISTANT,
|
||||
content: contentBlocks,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "toolResult": {
|
||||
// Collect all consecutive toolResult messages into a single user message
|
||||
// Bedrock requires all tool results to be in one message
|
||||
const toolResults: ContentBlock.ToolResultMember[] = [];
|
||||
|
||||
// Add current tool result with all content blocks combined
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: m.toolCallId,
|
||||
content: m.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: m.isError
|
||||
? ToolResultStatus.ERROR
|
||||
: ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult"
|
||||
) {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: nextMsg.toolCallId,
|
||||
content: nextMsg.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: nextMsg.isError
|
||||
? ToolResultStatus.ERROR
|
||||
: ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content: toolResults,
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error("Unknown message role");
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (
|
||||
cacheRetention !== "none" &&
|
||||
supportsPromptCaching(model) &&
|
||||
result.length > 0
|
||||
) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function convertToolConfig(
|
||||
tools: Tool[] | undefined,
|
||||
toolChoice: BedrockOptions["toolChoice"],
|
||||
): ToolConfiguration | undefined {
|
||||
if (!tools?.length || toolChoice === "none") return undefined;
|
||||
|
||||
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
|
||||
toolSpec: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
inputSchema: { json: tool.parameters },
|
||||
},
|
||||
}));
|
||||
|
||||
let bedrockToolChoice: ToolChoice | undefined;
|
||||
switch (toolChoice) {
|
||||
case "auto":
|
||||
bedrockToolChoice = { auto: {} };
|
||||
break;
|
||||
case "any":
|
||||
bedrockToolChoice = { any: {} };
|
||||
break;
|
||||
default:
|
||||
if (toolChoice?.type === "tool") {
|
||||
bedrockToolChoice = { tool: { name: toolChoice.name } };
|
||||
}
|
||||
}
|
||||
|
||||
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
|
||||
}
|
||||
|
||||
function mapStopReason(reason: string | undefined): StopReason {
|
||||
switch (reason) {
|
||||
case BedrockStopReason.END_TURN:
|
||||
case BedrockStopReason.STOP_SEQUENCE:
|
||||
return "stop";
|
||||
case BedrockStopReason.MAX_TOKENS:
|
||||
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
|
||||
return "length";
|
||||
case BedrockStopReason.TOOL_USE:
|
||||
return "toolUse";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
function buildAdditionalModelRequestFields(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
options: BedrockOptions,
|
||||
): Record<string, any> | undefined {
|
||||
if (!options.reasoning || !model.reasoning) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
model.id.includes("anthropic.claude") ||
|
||||
model.id.includes("anthropic/claude")
|
||||
) {
|
||||
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
|
||||
? {
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: {
|
||||
effort: mapThinkingLevelToEffort(options.reasoning, model.id),
|
||||
},
|
||||
}
|
||||
: (() => {
|
||||
const defaultBudgets: Record<ThinkingLevel, number> = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
|
||||
};
|
||||
|
||||
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
|
||||
const level =
|
||||
options.reasoning === "xhigh" ? "high" : options.reasoning;
|
||||
const budget =
|
||||
options.thinkingBudgets?.[level] ??
|
||||
defaultBudgets[options.reasoning];
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budget_tokens: budget,
|
||||
},
|
||||
};
|
||||
})();
|
||||
|
||||
if (
|
||||
!supportsAdaptiveThinking(model.id) &&
|
||||
(options.interleavedThinking ?? true)
|
||||
) {
|
||||
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function createImageBlock(mimeType: string, data: string) {
|
||||
let format: ImageFormat;
|
||||
switch (mimeType) {
|
||||
case "image/jpeg":
|
||||
case "image/jpg":
|
||||
format = ImageFormat.JPEG;
|
||||
break;
|
||||
case "image/png":
|
||||
format = ImageFormat.PNG;
|
||||
break;
|
||||
case "image/gif":
|
||||
format = ImageFormat.GIF;
|
||||
break;
|
||||
case "image/webp":
|
||||
format = ImageFormat.WEBP;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown image type: ${mimeType}`);
|
||||
}
|
||||
|
||||
const binaryString = atob(data);
|
||||
const bytes = new Uint8Array(binaryString.length);
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
|
||||
return { source: { bytes }, format };
|
||||
}
|
||||
989
packages/ai/src/providers/anthropic.ts
Normal file
989
packages/ai/src/providers/anthropic.ts
Normal file
|
|
@ -0,0 +1,989 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
|
||||
import {
|
||||
buildCopilotDynamicHeaders,
|
||||
hasCopilotVisionInput,
|
||||
} from "./github-copilot-headers.js";
|
||||
import {
|
||||
adjustMaxTokensForThinking,
|
||||
buildBaseOptions,
|
||||
} from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
function getCacheControl(
|
||||
baseUrl: string,
|
||||
cacheRetention?: CacheRetention,
|
||||
): {
|
||||
retention: CacheRetention;
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" };
|
||||
} {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
const ttl =
|
||||
retention === "long" && baseUrl.includes("api.anthropic.com")
|
||||
? "1h"
|
||||
: undefined;
|
||||
return {
|
||||
retention,
|
||||
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||
};
|
||||
}
|
||||
|
||||
// Stealth mode: Mimic Claude Code's tool naming exactly
|
||||
const claudeCodeVersion = "2.1.62";
|
||||
|
||||
// Claude Code 2.x tool names (canonical casing)
|
||||
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
|
||||
// To update: https://github.com/badlogic/cchistory
|
||||
const claudeCodeTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Bash",
|
||||
"Grep",
|
||||
"Glob",
|
||||
"AskUserQuestion",
|
||||
"EnterPlanMode",
|
||||
"ExitPlanMode",
|
||||
"KillShell",
|
||||
"NotebookEdit",
|
||||
"Skill",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"TodoWrite",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
];
|
||||
|
||||
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
|
||||
|
||||
// Convert tool name to CC canonical casing if it matches (case-insensitive)
|
||||
const toClaudeCodeName = (name: string) =>
|
||||
ccToolLookup.get(name.toLowerCase()) ?? name;
|
||||
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
|
||||
if (tools && tools.length > 0) {
|
||||
const lowerName = name.toLowerCase();
|
||||
const matchedTool = tools.find(
|
||||
(tool) => tool.name.toLowerCase() === lowerName,
|
||||
);
|
||||
if (matchedTool) return matchedTool.name;
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert content blocks to Anthropic API format
|
||||
*/
|
||||
function convertContentBlocks(content: (TextContent | ImageContent)[]):
|
||||
| string
|
||||
| Array<
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "image";
|
||||
source: {
|
||||
type: "base64";
|
||||
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
> {
|
||||
// If only text blocks, return as concatenated string for simplicity
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (!hasImages) {
|
||||
return sanitizeSurrogates(
|
||||
content.map((c) => (c as TextContent).text).join("\n"),
|
||||
);
|
||||
}
|
||||
|
||||
// If we have images, convert to content block array
|
||||
const blocks = content.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: sanitizeSurrogates(block.text),
|
||||
};
|
||||
}
|
||||
return {
|
||||
type: "image" as const,
|
||||
source: {
|
||||
type: "base64" as const,
|
||||
media_type: block.mimeType as
|
||||
| "image/jpeg"
|
||||
| "image/png"
|
||||
| "image/gif"
|
||||
| "image/webp",
|
||||
data: block.data,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// If only images (no text), add placeholder text block
|
||||
const hasText = blocks.some((b) => b.type === "text");
|
||||
if (!hasText) {
|
||||
blocks.unshift({
|
||||
type: "text" as const,
|
||||
text: "(see attached image)",
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
export type AnthropicEffort = "low" | "medium" | "high" | "max";
|
||||
|
||||
export interface AnthropicOptions extends StreamOptions {
|
||||
/**
|
||||
* Enable extended thinking.
|
||||
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
|
||||
* For older models: uses budget-based thinking with thinkingBudgetTokens.
|
||||
*/
|
||||
thinkingEnabled?: boolean;
|
||||
/**
|
||||
* Token budget for extended thinking (older models only).
|
||||
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
|
||||
*/
|
||||
thinkingBudgetTokens?: number;
|
||||
/**
|
||||
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
* Controls how much thinking Claude allocates:
|
||||
* - "max": Always thinks with no constraints (Opus 4.6 only)
|
||||
* - "high": Always thinks, deep reasoning (default)
|
||||
* - "medium": Moderate thinking, may skip for simple queries
|
||||
* - "low": Minimal thinking, skips for simple tasks
|
||||
* Ignored for older models.
|
||||
*/
|
||||
effort?: AnthropicEffort;
|
||||
interleavedThinking?: boolean;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
function mergeHeaders(
|
||||
...headerSources: (Record<string, string> | undefined)[]
|
||||
): Record<string, string> {
|
||||
const merged: Record<string, string> = {};
|
||||
for (const headers of headerSources) {
|
||||
if (headers) {
|
||||
Object.assign(merged, headers);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
export const streamAnthropic: StreamFunction<
|
||||
"anthropic-messages",
|
||||
AnthropicOptions
|
||||
> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
|
||||
|
||||
let copilotDynamicHeaders: Record<string, string> | undefined;
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
copilotDynamicHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
}
|
||||
|
||||
const { client, isOAuthToken } = createClient(
|
||||
model,
|
||||
apiKey,
|
||||
options?.interleavedThinking ?? true,
|
||||
options?.headers,
|
||||
copilotDynamicHeaders,
|
||||
);
|
||||
const params = buildParams(model, context, isOAuthToken, options);
|
||||
options?.onPayload?.(params);
|
||||
const anthropicStream = client.messages.stream(
|
||||
{ ...params, stream: true },
|
||||
{ signal: options?.signal },
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (
|
||||
| ThinkingContent
|
||||
| TextContent
|
||||
| (ToolCall & { partialJson: string })
|
||||
) & { index: number };
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "message_start") {
|
||||
// Capture initial token usage from message_start event
|
||||
// This ensures we have input token counts even if the stream is aborted early
|
||||
output.usage.input = event.message.usage.input_tokens || 0;
|
||||
output.usage.output = event.message.usage.output_tokens || 0;
|
||||
output.usage.cacheRead =
|
||||
event.message.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite =
|
||||
event.message.usage.cache_creation_input_tokens || 0;
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input +
|
||||
output.usage.output +
|
||||
output.usage.cacheRead +
|
||||
output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
} else if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
const block: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "redacted_thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "[Reasoning redacted]",
|
||||
thinkingSignature: event.content_block.data,
|
||||
redacted: true,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuthToken
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments:
|
||||
(event.content_block.input as Record<string, any>) ?? {},
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "text") {
|
||||
block.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "toolCall") {
|
||||
block.partialJson += event.delta.partial_json;
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.partial_json,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinkingSignature = block.thinkingSignature || "";
|
||||
block.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block) {
|
||||
delete (block as any).index;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as any).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
// Only update usage fields if present (not null).
|
||||
// Preserves input_tokens from message_start when proxies omit it in message_delta.
|
||||
if (event.usage.input_tokens != null) {
|
||||
output.usage.input = event.usage.input_tokens;
|
||||
}
|
||||
if (event.usage.output_tokens != null) {
|
||||
output.usage.output = event.usage.output_tokens;
|
||||
}
|
||||
if (event.usage.cache_read_input_tokens != null) {
|
||||
output.usage.cacheRead = event.usage.cache_read_input_tokens;
|
||||
}
|
||||
if (event.usage.cache_creation_input_tokens != null) {
|
||||
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
|
||||
}
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input +
|
||||
output.usage.output +
|
||||
output.usage.cacheRead +
|
||||
output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
|
||||
* Note: effort "max" is only valid on Opus 4.6.
|
||||
*/
|
||||
function mapThinkingLevelToEffort(
|
||||
level: SimpleStreamOptions["reasoning"],
|
||||
modelId: string,
|
||||
): AnthropicEffort {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
return "low";
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
|
||||
? "max"
|
||||
: "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
export const streamSimpleAnthropic: StreamFunction<
|
||||
"anthropic-messages",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: false,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
|
||||
// For older models: use budget-based thinking
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
effort,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: adjusted.thinkingBudget,
|
||||
} satisfies AnthropicOptions);
|
||||
};
|
||||
|
||||
function isOAuthToken(apiKey: string): boolean {
|
||||
return apiKey.includes("sk-ant-oat");
|
||||
}
|
||||
|
||||
function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
interleavedThinking: boolean,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
dynamicHeaders?: Record<string, string>,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
|
||||
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
|
||||
const needsInterleavedBeta =
|
||||
interleavedThinking && !supportsAdaptiveThinking(model.id);
|
||||
|
||||
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
|
||||
if (model.provider === "github-copilot") {
|
||||
const betaFeatures: string[] = [];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
...(betaFeatures.length > 0
|
||||
? { "anthropic-beta": betaFeatures.join(",") }
|
||||
: {}),
|
||||
},
|
||||
model.headers,
|
||||
dynamicHeaders,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
// OAuth: Bearer auth, Claude Code identity headers
|
||||
if (isOAuthToken(apiKey)) {
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
|
||||
"user-agent": `claude-cli/${claudeCodeVersion}`,
|
||||
"x-app": "cli",
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
}
|
||||
|
||||
// API key auth
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": betaFeatures.join(","),
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(
|
||||
model.baseUrl,
|
||||
options?.cacheRetention,
|
||||
);
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(
|
||||
context.messages,
|
||||
model,
|
||||
isOAuthToken,
|
||||
cacheControl,
|
||||
),
|
||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
// Add cache control to system prompt for non-OAuth tokens
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
// Temperature is incompatible with extended thinking (adaptive or budget-based).
|
||||
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, isOAuthToken);
|
||||
}
|
||||
|
||||
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6) or budget-based (older models)
|
||||
if (options?.thinkingEnabled && model.reasoning) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
// Adaptive thinking: Claude decides when and how much to think
|
||||
params.thinking = { type: "adaptive" };
|
||||
if (options.effort) {
|
||||
params.output_config = { effort: options.effort };
|
||||
}
|
||||
} else {
|
||||
// Budget-based thinking for older models
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.metadata) {
|
||||
const userId = options.metadata.user_id;
|
||||
if (typeof userId === "string") {
|
||||
params.metadata = { user_id: userId };
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Normalize tool call IDs to match Anthropic's required pattern and length
|
||||
function normalizeToolCallId(id: string): string {
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
messages: Message[],
|
||||
model: Model<"anthropic-messages">,
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(
|
||||
messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as
|
||||
| "image/jpeg"
|
||||
| "image/png"
|
||||
| "image/gif"
|
||||
| "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
let filteredBlocks = !model?.input.includes("image")
|
||||
? blocks.filter((b) => b.type !== "image")
|
||||
: blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Redacted thinking: pass the opaque payload back as redacted_thinking
|
||||
if (block.redacted) {
|
||||
blocks.push({
|
||||
type: "redacted_thinking",
|
||||
data: block.thinkingSignature!,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
// If thinking signature is missing/empty (e.g., from aborted stream),
|
||||
// convert to plain text block without <thinking> tags to avoid API rejection
|
||||
// and prevent Claude from mimicking the tags in responses
|
||||
if (
|
||||
!block.thinkingSignature ||
|
||||
block.thinkingSignature.trim().length === 0
|
||||
) {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
} else {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: sanitizeSurrogates(block.thinking),
|
||||
signature: block.thinkingSignature,
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
|
||||
input: block.arguments ?? {},
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
|
||||
const toolResults: ContentBlockParam[] = [];
|
||||
|
||||
// Add the current tool result
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: convertContentBlocks(msg.content),
|
||||
is_error: msg.isError,
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult"
|
||||
) {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: nextMsg.toolCallId,
|
||||
content: convertContentBlocks(nextMsg.content),
|
||||
is_error: nextMsg.isError,
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
// Add a single user message with all tool results
|
||||
params.push({
|
||||
role: "user",
|
||||
content: toolResults,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache_control to the last user message to cache conversation history
|
||||
if (cacheControl && params.length > 0) {
|
||||
const lastMessage = params[params.length - 1];
|
||||
if (lastMessage.role === "user") {
|
||||
if (Array.isArray(lastMessage.content)) {
|
||||
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
|
||||
if (
|
||||
lastBlock &&
|
||||
(lastBlock.type === "text" ||
|
||||
lastBlock.type === "image" ||
|
||||
lastBlock.type === "tool_result")
|
||||
) {
|
||||
(lastBlock as any).cache_control = cacheControl;
|
||||
}
|
||||
} else if (typeof lastMessage.content === "string") {
|
||||
lastMessage.content = [
|
||||
{
|
||||
type: "text",
|
||||
text: lastMessage.content,
|
||||
cache_control: cacheControl,
|
||||
},
|
||||
] as any;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: Tool[],
|
||||
isOAuthToken: boolean,
|
||||
): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => {
|
||||
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
|
||||
|
||||
return {
|
||||
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: jsonSchema.properties || {},
|
||||
required: jsonSchema.required || [],
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function mapStopReason(
|
||||
reason: Anthropic.Messages.StopReason | string,
|
||||
): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "error";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
|
||||
return "error";
|
||||
default:
|
||||
// Handle unknown stop reasons gracefully (API may add new values)
|
||||
throw new Error(`Unhandled stop reason: ${reason}`);
|
||||
}
|
||||
}
|
||||
297
packages/ai/src/providers/azure-openai-responses.ts
Normal file
297
packages/ai/src/providers/azure-openai-responses.ts
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
import { AzureOpenAI } from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
convertResponsesMessages,
|
||||
convertResponsesTools,
|
||||
processResponsesStream,
|
||||
} from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const DEFAULT_AZURE_API_VERSION = "v1";
|
||||
const AZURE_TOOL_CALL_PROVIDERS = new Set([
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
"azure-openai-responses",
|
||||
]);
|
||||
|
||||
function parseDeploymentNameMap(
|
||||
value: string | undefined,
|
||||
): Map<string, string> {
|
||||
const map = new Map<string, string>();
|
||||
if (!value) return map;
|
||||
for (const entry of value.split(",")) {
|
||||
const trimmed = entry.trim();
|
||||
if (!trimmed) continue;
|
||||
const [modelId, deploymentName] = trimmed.split("=", 2);
|
||||
if (!modelId || !deploymentName) continue;
|
||||
map.set(modelId.trim(), deploymentName.trim());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
function resolveDeploymentName(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): string {
|
||||
if (options?.azureDeploymentName) {
|
||||
return options.azureDeploymentName;
|
||||
}
|
||||
const mappedDeployment = parseDeploymentNameMap(
|
||||
process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP,
|
||||
).get(model.id);
|
||||
return mappedDeployment || model.id;
|
||||
}
|
||||
|
||||
// Azure OpenAI Responses-specific options
|
||||
export interface AzureOpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
azureApiVersion?: string;
|
||||
azureResourceName?: string;
|
||||
azureBaseUrl?: string;
|
||||
azureDeploymentName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Azure OpenAI Responses API
|
||||
*/
|
||||
export const streamAzureOpenAIResponses: StreamFunction<
|
||||
"azure-openai-responses",
|
||||
AzureOpenAIResponsesOptions
|
||||
> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const deploymentName = resolveDeploymentName(model, options);
|
||||
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "azure-openai-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Azure OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options);
|
||||
const params = buildParams(model, context, options, deploymentName);
|
||||
options?.onPayload?.(params);
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content)
|
||||
delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAzureOpenAIResponses: StreamFunction<
|
||||
"azure-openai-responses",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? options?.reasoning
|
||||
: clampReasoning(options?.reasoning);
|
||||
|
||||
return streamAzureOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies AzureOpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function normalizeAzureBaseUrl(baseUrl: string): string {
|
||||
return baseUrl.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
function buildDefaultBaseUrl(resourceName: string): string {
|
||||
return `https://${resourceName}.openai.azure.com/openai/v1`;
|
||||
}
|
||||
|
||||
function resolveAzureConfig(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): { baseUrl: string; apiVersion: string } {
|
||||
const apiVersion =
|
||||
options?.azureApiVersion ||
|
||||
process.env.AZURE_OPENAI_API_VERSION ||
|
||||
DEFAULT_AZURE_API_VERSION;
|
||||
|
||||
const baseUrl =
|
||||
options?.azureBaseUrl?.trim() ||
|
||||
process.env.AZURE_OPENAI_BASE_URL?.trim() ||
|
||||
undefined;
|
||||
const resourceName =
|
||||
options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
|
||||
|
||||
let resolvedBaseUrl = baseUrl;
|
||||
|
||||
if (!resolvedBaseUrl && resourceName) {
|
||||
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl && model.baseUrl) {
|
||||
resolvedBaseUrl = model.baseUrl;
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl) {
|
||||
throw new Error(
|
||||
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
|
||||
apiVersion,
|
||||
};
|
||||
}
|
||||
|
||||
function createClient(
|
||||
model: Model<"azure-openai-responses">,
|
||||
apiKey: string,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.AZURE_OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.AZURE_OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
|
||||
if (options?.headers) {
|
||||
Object.assign(headers, options.headers);
|
||||
}
|
||||
|
||||
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
|
||||
|
||||
return new AzureOpenAI({
|
||||
apiKey,
|
||||
apiVersion,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
baseURL: baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options: AzureOpenAIResponsesOptions | undefined,
|
||||
deploymentName: string,
|
||||
) {
|
||||
const messages = convertResponsesMessages(
|
||||
model,
|
||||
context,
|
||||
AZURE_TOOL_CALL_PROVIDERS,
|
||||
);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: deploymentName,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: options?.sessionId,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
if (model.name.toLowerCase().startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
37
packages/ai/src/providers/github-copilot-headers.ts
Normal file
37
packages/ai/src/providers/github-copilot-headers.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import type { Message } from "../types.js";
|
||||
|
||||
// Copilot expects X-Initiator to indicate whether the request is user-initiated
|
||||
// or agent-initiated (e.g. follow-up after assistant/tool messages).
|
||||
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
|
||||
const last = messages[messages.length - 1];
|
||||
return last && last.role !== "user" ? "agent" : "user";
|
||||
}
|
||||
|
||||
// Copilot requires Copilot-Vision-Request header when sending images
|
||||
export function hasCopilotVisionInput(messages: Message[]): boolean {
|
||||
return messages.some((msg) => {
|
||||
if (msg.role === "user" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
export function buildCopilotDynamicHeaders(params: {
|
||||
messages: Message[];
|
||||
hasImages: boolean;
|
||||
}): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"X-Initiator": inferCopilotInitiator(params.messages),
|
||||
"Openai-Intent": "conversation-edits",
|
||||
};
|
||||
|
||||
if (params.hasImages) {
|
||||
headers["Copilot-Vision-Request"] = "true";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
1074
packages/ai/src/providers/google-gemini-cli.ts
Normal file
1074
packages/ai/src/providers/google-gemini-cli.ts
Normal file
File diff suppressed because it is too large
Load diff
373
packages/ai/src/providers/google-shared.ts
Normal file
373
packages/ai/src/providers/google-shared.ts
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
/**
|
||||
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Content,
|
||||
FinishReason,
|
||||
FunctionCallingConfigMode,
|
||||
type Part,
|
||||
} from "@google/genai";
|
||||
import type {
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
Tool,
|
||||
} from "../types.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
type GoogleApiType =
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex";
|
||||
|
||||
/**
|
||||
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
|
||||
*
|
||||
* Protocol note (Gemini / Vertex AI thought signatures):
|
||||
* - `thought: true` is the definitive marker for thinking content (thought summaries).
|
||||
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
|
||||
* used to preserve reasoning context across multi-turn interactions.
|
||||
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
|
||||
* indicate the part itself is thinking content.
|
||||
* - For non-functionCall responses, the signature appears on the last part for context replay.
|
||||
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
|
||||
* do not merge/move signatures across parts.
|
||||
*
|
||||
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
*/
|
||||
export function isThinkingPart(
|
||||
part: Pick<Part, "thought" | "thoughtSignature">,
|
||||
): boolean {
|
||||
return part.thought === true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retain thought signatures during streaming.
|
||||
*
|
||||
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
|
||||
* This helper preserves the last non-empty signature for the current block.
|
||||
*
|
||||
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
|
||||
* a signature from being overwritten with `undefined` within the same streamed block.
|
||||
*/
|
||||
export function retainThoughtSignature(
|
||||
existing: string | undefined,
|
||||
incoming: string | undefined,
|
||||
): string | undefined {
|
||||
if (typeof incoming === "string" && incoming.length > 0) return incoming;
|
||||
return existing;
|
||||
}
|
||||
|
||||
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
|
||||
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
|
||||
|
||||
// Sentinel value that tells the Gemini API to skip thought signature validation.
|
||||
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
|
||||
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
|
||||
|
||||
function isValidThoughtSignature(signature: string | undefined): boolean {
|
||||
if (!signature) return false;
|
||||
if (signature.length % 4 !== 0) return false;
|
||||
return base64SignaturePattern.test(signature);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only keep signatures from the same provider/model and with valid base64.
|
||||
*/
|
||||
function resolveThoughtSignature(
|
||||
isSameProviderAndModel: boolean,
|
||||
signature: string | undefined,
|
||||
): string | undefined {
|
||||
return isSameProviderAndModel && isValidThoughtSignature(signature)
|
||||
? signature
|
||||
: undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
|
||||
*/
|
||||
export function requiresToolCallId(modelId: string): boolean {
|
||||
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert internal messages to Gemini Content[] format.
|
||||
*/
|
||||
export function convertMessages<T extends GoogleApiType>(
|
||||
model: Model<T>,
|
||||
context: Context,
|
||||
): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!requiresToolCallId(model.id)) return id;
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: sanitizeSurrogates(item.text) };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image")
|
||||
? parts.filter((p) => p.text !== undefined)
|
||||
: parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
// Check if message is from same provider and model - only then keep thinking blocks
|
||||
const isSameProviderAndModel =
|
||||
msg.provider === model.provider && msg.model === model.id;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
|
||||
if (!block.text || block.text.trim() === "") continue;
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.textSignature,
|
||||
);
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.text),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Skip empty thinking blocks
|
||||
if (!block.thinking || block.thinking.trim() === "") continue;
|
||||
// Only keep as thinking block if same provider AND same model
|
||||
// Otherwise convert to plain text (no tags to avoid model mimicking them)
|
||||
if (isSameProviderAndModel) {
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.thinkingSignature,
|
||||
);
|
||||
parts.push({
|
||||
thought: true,
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.thoughtSignature,
|
||||
);
|
||||
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
|
||||
// Use the skip_thought_signature_validator sentinel for unsigned function calls
|
||||
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
|
||||
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
|
||||
const effectiveSignature =
|
||||
thoughtSignature ||
|
||||
(isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
|
||||
const part: Part = {
|
||||
functionCall: {
|
||||
name: block.name,
|
||||
args: block.arguments ?? {},
|
||||
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
|
||||
},
|
||||
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
|
||||
};
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textContent = msg.content.filter(
|
||||
(c): c is TextContent => c.type === "text",
|
||||
);
|
||||
const textResult = textContent.map((c) => c.text).join("\n");
|
||||
const imageContent = model.input.includes("image")
|
||||
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
||||
: [];
|
||||
|
||||
const hasText = textResult.length > 0;
|
||||
const hasImages = imageContent.length > 0;
|
||||
|
||||
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
|
||||
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
|
||||
// Older models don't support this, so we put images in a separate user message.
|
||||
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
|
||||
|
||||
// Use "output" key for success, "error" key for errors as per SDK documentation
|
||||
const responseValue = hasText
|
||||
? sanitizeSurrogates(textResult)
|
||||
: hasImages
|
||||
? "(see attached image)"
|
||||
: "";
|
||||
|
||||
const imageParts: Part[] = imageContent.map((imageBlock) => ({
|
||||
inlineData: {
|
||||
mimeType: imageBlock.mimeType,
|
||||
data: imageBlock.data,
|
||||
},
|
||||
}));
|
||||
|
||||
const includeId = requiresToolCallId(model.id);
|
||||
const functionResponsePart: Part = {
|
||||
functionResponse: {
|
||||
name: msg.toolName,
|
||||
response: msg.isError
|
||||
? { error: responseValue }
|
||||
: { output: responseValue },
|
||||
// Nest images inside functionResponse.parts for Gemini 3
|
||||
...(hasImages &&
|
||||
supportsMultimodalFunctionResponse && { parts: imageParts }),
|
||||
...(includeId ? { id: msg.toolCallId } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
// Cloud Code Assist API requires all function responses to be in a single user turn.
|
||||
// Check if the last content is already a user turn with function responses and merge.
|
||||
const lastContent = contents[contents.length - 1];
|
||||
if (
|
||||
lastContent?.role === "user" &&
|
||||
lastContent.parts?.some((p) => p.functionResponse)
|
||||
) {
|
||||
lastContent.parts.push(functionResponsePart);
|
||||
} else {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [functionResponsePart],
|
||||
});
|
||||
}
|
||||
|
||||
// For older models, add images in a separate user message
|
||||
if (hasImages && !supportsMultimodalFunctionResponse) {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: "Tool result image:" }, ...imageParts],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert tools to Gemini function declarations format.
|
||||
*
|
||||
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
|
||||
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
|
||||
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
|
||||
* models, where the API translates `parameters` into Anthropic's `input_schema`.
|
||||
*/
|
||||
export function convertTools(
|
||||
tools: Tool[],
|
||||
useParameters = false,
|
||||
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
|
||||
if (tools.length === 0) return undefined;
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
...(useParameters
|
||||
? { parameters: tool.parameters }
|
||||
: { parametersJsonSchema: tool.parameters }),
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Map tool choice string to Gemini FunctionCallingConfigMode.
|
||||
*/
|
||||
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Gemini FinishReason to our StopReason.
|
||||
*/
|
||||
export function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.IMAGE_PROHIBITED_CONTENT:
|
||||
case FinishReason.IMAGE_RECITATION:
|
||||
case FinishReason.IMAGE_OTHER:
|
||||
case FinishReason.RECITATION:
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
case FinishReason.NO_IMAGE:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map string finish reason to our StopReason (for raw API responses).
|
||||
*/
|
||||
export function mapStopReasonString(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
529
packages/ai/src/providers/google-vertex.ts
Normal file
529
packages/ai/src/providers/google-vertex.ts
Normal file
|
|
@ -0,0 +1,529 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
ThinkingLevel,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
ThinkingLevel as PiThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleVertexOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
project?: string;
|
||||
location?: string;
|
||||
}
|
||||
|
||||
const API_VERSION = "v1";
|
||||
|
||||
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
|
||||
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
|
||||
MINIMAL: ThinkingLevel.MINIMAL,
|
||||
LOW: ThinkingLevel.LOW,
|
||||
MEDIUM: ThinkingLevel.MEDIUM,
|
||||
HIGH: ThinkingLevel.HIGH,
|
||||
};
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogleVertex: StreamFunction<
|
||||
"google-vertex",
|
||||
GoogleVertexOptions
|
||||
> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: GoogleVertexOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-vertex" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const project = resolveProject(options);
|
||||
const location = resolveLocation(options);
|
||||
const client = createClient(model, project, location, options?.headers);
|
||||
const params = buildParams(model, context, options);
|
||||
options?.onPayload?.(params);
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: undefined,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some(
|
||||
(b) => b.type === "toolCall" && b.id === providedId,
|
||||
);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments:
|
||||
(part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && {
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
}),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) +
|
||||
(chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleVertex: StreamFunction<
|
||||
"google-vertex",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, geminiModel),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(
|
||||
geminiModel,
|
||||
effort,
|
||||
options.thinkingBudgets,
|
||||
),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-vertex">,
|
||||
project: string,
|
||||
location: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { headers?: Record<string, string> } = {};
|
||||
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
||||
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project,
|
||||
location,
|
||||
apiVersion: API_VERSION,
|
||||
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function resolveProject(options?: GoogleVertexOptions): string {
|
||||
const project =
|
||||
options?.project ||
|
||||
process.env.GOOGLE_CLOUD_PROJECT ||
|
||||
process.env.GCLOUD_PROJECT;
|
||||
if (!project) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
||||
);
|
||||
}
|
||||
return project;
|
||||
}
|
||||
|
||||
function resolveLocation(options?: GoogleVertexOptions): string {
|
||||
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
||||
if (!location) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.",
|
||||
);
|
||||
}
|
||||
return location;
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options: GoogleVertexOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && {
|
||||
systemInstruction: sanitizeSurrogates(context.systemPrompt),
|
||||
}),
|
||||
...(context.tools &&
|
||||
context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
501
packages/ai/src/providers/google.ts
Normal file
501
packages/ai/src/providers/google.ts
Normal file
|
|
@ -0,0 +1,501 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
}
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogle: StreamFunction<
|
||||
"google-generative-ai",
|
||||
GoogleOptions
|
||||
> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options?.headers);
|
||||
const params = buildParams(model, context, options);
|
||||
options?.onPayload?.(params);
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: undefined,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Generate unique ID if not provided or if it's a duplicate
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some(
|
||||
(b) => b.type === "toolCall" && b.id === providedId,
|
||||
);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments:
|
||||
(part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && {
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
}),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) +
|
||||
(chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogle: StreamFunction<
|
||||
"google-generative-ai",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const googleModel = model as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, googleModel),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(
|
||||
googleModel,
|
||||
effort,
|
||||
options.thinkingBudgets,
|
||||
),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-generative-ai">,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: {
|
||||
baseUrl?: string;
|
||||
apiVersion?: string;
|
||||
headers?: Record<string, string>;
|
||||
} = {};
|
||||
if (model.baseUrl) {
|
||||
httpOptions.baseUrl = model.baseUrl;
|
||||
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
|
||||
}
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
return new GoogleGenAI({
|
||||
apiKey,
|
||||
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && {
|
||||
systemInstruction: sanitizeSurrogates(context.systemPrompt),
|
||||
}),
|
||||
...(context.tools &&
|
||||
context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
688
packages/ai/src/providers/mistral.ts
Normal file
688
packages/ai/src/providers/mistral.ts
Normal file
|
|
@ -0,0 +1,688 @@
|
|||
import { Mistral } from "@mistralai/mistralai";
|
||||
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
|
||||
import type {
|
||||
ChatCompletionStreamRequest,
|
||||
ChatCompletionStreamRequestMessages,
|
||||
CompletionEvent,
|
||||
ContentChunk,
|
||||
FunctionTool,
|
||||
} from "@mistralai/mistralai/models/components/index.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
|
||||
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
|
||||
|
||||
/**
|
||||
* Provider-specific options for the Mistral API.
|
||||
*/
|
||||
export interface MistralOptions extends StreamOptions {
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "any"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } };
|
||||
promptMode?: "reasoning";
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream responses from Mistral using `chat.stream`.
|
||||
*/
|
||||
export const streamMistral: StreamFunction<
|
||||
"mistral-conversations",
|
||||
MistralOptions
|
||||
> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: MistralOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = createOutput(model);
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
|
||||
const mistral = new Mistral({
|
||||
apiKey,
|
||||
serverURL: model.baseUrl,
|
||||
});
|
||||
|
||||
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
|
||||
const transformedMessages = transformMessages(
|
||||
context.messages,
|
||||
model,
|
||||
(id) => normalizeMistralToolCallId(id),
|
||||
);
|
||||
|
||||
const payload = buildChatPayload(
|
||||
model,
|
||||
context,
|
||||
transformedMessages,
|
||||
options,
|
||||
);
|
||||
options?.onPayload?.(payload);
|
||||
const mistralStream = await mistral.chat.stream(
|
||||
payload,
|
||||
buildRequestOptions(model, options),
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
await consumeChatStream(model, output, stream, mistralStream);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = formatMistralError(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
|
||||
*/
|
||||
export const streamSimpleMistral: StreamFunction<
|
||||
"mistral-conversations",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoning = clampReasoning(options?.reasoning);
|
||||
|
||||
return streamMistral(model, context, {
|
||||
...base,
|
||||
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
|
||||
} satisfies MistralOptions);
|
||||
};
|
||||
|
||||
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createMistralToolCallIdNormalizer(): (id: string) => string {
|
||||
const idMap = new Map<string, string>();
|
||||
const reverseMap = new Map<string, string>();
|
||||
|
||||
return (id: string): string => {
|
||||
const existing = idMap.get(id);
|
||||
if (existing) return existing;
|
||||
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
const candidate = deriveMistralToolCallId(id, attempt);
|
||||
const owner = reverseMap.get(candidate);
|
||||
if (!owner || owner === id) {
|
||||
idMap.set(id, candidate);
|
||||
reverseMap.set(candidate, id);
|
||||
return candidate;
|
||||
}
|
||||
attempt++;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function deriveMistralToolCallId(id: string, attempt: number): string {
|
||||
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
|
||||
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH)
|
||||
return normalized;
|
||||
const seedBase = normalized || id;
|
||||
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
|
||||
return shortHash(seed)
|
||||
.replace(/[^a-zA-Z0-9]/g, "")
|
||||
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
|
||||
}
|
||||
|
||||
function formatMistralError(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
|
||||
const statusCode =
|
||||
typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
|
||||
const bodyText =
|
||||
typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
|
||||
if (statusCode !== undefined && bodyText) {
|
||||
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
|
||||
}
|
||||
if (statusCode !== undefined)
|
||||
return `Mistral API error (${statusCode}): ${error.message}`;
|
||||
return error.message;
|
||||
}
|
||||
return safeJsonStringify(error);
|
||||
}
|
||||
|
||||
function truncateErrorText(text: string, maxChars: number): string {
|
||||
if (text.length <= maxChars) return text;
|
||||
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
|
||||
}
|
||||
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
try {
|
||||
const serialized = JSON.stringify(value);
|
||||
return serialized === undefined ? String(value) : serialized;
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
|
||||
function buildRequestOptions(
|
||||
model: Model<"mistral-conversations">,
|
||||
options?: MistralOptions,
|
||||
): RequestOptions {
|
||||
const requestOptions: RequestOptions = {};
|
||||
if (options?.signal) requestOptions.signal = options.signal;
|
||||
requestOptions.retries = { strategy: "none" };
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (model.headers) Object.assign(headers, model.headers);
|
||||
if (options?.headers) Object.assign(headers, options.headers);
|
||||
|
||||
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
|
||||
// Respect explicit caller-provided header values.
|
||||
if (options?.sessionId && !headers["x-affinity"]) {
|
||||
headers["x-affinity"] = options.sessionId;
|
||||
}
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
requestOptions.headers = headers;
|
||||
}
|
||||
|
||||
return requestOptions;
|
||||
}
|
||||
|
||||
function buildChatPayload(
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
messages: Message[],
|
||||
options?: MistralOptions,
|
||||
): ChatCompletionStreamRequest {
|
||||
const payload: ChatCompletionStreamRequest = {
|
||||
model: model.id,
|
||||
stream: true,
|
||||
messages: toChatMessages(messages, model.input.includes("image")),
|
||||
};
|
||||
|
||||
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
|
||||
if (options?.temperature !== undefined)
|
||||
payload.temperature = options.temperature;
|
||||
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
|
||||
if (options?.toolChoice)
|
||||
payload.toolChoice = mapToolChoice(options.toolChoice);
|
||||
if (options?.promptMode) payload.promptMode = options.promptMode as any;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
payload.messages.unshift({
|
||||
role: "system",
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
async function consumeChatStream(
|
||||
model: Model<"mistral-conversations">,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
mistralStream: AsyncIterable<CompletionEvent>,
|
||||
): Promise<void> {
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const toolBlocksByKey = new Map<string, number>();
|
||||
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (!block) return;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
for await (const event of mistralStream) {
|
||||
const chunk = event.data;
|
||||
|
||||
if (chunk.usage) {
|
||||
output.usage.input = chunk.usage.promptTokens || 0;
|
||||
output.usage.output = chunk.usage.completionTokens || 0;
|
||||
output.usage.cacheRead = 0;
|
||||
output.usage.cacheWrite = 0;
|
||||
output.usage.totalTokens =
|
||||
chunk.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finishReason) {
|
||||
output.stopReason = mapChatStopReason(choice.finishReason);
|
||||
}
|
||||
|
||||
const delta = choice.delta;
|
||||
if (delta.content !== null && delta.content !== undefined) {
|
||||
const contentItems =
|
||||
typeof delta.content === "string" ? [delta.content] : delta.content;
|
||||
for (const item of contentItems) {
|
||||
if (typeof item === "string") {
|
||||
const textDelta = sanitizeSurrogates(item);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "thinking") {
|
||||
const deltaText = item.thinking
|
||||
.map((part) => ("text" in part ? part.text : ""))
|
||||
.filter((text) => text.length > 0)
|
||||
.join("");
|
||||
const thinkingDelta = sanitizeSurrogates(deltaText);
|
||||
if (!thinkingDelta) continue;
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.thinking += thinkingDelta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: thinkingDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "text") {
|
||||
const textDelta = sanitizeSurrogates(item.text);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = delta.toolCalls || [];
|
||||
for (const toolCall of toolCalls) {
|
||||
if (currentBlock) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = null;
|
||||
}
|
||||
const callId =
|
||||
toolCall.id && toolCall.id !== "null"
|
||||
? toolCall.id
|
||||
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
|
||||
const key = `${callId}:${toolCall.index || 0}`;
|
||||
const existingIndex = toolBlocksByKey.get(key);
|
||||
let block: (ToolCall & { partialArgs?: string }) | undefined;
|
||||
|
||||
if (existingIndex !== undefined) {
|
||||
const existing = output.content[existingIndex];
|
||||
if (existing?.type === "toolCall") {
|
||||
block = existing as ToolCall & { partialArgs?: string };
|
||||
}
|
||||
}
|
||||
|
||||
if (!block) {
|
||||
block = {
|
||||
type: "toolCall",
|
||||
id: callId,
|
||||
name: toolCall.function.name,
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(block);
|
||||
toolBlocksByKey.set(key, output.content.length - 1);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
const argsDelta =
|
||||
typeof toolCall.function.arguments === "string"
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall.function.arguments || {});
|
||||
block.partialArgs = (block.partialArgs || "") + argsDelta;
|
||||
block.arguments = parseStreamingJson<Record<string, unknown>>(
|
||||
block.partialArgs,
|
||||
);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: toolBlocksByKey.get(key)!,
|
||||
delta: argsDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
for (const index of toolBlocksByKey.values()) {
|
||||
const block = output.content[index];
|
||||
if (block.type !== "toolCall") continue;
|
||||
const toolBlock = block as ToolCall & { partialArgs?: string };
|
||||
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(
|
||||
toolBlock.partialArgs,
|
||||
);
|
||||
delete toolBlock.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: toolBlock,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function toFunctionTools(
|
||||
tools: Tool[],
|
||||
): Array<FunctionTool & { type: "function" }> {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as unknown as Record<string, unknown>,
|
||||
strict: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function toChatMessages(
|
||||
messages: Message[],
|
||||
supportsImages: boolean,
|
||||
): ChatCompletionStreamRequestMessages[] {
|
||||
const result: ChatCompletionStreamRequestMessages[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
|
||||
continue;
|
||||
}
|
||||
const hadImages = msg.content.some((item) => item.type === "image");
|
||||
const content: ContentChunk[] = msg.content
|
||||
.filter((item) => item.type === "text" || supportsImages)
|
||||
.map((item) => {
|
||||
if (item.type === "text")
|
||||
return { type: "text", text: sanitizeSurrogates(item.text) };
|
||||
return {
|
||||
type: "image_url",
|
||||
imageUrl: `data:${item.mimeType};base64,${item.data}`,
|
||||
};
|
||||
});
|
||||
if (content.length > 0) {
|
||||
result.push({ role: "user", content });
|
||||
continue;
|
||||
}
|
||||
if (hadImages && !supportsImages) {
|
||||
result.push({
|
||||
role: "user",
|
||||
content: "(image omitted: model does not support images)",
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
const contentParts: ContentChunk[] = [];
|
||||
const toolCalls: Array<{
|
||||
id: string;
|
||||
type: "function";
|
||||
function: { name: string; arguments: string };
|
||||
}> = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "thinking",
|
||||
thinking: [
|
||||
{ type: "text", text: sanitizeSurrogates(block.thinking) },
|
||||
],
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
toolCalls.push({
|
||||
id: block.id,
|
||||
type: "function",
|
||||
function: {
|
||||
name: block.name,
|
||||
arguments: JSON.stringify(block.arguments || {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const assistantMessage: ChatCompletionStreamRequestMessages = {
|
||||
role: "assistant",
|
||||
};
|
||||
if (contentParts.length > 0) assistantMessage.content = contentParts;
|
||||
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
|
||||
if (contentParts.length > 0 || toolCalls.length > 0)
|
||||
result.push(assistantMessage);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolContent: ContentChunk[] = [];
|
||||
const textResult = msg.content
|
||||
.filter((part) => part.type === "text")
|
||||
.map((part) =>
|
||||
part.type === "text" ? sanitizeSurrogates(part.text) : "",
|
||||
)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((part) => part.type === "image");
|
||||
const toolText = buildToolResultText(
|
||||
textResult,
|
||||
hasImages,
|
||||
supportsImages,
|
||||
msg.isError,
|
||||
);
|
||||
toolContent.push({ type: "text", text: toolText });
|
||||
for (const part of msg.content) {
|
||||
if (!supportsImages) continue;
|
||||
if (part.type !== "image") continue;
|
||||
toolContent.push({
|
||||
type: "image_url",
|
||||
imageUrl: `data:${part.mimeType};base64,${part.data}`,
|
||||
});
|
||||
}
|
||||
result.push({
|
||||
role: "tool",
|
||||
toolCallId: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
content: toolContent,
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildToolResultText(
|
||||
text: string,
|
||||
hasImages: boolean,
|
||||
supportsImages: boolean,
|
||||
isError: boolean,
|
||||
): string {
|
||||
const trimmed = text.trim();
|
||||
const errorPrefix = isError ? "[tool error] " : "";
|
||||
|
||||
if (trimmed.length > 0) {
|
||||
const imageSuffix =
|
||||
hasImages && !supportsImages
|
||||
? "\n[tool image omitted: model does not support images]"
|
||||
: "";
|
||||
return `${errorPrefix}${trimmed}${imageSuffix}`;
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
if (supportsImages) {
|
||||
return isError
|
||||
? "[tool error] (see attached image)"
|
||||
: "(see attached image)";
|
||||
}
|
||||
return isError
|
||||
? "[tool error] (image omitted: model does not support images)"
|
||||
: "(image omitted: model does not support images)";
|
||||
}
|
||||
|
||||
return isError ? "[tool error] (no tool output)" : "(no tool output)";
|
||||
}
|
||||
|
||||
function mapToolChoice(
|
||||
choice: MistralOptions["toolChoice"],
|
||||
):
|
||||
| "auto"
|
||||
| "none"
|
||||
| "any"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } }
|
||||
| undefined {
|
||||
if (!choice) return undefined;
|
||||
if (
|
||||
choice === "auto" ||
|
||||
choice === "none" ||
|
||||
choice === "any" ||
|
||||
choice === "required"
|
||||
) {
|
||||
return choice as any;
|
||||
}
|
||||
return {
|
||||
type: "function",
|
||||
function: { name: choice.function.name },
|
||||
};
|
||||
}
|
||||
|
||||
function mapChatStopReason(reason: string | null): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
case "model_length":
|
||||
return "length";
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "error":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
1016
packages/ai/src/providers/openai-codex-responses.ts
Normal file
1016
packages/ai/src/providers/openai-codex-responses.ts
Normal file
File diff suppressed because it is too large
Load diff
949
packages/ai/src/providers/openai-completions.ts
Normal file
949
packages/ai/src/providers/openai-completions.ts
Normal file
|
|
@ -0,0 +1,949 @@
|
|||
import OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost, supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletionsCompat,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
buildCopilotDynamicHeaders,
|
||||
hasCopilotVisionInput,
|
||||
} from "./github-copilot-headers.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Check if conversation messages contain tool calls or tool results.
|
||||
* This is needed because Anthropic (via proxy) requires the tools param
|
||||
* to be present when messages include tool_calls or tool role messages.
|
||||
*/
|
||||
function hasToolHistory(messages: Message[]): boolean {
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "toolResult") {
|
||||
return true;
|
||||
}
|
||||
if (msg.role === "assistant") {
|
||||
if (msg.content.some((block) => block.type === "toolCall")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export interface OpenAICompletionsOptions extends StreamOptions {
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
}
|
||||
|
||||
export const streamOpenAICompletions: StreamFunction<
|
||||
"openai-completions",
|
||||
OpenAICompletionsOptions
|
||||
> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
const params = buildParams(model, context, options);
|
||||
options?.onPayload?.(params);
|
||||
const openaiStream = await client.chat.completions.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock:
|
||||
| TextContent
|
||||
| ThinkingContent
|
||||
| (ToolCall & { partialArgs?: string })
|
||||
| null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (block) {
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialArgs);
|
||||
delete block.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
const cachedTokens =
|
||||
chunk.usage.prompt_tokens_details?.cached_tokens || 0;
|
||||
const reasoningTokens =
|
||||
chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
|
||||
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
|
||||
const outputTokens =
|
||||
(chunk.usage.completion_tokens || 0) + reasoningTokens;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
// Compute totalTokens ourselves since we add reasoning_tokens to output
|
||||
// and some providers (e.g., Groq) don't include them in total_tokens
|
||||
totalTokens: input + outputTokens + cachedTokens,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices?.[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp),
|
||||
// or reasoning (other openai compatible endpoints)
|
||||
// Use the first non-empty reasoning field to avoid duplication
|
||||
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
|
||||
const reasoningFields = [
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"reasoning_text",
|
||||
];
|
||||
let foundReasoningField: string | null = null;
|
||||
for (const field of reasoningFields) {
|
||||
if (
|
||||
(choice.delta as any)[field] !== null &&
|
||||
(choice.delta as any)[field] !== undefined &&
|
||||
(choice.delta as any)[field].length > 0
|
||||
) {
|
||||
if (!foundReasoningField) {
|
||||
foundReasoningField = field;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (foundReasoningField) {
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: foundReasoningField,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any)[foundReasoningField];
|
||||
currentBlock.thinking += delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
name: toolCall.function?.name || "",
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name)
|
||||
currentBlock.name = toolCall.function.name;
|
||||
let delta = "";
|
||||
if (toolCall.function?.arguments) {
|
||||
delta = toolCall.function.arguments;
|
||||
currentBlock.partialArgs += toolCall.function.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(
|
||||
currentBlock.partialArgs,
|
||||
);
|
||||
}
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const reasoningDetails = (choice.delta as any).reasoning_details;
|
||||
if (reasoningDetails && Array.isArray(reasoningDetails)) {
|
||||
for (const detail of reasoningDetails) {
|
||||
if (
|
||||
detail.type === "reasoning.encrypted" &&
|
||||
detail.id &&
|
||||
detail.data
|
||||
) {
|
||||
const matchingToolCall = output.content.find(
|
||||
(b) => b.type === "toolCall" && b.id === detail.id,
|
||||
) as ToolCall | undefined;
|
||||
if (matchingToolCall) {
|
||||
matchingToolCall.thoughtSignature = JSON.stringify(detail);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
// Some providers via OpenRouter give additional information in this field.
|
||||
const rawMetadata = (error as any)?.error?.metadata?.raw;
|
||||
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICompletions: StreamFunction<
|
||||
"openai-completions",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? options?.reasoning
|
||||
: clampReasoning(options?.reasoning);
|
||||
const toolChoice = (options as OpenAICompletionsOptions | undefined)
|
||||
?.toolChoice;
|
||||
|
||||
return streamOpenAICompletions(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
toolChoice,
|
||||
} satisfies OpenAICompletionsOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
) {
|
||||
const compat = getCompat(model);
|
||||
const messages = convertMessages(model, context, compat);
|
||||
maybeAddOpenRouterAnthropicCacheControl(model, messages);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (compat.supportsUsageInStreaming !== false) {
|
||||
(params as any).stream_options = { include_usage: true };
|
||||
}
|
||||
|
||||
if (compat.supportsStore) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
if (compat.maxTokensField === "max_tokens") {
|
||||
(params as any).max_tokens = options.maxTokens;
|
||||
} else {
|
||||
params.max_completion_tokens = options.maxTokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, compat);
|
||||
} else if (hasToolHistory(context.messages)) {
|
||||
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
|
||||
params.tools = [];
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if (
|
||||
(compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") &&
|
||||
model.reasoning
|
||||
) {
|
||||
// Both Z.ai and Qwen use enable_thinking: boolean
|
||||
(params as any).enable_thinking = !!options?.reasoningEffort;
|
||||
} else if (
|
||||
options?.reasoningEffort &&
|
||||
model.reasoning &&
|
||||
compat.supportsReasoningEffort
|
||||
) {
|
||||
// OpenAI-style reasoning_effort
|
||||
(params as any).reasoning_effort = mapReasoningEffort(
|
||||
options.reasoningEffort,
|
||||
compat.reasoningEffortMap,
|
||||
);
|
||||
}
|
||||
|
||||
// OpenRouter provider routing preferences
|
||||
if (
|
||||
model.baseUrl.includes("openrouter.ai") &&
|
||||
model.compat?.openRouterRouting
|
||||
) {
|
||||
(params as any).provider = model.compat.openRouterRouting;
|
||||
}
|
||||
|
||||
// Vercel AI Gateway provider routing preferences
|
||||
if (
|
||||
model.baseUrl.includes("ai-gateway.vercel.sh") &&
|
||||
model.compat?.vercelGatewayRouting
|
||||
) {
|
||||
const routing = model.compat.vercelGatewayRouting;
|
||||
if (routing.only || routing.order) {
|
||||
const gatewayOptions: Record<string, string[]> = {};
|
||||
if (routing.only) gatewayOptions.only = routing.only;
|
||||
if (routing.order) gatewayOptions.order = routing.order;
|
||||
(params as any).providerOptions = { gateway: gatewayOptions };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function mapReasoningEffort(
|
||||
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
|
||||
reasoningEffortMap: Partial<
|
||||
Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>
|
||||
>,
|
||||
): string {
|
||||
return reasoningEffortMap[effort] ?? effort;
|
||||
}
|
||||
|
||||
function maybeAddOpenRouterAnthropicCacheControl(
|
||||
model: Model<"openai-completions">,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/"))
|
||||
return;
|
||||
|
||||
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
|
||||
// on the last user/assistant message (walking backwards until we find text content).
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i];
|
||||
if (msg.role !== "user" && msg.role !== "assistant") continue;
|
||||
|
||||
const content = msg.content;
|
||||
if (typeof content === "string") {
|
||||
msg.content = [
|
||||
Object.assign(
|
||||
{ type: "text" as const, text: content },
|
||||
{ cache_control: { type: "ephemeral" } },
|
||||
),
|
||||
];
|
||||
return;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) continue;
|
||||
|
||||
// Find last text part and add cache_control
|
||||
for (let j = content.length - 1; j >= 0; j--) {
|
||||
const part = content[j];
|
||||
if (part?.type === "text") {
|
||||
Object.assign(part, { cache_control: { type: "ephemeral" } });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertMessages(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
// Handle pipe-separated IDs from OpenAI Responses API
|
||||
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
|
||||
// These come from providers like github-copilot, openai-codex, opencode
|
||||
// Extract just the call_id part and normalize it
|
||||
if (id.includes("|")) {
|
||||
const [callId] = id.split("|");
|
||||
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
|
||||
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
|
||||
}
|
||||
|
||||
if (model.provider === "openai")
|
||||
return id.length > 40 ? id.slice(0, 40) : id;
|
||||
return id;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, (id) =>
|
||||
normalizeToolCallId(id),
|
||||
);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({
|
||||
role: role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let lastRole: string | null = null;
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
// Some providers don't allow user messages directly after tool results
|
||||
// Insert a synthetic assistant message to bridge the gap
|
||||
if (
|
||||
compat.requiresAssistantAfterToolResult &&
|
||||
lastRole === "toolResult" &&
|
||||
msg.role === "user"
|
||||
) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map(
|
||||
(item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
},
|
||||
);
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Some providers don't accept null content, use empty string instead
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: compat.requiresAssistantAfterToolResult ? "" : null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter(
|
||||
(b) => b.type === "text",
|
||||
) as TextContent[];
|
||||
// Filter out empty text blocks to avoid API validation errors
|
||||
const nonEmptyTextBlocks = textBlocks.filter(
|
||||
(b) => b.text && b.text.trim().length > 0,
|
||||
);
|
||||
if (nonEmptyTextBlocks.length > 0) {
|
||||
// GitHub Copilot requires assistant content as a string, not an array.
|
||||
// Sending as array causes Claude models to re-answer all previous prompts.
|
||||
if (model.provider === "github-copilot") {
|
||||
assistantMsg.content = nonEmptyTextBlocks
|
||||
.map((b) => sanitizeSurrogates(b.text))
|
||||
.join("");
|
||||
} else {
|
||||
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
|
||||
return { type: "text", text: sanitizeSurrogates(b.text) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle thinking blocks
|
||||
const thinkingBlocks = msg.content.filter(
|
||||
(b) => b.type === "thinking",
|
||||
) as ThinkingContent[];
|
||||
// Filter out empty thinking blocks to avoid API validation errors
|
||||
const nonEmptyThinkingBlocks = thinkingBlocks.filter(
|
||||
(b) => b.thinking && b.thinking.trim().length > 0,
|
||||
);
|
||||
if (nonEmptyThinkingBlocks.length > 0) {
|
||||
if (compat.requiresThinkingAsText) {
|
||||
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
|
||||
const thinkingText = nonEmptyThinkingBlocks
|
||||
.map((b) => b.thinking)
|
||||
.join("\n\n");
|
||||
const textContent = assistantMsg.content as Array<{
|
||||
type: "text";
|
||||
text: string;
|
||||
}> | null;
|
||||
if (textContent) {
|
||||
textContent.unshift({ type: "text", text: thinkingText });
|
||||
} else {
|
||||
assistantMsg.content = [{ type: "text", text: thinkingText }];
|
||||
}
|
||||
} else {
|
||||
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
|
||||
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks
|
||||
.map((b) => b.thinking)
|
||||
.join("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter(
|
||||
(b) => b.type === "toolCall",
|
||||
) as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
const reasoningDetails = toolCalls
|
||||
.filter((tc) => tc.thoughtSignature)
|
||||
.map((tc) => {
|
||||
try {
|
||||
return JSON.parse(tc.thoughtSignature!);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
if (reasoningDetails.length > 0) {
|
||||
(assistantMsg as any).reasoning_details = reasoningDetails;
|
||||
}
|
||||
}
|
||||
// Skip assistant messages that have no content and no tool calls.
|
||||
// Some providers require "either content or tool_calls, but not none".
|
||||
// Other providers also don't accept empty assistant messages.
|
||||
// This handles aborted assistant responses that got no content.
|
||||
const content = assistantMsg.content;
|
||||
const hasContent =
|
||||
content !== null &&
|
||||
content !== undefined &&
|
||||
(typeof content === "string" ? content.length > 0 : content.length > 0);
|
||||
if (!hasContent && !assistantMsg.tool_calls) {
|
||||
continue;
|
||||
}
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const imageBlocks: Array<{
|
||||
type: "image_url";
|
||||
image_url: { url: string };
|
||||
}> = [];
|
||||
let j = i;
|
||||
|
||||
for (
|
||||
;
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult";
|
||||
j++
|
||||
) {
|
||||
const toolMsg = transformedMessages[j] as ToolResultMessage;
|
||||
|
||||
// Extract text and image content
|
||||
const textResult = toolMsg.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as any).text)
|
||||
.join("\n");
|
||||
const hasImages = toolMsg.content.some((c) => c.type === "image");
|
||||
|
||||
// Always send tool result with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
// Some providers require the 'name' field in tool results
|
||||
const toolResultMsg: ChatCompletionToolMessageParam = {
|
||||
role: "tool",
|
||||
content: sanitizeSurrogates(
|
||||
hasText ? textResult : "(see attached image)",
|
||||
),
|
||||
tool_call_id: toolMsg.toolCallId,
|
||||
};
|
||||
if (compat.requiresToolResultName && toolMsg.toolName) {
|
||||
(toolResultMsg as any).name = toolMsg.toolName;
|
||||
}
|
||||
params.push(toolResultMsg);
|
||||
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
for (const block of toolMsg.content) {
|
||||
if (block.type === "image") {
|
||||
imageBlocks.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
|
||||
if (imageBlocks.length > 0) {
|
||||
if (compat.requiresAssistantAfterToolResult) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
},
|
||||
...imageBlocks,
|
||||
],
|
||||
});
|
||||
lastRole = "user";
|
||||
} else {
|
||||
lastRole = "toolResult";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
lastRole = msg.role;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: Tool[],
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
// Only include strict if provider supports it. Some reject unknown fields.
|
||||
...(compat.supportsStrictMode !== false && { strict: false }),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(
|
||||
reason: ChatCompletionChunk.Choice["finish_reason"],
|
||||
): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect compatibility settings from provider and baseUrl for known providers.
|
||||
* Provider takes precedence over URL-based detection since it's explicitly configured.
|
||||
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
|
||||
*/
|
||||
function detectCompat(
|
||||
model: Model<"openai-completions">,
|
||||
): Required<OpenAICompletionsCompat> {
|
||||
const provider = model.provider;
|
||||
const baseUrl = model.baseUrl;
|
||||
|
||||
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
|
||||
|
||||
const isNonStandard =
|
||||
provider === "cerebras" ||
|
||||
baseUrl.includes("cerebras.ai") ||
|
||||
provider === "xai" ||
|
||||
baseUrl.includes("api.x.ai") ||
|
||||
baseUrl.includes("chutes.ai") ||
|
||||
baseUrl.includes("deepseek.com") ||
|
||||
isZai ||
|
||||
provider === "opencode" ||
|
||||
baseUrl.includes("opencode.ai");
|
||||
|
||||
const useMaxTokens = baseUrl.includes("chutes.ai");
|
||||
|
||||
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
|
||||
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
|
||||
|
||||
const reasoningEffortMap =
|
||||
isGroq && model.id === "qwen/qwen3-32b"
|
||||
? {
|
||||
minimal: "default",
|
||||
low: "default",
|
||||
medium: "default",
|
||||
high: "default",
|
||||
xhigh: "default",
|
||||
}
|
||||
: {};
|
||||
return {
|
||||
supportsStore: !isNonStandard,
|
||||
supportsDeveloperRole: !isNonStandard,
|
||||
supportsReasoningEffort: !isGrok && !isZai,
|
||||
reasoningEffortMap,
|
||||
supportsUsageInStreaming: true,
|
||||
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
|
||||
requiresToolResultName: false,
|
||||
requiresAssistantAfterToolResult: false,
|
||||
requiresThinkingAsText: false,
|
||||
thinkingFormat: isZai ? "zai" : "openai",
|
||||
openRouterRouting: {},
|
||||
vercelGatewayRouting: {},
|
||||
supportsStrictMode: true,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get resolved compatibility settings for a model.
|
||||
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
|
||||
*/
|
||||
function getCompat(
|
||||
model: Model<"openai-completions">,
|
||||
): Required<OpenAICompletionsCompat> {
|
||||
const detected = detectCompat(model);
|
||||
if (!model.compat) return detected;
|
||||
|
||||
return {
|
||||
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
|
||||
supportsDeveloperRole:
|
||||
model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
|
||||
supportsReasoningEffort:
|
||||
model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
|
||||
reasoningEffortMap:
|
||||
model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
|
||||
supportsUsageInStreaming:
|
||||
model.compat.supportsUsageInStreaming ??
|
||||
detected.supportsUsageInStreaming,
|
||||
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
|
||||
requiresToolResultName:
|
||||
model.compat.requiresToolResultName ?? detected.requiresToolResultName,
|
||||
requiresAssistantAfterToolResult:
|
||||
model.compat.requiresAssistantAfterToolResult ??
|
||||
detected.requiresAssistantAfterToolResult,
|
||||
requiresThinkingAsText:
|
||||
model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
|
||||
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
|
||||
openRouterRouting: model.compat.openRouterRouting ?? {},
|
||||
vercelGatewayRouting:
|
||||
model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
|
||||
supportsStrictMode:
|
||||
model.compat.supportsStrictMode ?? detected.supportsStrictMode,
|
||||
};
|
||||
}
|
||||
583
packages/ai/src/providers/openai-responses-shared.ts
Normal file
583
packages/ai/src/providers/openai-responses-shared.ts
Normal file
|
|
@ -0,0 +1,583 @@
|
|||
import type OpenAI from "openai";
|
||||
import type {
|
||||
Tool as OpenAITool,
|
||||
ResponseCreateParamsStreaming,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInput,
|
||||
ResponseInputContent,
|
||||
ResponseInputImage,
|
||||
ResponseInputText,
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
TextSignatureV1,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
function encodeTextSignatureV1(
|
||||
id: string,
|
||||
phase?: TextSignatureV1["phase"],
|
||||
): string {
|
||||
const payload: TextSignatureV1 = { v: 1, id };
|
||||
if (phase) payload.phase = phase;
|
||||
return JSON.stringify(payload);
|
||||
}
|
||||
|
||||
function parseTextSignature(
|
||||
signature: string | undefined,
|
||||
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
|
||||
if (!signature) return undefined;
|
||||
if (signature.startsWith("{")) {
|
||||
try {
|
||||
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
|
||||
if (parsed.v === 1 && typeof parsed.id === "string") {
|
||||
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
|
||||
return { id: parsed.id, phase: parsed.phase };
|
||||
}
|
||||
return { id: parsed.id };
|
||||
}
|
||||
} catch {
|
||||
// Fall through to legacy plain-string handling.
|
||||
}
|
||||
}
|
||||
return { id: signature };
|
||||
}
|
||||
|
||||
export interface OpenAIResponsesStreamOptions {
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
applyServiceTierPricing?: (
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) => void;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesMessagesOptions {
|
||||
includeSystemPrompt?: boolean;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesToolsOptions {
|
||||
strict?: boolean | null;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Message conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesMessages<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
allowedToolCallProviders: ReadonlySet<string>,
|
||||
options?: ConvertResponsesMessagesOptions,
|
||||
): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!allowedToolCallProviders.has(model.provider)) return id;
|
||||
if (!id.includes("|")) return id;
|
||||
const [callId, itemId] = id.split("|");
|
||||
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
// OpenAI Responses API requires item id to start with "fc"
|
||||
if (!sanitizedItemId.startsWith("fc")) {
|
||||
sanitizedItemId = `fc_${sanitizedItemId}`;
|
||||
}
|
||||
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
|
||||
let normalizedCallId =
|
||||
sanitizedCallId.length > 64
|
||||
? sanitizedCallId.slice(0, 64)
|
||||
: sanitizedCallId;
|
||||
let normalizedItemId =
|
||||
sanitizedItemId.length > 64
|
||||
? sanitizedItemId.slice(0, 64)
|
||||
: sanitizedItemId;
|
||||
normalizedCallId = normalizedCallId.replace(/_+$/, "");
|
||||
normalizedItemId = normalizedItemId.replace(/_+$/, "");
|
||||
return `${normalizedCallId}|${normalizedItemId}`;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
);
|
||||
|
||||
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
|
||||
if (includeSystemPrompt && context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let msgIndex = 0;
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "input_text", text: sanitizeSurrogates(msg.content) },
|
||||
],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map(
|
||||
(item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ResponseInputText;
|
||||
}
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
},
|
||||
);
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isDifferentModel =
|
||||
assistantMsg.model !== model.id &&
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(
|
||||
block.thinkingSignature,
|
||||
) as ResponseReasoningItem;
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
const parsedSignature = parseTextSignature(textBlock.textSignature);
|
||||
// OpenAI requires id to be max 64 characters
|
||||
let msgId = parsedSignature?.id;
|
||||
if (!msgId) {
|
||||
msgId = `msg_${msgIndex}`;
|
||||
} else if (msgId.length > 64) {
|
||||
msgId = `msg_${shortHash(msgId)}`;
|
||||
}
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "output_text",
|
||||
text: sanitizeSurrogates(textBlock.text),
|
||||
annotations: [],
|
||||
},
|
||||
],
|
||||
status: "completed",
|
||||
id: msgId,
|
||||
phase: parsedSignature?.phase,
|
||||
} satisfies ResponseOutputMessage);
|
||||
} else if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
const [callId, itemIdRaw] = toolCall.id.split("|");
|
||||
let itemId: string | undefined = itemIdRaw;
|
||||
|
||||
// For different-model messages, set id to undefined to avoid pairing validation.
|
||||
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
|
||||
// By omitting the id, we avoid triggering that validation (like cross-provider does).
|
||||
if (isDifferentModel && itemId?.startsWith("fc_")) {
|
||||
itemId = undefined;
|
||||
}
|
||||
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: itemId,
|
||||
call_id: callId,
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textResult = msg.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some(
|
||||
(c): c is ImageContent => c.type === "image",
|
||||
);
|
||||
|
||||
// Always send function_call_output with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
const [callId] = msg.toolCallId.split("|");
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output: sanitizeSurrogates(
|
||||
hasText ? textResult : "(see attached image)",
|
||||
),
|
||||
});
|
||||
|
||||
// If there are images and model supports them, send a follow-up user message with images
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
const contentParts: ResponseInputContent[] = [];
|
||||
|
||||
// Add text prefix
|
||||
contentParts.push({
|
||||
type: "input_text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
} satisfies ResponseInputText);
|
||||
|
||||
// Add images
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "image") {
|
||||
contentParts.push({
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${block.mimeType};base64,${block.data}`,
|
||||
} satisfies ResponseInputImage);
|
||||
}
|
||||
}
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: contentParts,
|
||||
});
|
||||
}
|
||||
}
|
||||
msgIndex++;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesTools(
|
||||
tools: Tool[],
|
||||
options?: ConvertResponsesToolsOptions,
|
||||
): OpenAITool[] {
|
||||
const strict = options?.strict === undefined ? false : options.strict;
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
strict,
|
||||
}));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Stream processing
|
||||
// =============================================================================
|
||||
|
||||
export async function processResponsesStream<TApi extends Api>(
|
||||
openaiStream: AsyncIterable<ResponseStreamEvent>,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<TApi>,
|
||||
options?: OpenAIResponsesStreamOptions,
|
||||
): Promise<void> {
|
||||
let currentItem:
|
||||
| ResponseReasoningItem
|
||||
| ResponseOutputMessage
|
||||
| ResponseFunctionToolCall
|
||||
| null = null;
|
||||
let currentBlock:
|
||||
| ThinkingContent
|
||||
| TextContent
|
||||
| (ToolCall & { partialJson: string })
|
||||
| null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
for await (const event of openaiStream) {
|
||||
if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else if (item.type === "message") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else if (item.type === "function_call") {
|
||||
currentItem = item;
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: {},
|
||||
partialJson: item.arguments || "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.added") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (
|
||||
currentItem?.type === "reasoning" &&
|
||||
currentBlock?.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (
|
||||
currentItem?.type === "reasoning" &&
|
||||
currentBlock?.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.content_part.added") {
|
||||
if (currentItem?.type === "message") {
|
||||
currentItem.content = currentItem.content || [];
|
||||
// Filter out ReasoningText, only accept output_text and refusal
|
||||
if (
|
||||
event.part.type === "output_text" ||
|
||||
event.part.type === "refusal"
|
||||
) {
|
||||
currentItem.content.push(event.part);
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.delta") {
|
||||
if (
|
||||
currentItem?.type === "function_call" &&
|
||||
currentBlock?.type === "toolCall"
|
||||
) {
|
||||
currentBlock.partialJson += event.delta;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.done") {
|
||||
if (
|
||||
currentItem?.type === "function_call" &&
|
||||
currentBlock?.type === "toolCall"
|
||||
) {
|
||||
currentBlock.partialJson = event.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
}
|
||||
} else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentBlock.thinking =
|
||||
item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock?.type === "text") {
|
||||
currentBlock.text = item.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join("");
|
||||
currentBlock.textSignature = encodeTextSignatureV1(
|
||||
item.id,
|
||||
item.phase ?? undefined,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const args =
|
||||
currentBlock?.type === "toolCall" && currentBlock.partialJson
|
||||
? parseStreamingJson(currentBlock.partialJson)
|
||||
: parseStreamingJson(item.arguments || "{}");
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: args,
|
||||
};
|
||||
|
||||
currentBlock = null;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.completed") {
|
||||
const response = event.response;
|
||||
if (response?.usage) {
|
||||
const cachedTokens =
|
||||
response.usage.input_tokens_details?.cached_tokens || 0;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
|
||||
input: (response.usage.input_tokens || 0) - cachedTokens,
|
||||
output: response.usage.output_tokens || 0,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: response.usage.total_tokens || 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(model, output.usage);
|
||||
if (options?.applyServiceTierPricing) {
|
||||
const serviceTier = response?.service_tier ?? options.serviceTier;
|
||||
options.applyServiceTierPricing(output.usage, serviceTier);
|
||||
}
|
||||
// Map status to stop reason
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (
|
||||
output.content.some((b) => b.type === "toolCall") &&
|
||||
output.stopReason === "stop"
|
||||
) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
} else if (event.type === "error") {
|
||||
throw new Error(
|
||||
`Error Code ${event.code}: ${event.message}` || "Unknown error",
|
||||
);
|
||||
} else if (event.type === "response.failed") {
|
||||
throw new Error("Unknown error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(
|
||||
status: OpenAI.Responses.ResponseStatus | undefined,
|
||||
): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
309
packages/ai/src/providers/openai-responses.ts
Normal file
309
packages/ai/src/providers/openai-responses.ts
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
import OpenAI from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
buildCopilotDynamicHeaders,
|
||||
hasCopilotVisionInput,
|
||||
} from "./github-copilot-headers.js";
|
||||
import {
|
||||
convertResponsesMessages,
|
||||
convertResponsesTools,
|
||||
processResponsesStream,
|
||||
} from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const OPENAI_TOOL_CALL_PROVIDERS = new Set([
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on cacheRetention and base URL.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
*/
|
||||
function getPromptCacheRetention(
|
||||
baseUrl: string,
|
||||
cacheRetention: CacheRetention,
|
||||
): "24h" | undefined {
|
||||
if (cacheRetention !== "long") {
|
||||
return undefined;
|
||||
}
|
||||
if (baseUrl.includes("api.openai.com")) {
|
||||
return "24h";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: StreamFunction<
|
||||
"openai-responses",
|
||||
OpenAIResponsesOptions
|
||||
> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
const params = buildParams(model, context, options);
|
||||
options?.onPayload?.(params);
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model, {
|
||||
serviceTier: options?.serviceTier,
|
||||
applyServiceTierPricing,
|
||||
});
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content)
|
||||
delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAIResponses: StreamFunction<
|
||||
"openai-responses",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? options?.reasoning
|
||||
: clampReasoning(options?.reasoning);
|
||||
|
||||
return streamOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
) {
|
||||
const messages = convertResponsesMessages(
|
||||
model,
|
||||
context,
|
||||
OPENAI_TOOL_CALL_PROVIDERS,
|
||||
);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key:
|
||||
cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(
|
||||
model.baseUrl,
|
||||
cacheRetention,
|
||||
),
|
||||
store: false,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (options?.serviceTier !== undefined) {
|
||||
params.service_tier = options.serviceTier;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
if (model.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function getServiceTierCostMultiplier(
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
): number {
|
||||
switch (serviceTier) {
|
||||
case "flex":
|
||||
return 0.5;
|
||||
case "priority":
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
function applyServiceTierPricing(
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) {
|
||||
const multiplier = getServiceTierCostMultiplier(serviceTier);
|
||||
if (multiplier === 1) return;
|
||||
|
||||
usage.cost.input *= multiplier;
|
||||
usage.cost.output *= multiplier;
|
||||
usage.cost.cacheRead *= multiplier;
|
||||
usage.cost.cacheWrite *= multiplier;
|
||||
usage.cost.total =
|
||||
usage.cost.input +
|
||||
usage.cost.output +
|
||||
usage.cost.cacheRead +
|
||||
usage.cost.cacheWrite;
|
||||
}
|
||||
216
packages/ai/src/providers/register-builtins.ts
Normal file
216
packages/ai/src/providers/register-builtins.ts
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
|
||||
import {
|
||||
streamAzureOpenAIResponses,
|
||||
streamSimpleAzureOpenAIResponses,
|
||||
} from "./azure-openai-responses.js";
|
||||
import { streamGoogle, streamSimpleGoogle } from "./google.js";
|
||||
import {
|
||||
streamGoogleGeminiCli,
|
||||
streamSimpleGoogleGeminiCli,
|
||||
} from "./google-gemini-cli.js";
|
||||
import {
|
||||
streamGoogleVertex,
|
||||
streamSimpleGoogleVertex,
|
||||
} from "./google-vertex.js";
|
||||
import { streamMistral, streamSimpleMistral } from "./mistral.js";
|
||||
import {
|
||||
streamOpenAICodexResponses,
|
||||
streamSimpleOpenAICodexResponses,
|
||||
} from "./openai-codex-responses.js";
|
||||
import {
|
||||
streamOpenAICompletions,
|
||||
streamSimpleOpenAICompletions,
|
||||
} from "./openai-completions.js";
|
||||
import {
|
||||
streamOpenAIResponses,
|
||||
streamSimpleOpenAIResponses,
|
||||
} from "./openai-responses.js";
|
||||
|
||||
interface BedrockProviderModule {
|
||||
streamBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
streamSimpleBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
}
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
|
||||
|
||||
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
|
||||
|
||||
export function setBedrockProviderModule(module: BedrockProviderModule): void {
|
||||
bedrockProviderModuleOverride = module;
|
||||
}
|
||||
|
||||
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
|
||||
if (bedrockProviderModuleOverride) {
|
||||
return bedrockProviderModuleOverride;
|
||||
}
|
||||
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
|
||||
return module as BedrockProviderModule;
|
||||
}
|
||||
|
||||
function forwardStream(
|
||||
target: AssistantMessageEventStream,
|
||||
source: AsyncIterable<AssistantMessageEvent>,
|
||||
): void {
|
||||
(async () => {
|
||||
for await (const event of source) {
|
||||
target.push(event);
|
||||
}
|
||||
target.end();
|
||||
})();
|
||||
}
|
||||
|
||||
function createLazyLoadErrorMessage(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
error: unknown,
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream",
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function streamBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
function streamSimpleBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamSimpleBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
export function registerBuiltInApiProviders(): void {
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "mistral-conversations",
|
||||
stream: streamMistral,
|
||||
streamSimple: streamSimpleMistral,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrockLazy,
|
||||
streamSimple: streamSimpleBedrockLazy,
|
||||
});
|
||||
}
|
||||
|
||||
export function resetApiProviders(): void {
|
||||
clearApiProviders();
|
||||
registerBuiltInApiProviders();
|
||||
}
|
||||
|
||||
registerBuiltInApiProviders();
|
||||
59
packages/ai/src/providers/simple-options.ts
Normal file
59
packages/ai/src/providers/simple-options.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import type {
|
||||
Api,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
ThinkingBudgets,
|
||||
ThinkingLevel,
|
||||
} from "../types.js";
|
||||
|
||||
export function buildBaseOptions(
|
||||
model: Model<Api>,
|
||||
options?: SimpleStreamOptions,
|
||||
apiKey?: string,
|
||||
): StreamOptions {
|
||||
return {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
||||
signal: options?.signal,
|
||||
apiKey: apiKey || options?.apiKey,
|
||||
cacheRetention: options?.cacheRetention,
|
||||
sessionId: options?.sessionId,
|
||||
headers: options?.headers,
|
||||
onPayload: options?.onPayload,
|
||||
maxRetryDelayMs: options?.maxRetryDelayMs,
|
||||
metadata: options?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
export function clampReasoning(
|
||||
effort: ThinkingLevel | undefined,
|
||||
): Exclude<ThinkingLevel, "xhigh"> | undefined {
|
||||
return effort === "xhigh" ? "high" : effort;
|
||||
}
|
||||
|
||||
export function adjustMaxTokensForThinking(
|
||||
baseMaxTokens: number,
|
||||
modelMaxTokens: number,
|
||||
reasoningLevel: ThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): { maxTokens: number; thinkingBudget: number } {
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...customBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
const level = clampReasoning(reasoningLevel)!;
|
||||
let thinkingBudget = budgets[level]!;
|
||||
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return { maxTokens, thinkingBudget };
|
||||
}
|
||||
193
packages/ai/src/providers/transform-messages.ts
Normal file
193
packages/ai/src/providers/transform-messages.ts
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Message,
|
||||
Model,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
|
||||
/**
|
||||
* Normalize tool call ID for cross-provider compatibility.
|
||||
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
|
||||
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
|
||||
*/
|
||||
export function transformMessages<TApi extends Api>(
|
||||
messages: Message[],
|
||||
model: Model<TApi>,
|
||||
normalizeToolCallId?: (
|
||||
id: string,
|
||||
model: Model<TApi>,
|
||||
source: AssistantMessage,
|
||||
) => string,
|
||||
): Message[] {
|
||||
// Build a map of original tool call IDs to normalized IDs
|
||||
const toolCallIdMap = new Map<string, string>();
|
||||
|
||||
// First pass: transform messages (thinking blocks, tool call ID normalization)
|
||||
const transformed = messages.map((msg) => {
|
||||
// User messages pass through unchanged
|
||||
if (msg.role === "user") {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Handle toolResult messages - normalize toolCallId if we have a mapping
|
||||
if (msg.role === "toolResult") {
|
||||
const normalizedId = toolCallIdMap.get(msg.toolCallId);
|
||||
if (normalizedId && normalizedId !== msg.toolCallId) {
|
||||
return { ...msg, toolCallId: normalizedId };
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Assistant messages need transformation check
|
||||
if (msg.role === "assistant") {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isSameModel =
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api &&
|
||||
assistantMsg.model === model.id;
|
||||
|
||||
const transformedContent = assistantMsg.content.flatMap((block) => {
|
||||
if (block.type === "thinking") {
|
||||
// Redacted thinking is opaque encrypted content, only valid for the same model.
|
||||
// Drop it for cross-model to avoid API errors.
|
||||
if (block.redacted) {
|
||||
return isSameModel ? block : [];
|
||||
}
|
||||
// For same model: keep thinking blocks with signatures (needed for replay)
|
||||
// even if the thinking text is empty (OpenAI encrypted reasoning)
|
||||
if (isSameModel && block.thinkingSignature) return block;
|
||||
// Skip empty thinking blocks, convert others to plain text
|
||||
if (!block.thinking || block.thinking.trim() === "") return [];
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.thinking,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "text") {
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.text,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
let normalizedToolCall: ToolCall = toolCall;
|
||||
|
||||
if (!isSameModel && toolCall.thoughtSignature) {
|
||||
normalizedToolCall = { ...toolCall };
|
||||
delete (normalizedToolCall as { thoughtSignature?: string })
|
||||
.thoughtSignature;
|
||||
}
|
||||
|
||||
if (!isSameModel && normalizeToolCallId) {
|
||||
const normalizedId = normalizeToolCallId(
|
||||
toolCall.id,
|
||||
model,
|
||||
assistantMsg,
|
||||
);
|
||||
if (normalizedId !== toolCall.id) {
|
||||
toolCallIdMap.set(toolCall.id, normalizedId);
|
||||
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedToolCall;
|
||||
}
|
||||
|
||||
return block;
|
||||
});
|
||||
|
||||
return {
|
||||
...assistantMsg,
|
||||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
|
||||
// Second pass: insert synthetic empty tool results for orphaned tool calls
|
||||
// This preserves thinking signatures and satisfies API requirements
|
||||
const result: Message[] = [];
|
||||
let pendingToolCalls: ToolCall[] = [];
|
||||
let existingToolResultIds = new Set<string>();
|
||||
|
||||
for (let i = 0; i < transformed.length; i++) {
|
||||
const msg = transformed[i];
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
// Skip errored/aborted assistant messages entirely.
|
||||
// These are incomplete turns that shouldn't be replayed:
|
||||
// - May have partial content (reasoning without message, incomplete tool calls)
|
||||
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
|
||||
// - The model should retry from the last valid state
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (
|
||||
assistantMsg.stopReason === "error" ||
|
||||
assistantMsg.stopReason === "aborted"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track tool calls from this assistant message
|
||||
const toolCalls = assistantMsg.content.filter(
|
||||
(b) => b.type === "toolCall",
|
||||
) as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
pendingToolCalls = toolCalls;
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
existingToolResultIds.add(msg.toolCallId);
|
||||
result.push(msg);
|
||||
} else if (msg.role === "user") {
|
||||
// User message interrupts tool flow - insert synthetic results for orphaned calls
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
result.push(msg);
|
||||
} else {
|
||||
result.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
59
packages/ai/src/stream.ts
Normal file
59
packages/ai/src/stream.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import "./providers/register-builtins.js";
|
||||
|
||||
import { getApiProvider } from "./api-registry.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
ProviderStreamOptions,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export { getEnvApiKey } from "./env-api-keys.js";
|
||||
|
||||
function resolveApiProvider(api: Api) {
|
||||
const provider = getApiProvider(api);
|
||||
if (!provider) {
|
||||
throw new Error(`No API provider registered for api: ${api}`);
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.stream(model, context, options as StreamOptions);
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.streamSimple(model, context, options);
|
||||
}
|
||||
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
361
packages/ai/src/types.ts
Normal file
361
packages/ai/src/types.ts
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type KnownApi =
|
||||
| "openai-completions"
|
||||
| "mistral-conversations"
|
||||
| "openai-responses"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex-responses"
|
||||
| "anthropic-messages"
|
||||
| "bedrock-converse-stream"
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex";
|
||||
|
||||
export type Api = KnownApi | (string & {});
|
||||
|
||||
export type KnownProvider =
|
||||
| "amazon-bedrock"
|
||||
| "anthropic"
|
||||
| "google"
|
||||
| "google-gemini-cli"
|
||||
| "google-antigravity"
|
||||
| "google-vertex"
|
||||
| "openai"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex"
|
||||
| "github-copilot"
|
||||
| "xai"
|
||||
| "groq"
|
||||
| "cerebras"
|
||||
| "openrouter"
|
||||
| "vercel-ai-gateway"
|
||||
| "zai"
|
||||
| "mistral"
|
||||
| "minimax"
|
||||
| "minimax-cn"
|
||||
| "huggingface"
|
||||
| "opencode"
|
||||
| "opencode-go"
|
||||
| "kimi-coding";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
||||
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/** Token budgets for each thinking level (token-based providers only) */
|
||||
export interface ThinkingBudgets {
|
||||
minimal?: number;
|
||||
low?: number;
|
||||
medium?: number;
|
||||
high?: number;
|
||||
}
|
||||
|
||||
// Base options all providers share
|
||||
export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export type Transport = "sse" | "websocket" | "auto";
|
||||
|
||||
export interface StreamOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
* Providers that do not support this option ignore it.
|
||||
*/
|
||||
transport?: Transport;
|
||||
/**
|
||||
* Prompt cache retention preference. Providers map this to their supported values.
|
||||
* Default: "short".
|
||||
*/
|
||||
cacheRetention?: CacheRetention;
|
||||
/**
|
||||
* Optional session identifier for providers that support session-based caching.
|
||||
* Providers can use this to enable prompt caching, request routing, or other
|
||||
* session-aware features. Ignored by providers that don't support it.
|
||||
*/
|
||||
sessionId?: string;
|
||||
/**
|
||||
* Optional callback for inspecting provider payloads before sending.
|
||||
*/
|
||||
onPayload?: (payload: unknown) => void;
|
||||
/**
|
||||
* Optional custom HTTP headers to include in API requests.
|
||||
* Merged with provider defaults; can override default headers.
|
||||
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately
|
||||
* with an error containing the requested delay, allowing higher-level retry logic
|
||||
* to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
/**
|
||||
* Optional metadata to include in API requests.
|
||||
* Providers extract the fields they understand and ignore the rest.
|
||||
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
|
||||
*/
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
|
||||
|
||||
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
||||
export interface SimpleStreamOptions extends StreamOptions {
|
||||
reasoning?: ThinkingLevel;
|
||||
/** Custom token budgets for thinking levels (token-based providers only) */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
}
|
||||
|
||||
// Generic StreamFunction with typed options
|
||||
export type StreamFunction<
|
||||
TApi extends Api = Api,
|
||||
TOptions extends StreamOptions = StreamOptions,
|
||||
> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: TOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface TextSignatureV1 {
|
||||
v: 1;
|
||||
id: string;
|
||||
phase?: "commentary" | "final_answer";
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
|
||||
}
|
||||
|
||||
export interface ThinkingContent {
|
||||
type: "thinking";
|
||||
thinking: string;
|
||||
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
|
||||
/** When true, the thinking content was redacted by safety filters. The opaque
|
||||
* encrypted payload is stored in `thinkingSignature` so it can be passed back
|
||||
* to the API for multi-turn continuity. */
|
||||
redacted?: boolean;
|
||||
}
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image";
|
||||
data: string; // base64 encoded image data
|
||||
mimeType: string; // e.g., "image/jpeg", "image/png"
|
||||
}
|
||||
|
||||
export interface ToolCall {
|
||||
type: "toolCall";
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
|
||||
}
|
||||
|
||||
export interface Usage {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
}
|
||||
|
||||
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
|
||||
|
||||
export interface UserMessage {
|
||||
role: "user";
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface AssistantMessage {
|
||||
role: "assistant";
|
||||
content: (TextContent | ThinkingContent | ToolCall)[];
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
model: string;
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
errorMessage?: string;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface ToolResultMessage<TDetails = any> {
|
||||
role: "toolResult";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
content: (TextContent | ImageContent)[]; // Supports text and images
|
||||
details?: TDetails;
|
||||
isError: boolean;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
|
||||
|
||||
import type { TSchema } from "@sinclair/typebox";
|
||||
|
||||
export interface Tool<TParameters extends TSchema = TSchema> {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: TParameters;
|
||||
}
|
||||
|
||||
export interface Context {
|
||||
systemPrompt?: string;
|
||||
messages: Message[];
|
||||
tools?: Tool[];
|
||||
}
|
||||
|
||||
export type AssistantMessageEvent =
|
||||
| { type: "start"; partial: AssistantMessage }
|
||||
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "text_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "text_end";
|
||||
contentIndex: number;
|
||||
content: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "thinking_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "thinking_end";
|
||||
contentIndex: number;
|
||||
content: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "toolcall_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "toolcall_end";
|
||||
contentIndex: number;
|
||||
toolCall: ToolCall;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
|
||||
message: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
error: AssistantMessage;
|
||||
};
|
||||
|
||||
/**
|
||||
* Compatibility settings for OpenAI-compatible completions APIs.
|
||||
* Use this to override URL-based auto-detection for custom providers.
|
||||
*/
|
||||
export interface OpenAICompletionsCompat {
|
||||
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
|
||||
supportsStore?: boolean;
|
||||
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
|
||||
supportsDeveloperRole?: boolean;
|
||||
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
|
||||
supportsReasoningEffort?: boolean;
|
||||
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
|
||||
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
|
||||
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
|
||||
supportsUsageInStreaming?: boolean;
|
||||
/** Which field to use for max tokens. Default: auto-detected from URL. */
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
|
||||
requiresToolResultName?: boolean;
|
||||
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
|
||||
requiresAssistantAfterToolResult?: boolean;
|
||||
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
|
||||
requiresThinkingAsText?: boolean;
|
||||
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
|
||||
thinkingFormat?: "openai" | "zai" | "qwen";
|
||||
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
|
||||
openRouterRouting?: OpenRouterRouting;
|
||||
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
|
||||
vercelGatewayRouting?: VercelGatewayRouting;
|
||||
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
|
||||
supportsStrictMode?: boolean;
|
||||
}
|
||||
|
||||
/** Compatibility settings for OpenAI Responses APIs. */
|
||||
export interface OpenAIResponsesCompat {
|
||||
// Reserved for future use
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenRouter provider routing preferences.
|
||||
* Controls which upstream providers OpenRouter routes requests to.
|
||||
* @see https://openrouter.ai/docs/provider-routing
|
||||
*/
|
||||
export interface OpenRouterRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway routing preferences.
|
||||
* Controls which upstream providers the gateway routes requests to.
|
||||
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
|
||||
*/
|
||||
export interface VercelGatewayRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number; // $/million tokens
|
||||
output: number; // $/million tokens
|
||||
cacheRead: number; // $/million tokens
|
||||
cacheWrite: number; // $/million tokens
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
|
||||
compat?: TApi extends "openai-completions"
|
||||
? OpenAICompletionsCompat
|
||||
: TApi extends "openai-responses"
|
||||
? OpenAIResponsesCompat
|
||||
: never;
|
||||
}
|
||||
92
packages/ai/src/utils/event-stream.ts
Normal file
92
packages/ai/src/utils/event-stream.ts
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
|
||||
|
||||
// Generic event stream class for async iteration
|
||||
export class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
private queue: T[] = [];
|
||||
private waiting: ((value: IteratorResult<T>) => void)[] = [];
|
||||
private done = false;
|
||||
private finalResultPromise: Promise<R>;
|
||||
private resolveFinalResult!: (result: R) => void;
|
||||
|
||||
constructor(
|
||||
private isComplete: (event: T) => boolean,
|
||||
private extractResult: (event: T) => R,
|
||||
) {
|
||||
this.finalResultPromise = new Promise((resolve) => {
|
||||
this.resolveFinalResult = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: T): void {
|
||||
if (this.done) return;
|
||||
|
||||
if (this.isComplete(event)) {
|
||||
this.done = true;
|
||||
this.resolveFinalResult(this.extractResult(event));
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: event, done: false });
|
||||
} else {
|
||||
this.queue.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
end(result?: R): void {
|
||||
this.done = true;
|
||||
if (result !== undefined) {
|
||||
this.resolveFinalResult(result);
|
||||
}
|
||||
// Notify all waiting consumers that we're done
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as any, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (true) {
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
} else if (this.done) {
|
||||
return;
|
||||
} else {
|
||||
const result = await new Promise<IteratorResult<T>>((resolve) =>
|
||||
this.waiting.push(resolve),
|
||||
);
|
||||
if (result.done) return;
|
||||
yield result.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result(): Promise<R> {
|
||||
return this.finalResultPromise;
|
||||
}
|
||||
}
|
||||
|
||||
export class AssistantMessageEventStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") {
|
||||
return event.message;
|
||||
} else if (event.type === "error") {
|
||||
return event.error;
|
||||
}
|
||||
throw new Error("Unexpected event type for final result");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory function for AssistantMessageEventStream (for use in extensions) */
|
||||
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
|
||||
return new AssistantMessageEventStream();
|
||||
}
|
||||
17
packages/ai/src/utils/hash.ts
Normal file
17
packages/ai/src/utils/hash.ts
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
/** Fast deterministic hash to shorten long strings */
|
||||
export function shortHash(str: string): string {
|
||||
let h1 = 0xdeadbeef;
|
||||
let h2 = 0x41c6ce57;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
const ch = str.charCodeAt(i);
|
||||
h1 = Math.imul(h1 ^ ch, 2654435761);
|
||||
h2 = Math.imul(h2 ^ ch, 1597334677);
|
||||
}
|
||||
h1 =
|
||||
Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^
|
||||
Math.imul(h2 ^ (h2 >>> 13), 3266489909);
|
||||
h2 =
|
||||
Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^
|
||||
Math.imul(h1 ^ (h1 >>> 13), 3266489909);
|
||||
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
|
||||
}
|
||||
30
packages/ai/src/utils/json-parse.ts
Normal file
30
packages/ai/src/utils/json-parse.ts
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import { parse as partialParse } from "partial-json";
|
||||
|
||||
/**
|
||||
* Attempts to parse potentially incomplete JSON during streaming.
|
||||
* Always returns a valid object, even if the JSON is incomplete.
|
||||
*
|
||||
* @param partialJson The partial JSON string from streaming
|
||||
* @returns Parsed object or empty object if parsing fails
|
||||
*/
|
||||
export function parseStreamingJson<T = any>(
|
||||
partialJson: string | undefined,
|
||||
): T {
|
||||
if (!partialJson || partialJson.trim() === "") {
|
||||
return {} as T;
|
||||
}
|
||||
|
||||
// Try standard parsing first (fastest for complete JSON)
|
||||
try {
|
||||
return JSON.parse(partialJson) as T;
|
||||
} catch {
|
||||
// Try partial-json for incomplete JSON
|
||||
try {
|
||||
const result = partialParse(partialJson);
|
||||
return (result ?? {}) as T;
|
||||
} catch {
|
||||
// If all parsing fails, return empty object
|
||||
return {} as T;
|
||||
}
|
||||
}
|
||||
}
|
||||
144
packages/ai/src/utils/oauth/anthropic.ts
Normal file
144
packages/ai/src/utils/oauth/anthropic.ts
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Anthropic OAuth flow (Claude Pro/Max)
|
||||
*/
|
||||
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
|
||||
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
|
||||
const SCOPES = "org:create_api_key user:profile user:inference";
|
||||
|
||||
/**
|
||||
* Login with Anthropic OAuth (device code flow)
|
||||
*
|
||||
* @param onAuthUrl - Callback to handle the authorization URL (e.g., open browser)
|
||||
* @param onPromptCode - Callback to prompt user for the authorization code
|
||||
*/
|
||||
export async function loginAnthropic(
|
||||
onAuthUrl: (url: string) => void,
|
||||
onPromptCode: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
code: "true",
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES,
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
});
|
||||
|
||||
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuthUrl(authUrl);
|
||||
|
||||
// Wait for user to paste authorization code (format: code#state)
|
||||
const authCode = await onPromptCode();
|
||||
const splits = authCode.split("#");
|
||||
const code = splits[0];
|
||||
const state = splits[1];
|
||||
|
||||
// Exchange code for tokens
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code: code,
|
||||
state: state,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
// Save credentials
|
||||
return {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Anthropic OAuth token
|
||||
*/
|
||||
export async function refreshAnthropicToken(
|
||||
refreshToken: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
grant_type: "refresh_token",
|
||||
client_id: CLIENT_ID,
|
||||
refresh_token: refreshToken,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Anthropic token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
export const anthropicOAuthProvider: OAuthProviderInterface = {
|
||||
id: "anthropic",
|
||||
name: "Anthropic (Claude Pro/Max)",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAnthropic(
|
||||
(url) => callbacks.onAuth({ url }),
|
||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||
);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshAnthropicToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
423
packages/ai/src/utils/oauth/github-copilot.ts
Normal file
423
packages/ai/src/utils/oauth/github-copilot.ts
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
/**
|
||||
* GitHub Copilot OAuth flow
|
||||
*/
|
||||
|
||||
import { getModels } from "../../models.js";
|
||||
import type { Api, Model } from "../../types.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
type CopilotCredentials = OAuthCredentials & {
|
||||
enterpriseUrl?: string;
|
||||
};
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
||||
const COPILOT_HEADERS = {
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
"Editor-Version": "vscode/1.107.0",
|
||||
"Editor-Plugin-Version": "copilot-chat/0.35.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
} as const;
|
||||
|
||||
type DeviceCodeResponse = {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
interval: number;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
type DeviceTokenSuccessResponse = {
|
||||
access_token: string;
|
||||
token_type?: string;
|
||||
scope?: string;
|
||||
};
|
||||
|
||||
type DeviceTokenErrorResponse = {
|
||||
error: string;
|
||||
error_description?: string;
|
||||
interval?: number;
|
||||
};
|
||||
|
||||
export function normalizeDomain(input: string): string | null {
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const url = trimmed.includes("://")
|
||||
? new URL(trimmed)
|
||||
: new URL(`https://${trimmed}`);
|
||||
return url.hostname;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function getUrls(domain: string): {
|
||||
deviceCodeUrl: string;
|
||||
accessTokenUrl: string;
|
||||
copilotTokenUrl: string;
|
||||
} {
|
||||
return {
|
||||
deviceCodeUrl: `https://${domain}/login/device/code`,
|
||||
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
|
||||
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the proxy-ep from a Copilot token and convert to API base URL.
|
||||
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
||||
* Returns API URL like https://api.individual.githubcopilot.com
|
||||
*/
|
||||
function getBaseUrlFromToken(token: string): string | null {
|
||||
const match = token.match(/proxy-ep=([^;]+)/);
|
||||
if (!match) return null;
|
||||
const proxyHost = match[1];
|
||||
// Convert proxy.xxx to api.xxx
|
||||
const apiHost = proxyHost.replace(/^proxy\./, "api.");
|
||||
return `https://${apiHost}`;
|
||||
}
|
||||
|
||||
export function getGitHubCopilotBaseUrl(
|
||||
token?: string,
|
||||
enterpriseDomain?: string,
|
||||
): string {
|
||||
// If we have a token, extract the base URL from proxy-ep
|
||||
if (token) {
|
||||
const urlFromToken = getBaseUrlFromToken(token);
|
||||
if (urlFromToken) return urlFromToken;
|
||||
}
|
||||
// Fallback for enterprise or if token parsing fails
|
||||
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
|
||||
return "https://api.individual.githubcopilot.com";
|
||||
}
|
||||
|
||||
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
|
||||
const response = await fetch(url, init);
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`${response.status} ${response.statusText}: ${text}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
|
||||
const urls = getUrls(domain);
|
||||
const data = await fetchJson(urls.deviceCodeUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
scope: "read:user",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!data || typeof data !== "object") {
|
||||
throw new Error("Invalid device code response");
|
||||
}
|
||||
|
||||
const deviceCode = (data as Record<string, unknown>).device_code;
|
||||
const userCode = (data as Record<string, unknown>).user_code;
|
||||
const verificationUri = (data as Record<string, unknown>).verification_uri;
|
||||
const interval = (data as Record<string, unknown>).interval;
|
||||
const expiresIn = (data as Record<string, unknown>).expires_in;
|
||||
|
||||
if (
|
||||
typeof deviceCode !== "string" ||
|
||||
typeof userCode !== "string" ||
|
||||
typeof verificationUri !== "string" ||
|
||||
typeof interval !== "number" ||
|
||||
typeof expiresIn !== "number"
|
||||
) {
|
||||
throw new Error("Invalid device code response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
device_code: deviceCode,
|
||||
user_code: userCode,
|
||||
verification_uri: verificationUri,
|
||||
interval,
|
||||
expires_in: expiresIn,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep that can be interrupted by an AbortSignal
|
||||
*/
|
||||
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Login cancelled"));
|
||||
return;
|
||||
}
|
||||
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
|
||||
signal?.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Login cancelled"));
|
||||
},
|
||||
{ once: true },
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
async function pollForGitHubAccessToken(
|
||||
domain: string,
|
||||
deviceCode: string,
|
||||
intervalSeconds: number,
|
||||
expiresIn: number,
|
||||
signal?: AbortSignal,
|
||||
) {
|
||||
const urls = getUrls(domain);
|
||||
const deadline = Date.now() + expiresIn * 1000;
|
||||
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
|
||||
|
||||
while (Date.now() < deadline) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const raw = await fetchJson(urls.accessTokenUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
device_code: deviceCode,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}),
|
||||
});
|
||||
|
||||
if (
|
||||
raw &&
|
||||
typeof raw === "object" &&
|
||||
typeof (raw as DeviceTokenSuccessResponse).access_token === "string"
|
||||
) {
|
||||
return (raw as DeviceTokenSuccessResponse).access_token;
|
||||
}
|
||||
|
||||
if (
|
||||
raw &&
|
||||
typeof raw === "object" &&
|
||||
typeof (raw as DeviceTokenErrorResponse).error === "string"
|
||||
) {
|
||||
const err = (raw as DeviceTokenErrorResponse).error;
|
||||
if (err === "authorization_pending") {
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (err === "slow_down") {
|
||||
intervalMs += 5000;
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Error(`Device flow failed: ${err}`);
|
||||
}
|
||||
|
||||
await abortableSleep(intervalMs, signal);
|
||||
}
|
||||
|
||||
throw new Error("Device flow timed out");
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh GitHub Copilot token
|
||||
*/
|
||||
export async function refreshGitHubCopilotToken(
|
||||
refreshToken: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
const urls = getUrls(domain);
|
||||
|
||||
const raw = await fetchJson(urls.copilotTokenUrl, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${refreshToken}`,
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
});
|
||||
|
||||
if (!raw || typeof raw !== "object") {
|
||||
throw new Error("Invalid Copilot token response");
|
||||
}
|
||||
|
||||
const token = (raw as Record<string, unknown>).token;
|
||||
const expiresAt = (raw as Record<string, unknown>).expires_at;
|
||||
|
||||
if (typeof token !== "string" || typeof expiresAt !== "number") {
|
||||
throw new Error("Invalid Copilot token response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: refreshToken,
|
||||
access: token,
|
||||
expires: expiresAt * 1000 - 5 * 60 * 1000,
|
||||
enterpriseUrl: enterpriseDomain,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a model for the user's GitHub Copilot account.
|
||||
* This is required for some models (like Claude, Grok) before they can be used.
|
||||
*/
|
||||
async function enableGitHubCopilotModel(
|
||||
token: string,
|
||||
modelId: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<boolean> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
const url = `${baseUrl}/models/${modelId}/policy`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...COPILOT_HEADERS,
|
||||
"openai-intent": "chat-policy",
|
||||
"x-interaction-type": "chat-policy",
|
||||
},
|
||||
body: JSON.stringify({ state: "enabled" }),
|
||||
});
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable all known GitHub Copilot models that may require policy acceptance.
|
||||
* Called after successful login to ensure all models are available.
|
||||
*/
|
||||
async function enableAllGitHubCopilotModels(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
onProgress?: (model: string, success: boolean) => void,
|
||||
): Promise<void> {
|
||||
const models = getModels("github-copilot");
|
||||
await Promise.all(
|
||||
models.map(async (model) => {
|
||||
const success = await enableGitHubCopilotModel(
|
||||
token,
|
||||
model.id,
|
||||
enterpriseDomain,
|
||||
);
|
||||
onProgress?.(model.id, success);
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with GitHub Copilot OAuth (device code flow)
|
||||
*
|
||||
* @param options.onAuth - Callback with URL and optional instructions (user code)
|
||||
* @param options.onPrompt - Callback to prompt user for input
|
||||
* @param options.onProgress - Optional progress callback
|
||||
* @param options.signal - Optional AbortSignal for cancellation
|
||||
*/
|
||||
export async function loginGitHubCopilot(options: {
|
||||
onAuth: (url: string, instructions?: string) => void;
|
||||
onPrompt: (prompt: {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
}) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const input = await options.onPrompt({
|
||||
message: "GitHub Enterprise URL/domain (blank for github.com)",
|
||||
placeholder: "company.ghe.com",
|
||||
allowEmpty: true,
|
||||
});
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const trimmed = input.trim();
|
||||
const enterpriseDomain = normalizeDomain(input);
|
||||
if (trimmed && !enterpriseDomain) {
|
||||
throw new Error("Invalid GitHub Enterprise URL/domain");
|
||||
}
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
|
||||
const device = await startDeviceFlow(domain);
|
||||
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
|
||||
|
||||
const githubAccessToken = await pollForGitHubAccessToken(
|
||||
domain,
|
||||
device.device_code,
|
||||
device.interval,
|
||||
device.expires_in,
|
||||
options.signal,
|
||||
);
|
||||
const credentials = await refreshGitHubCopilotToken(
|
||||
githubAccessToken,
|
||||
enterpriseDomain ?? undefined,
|
||||
);
|
||||
|
||||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(
|
||||
credentials.access,
|
||||
enterpriseDomain ?? undefined,
|
||||
);
|
||||
return credentials;
|
||||
}
|
||||
|
||||
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
||||
id: "github-copilot",
|
||||
name: "GitHub Copilot",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
signal: callbacks.signal,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
|
||||
modifyModels(
|
||||
models: Model<Api>[],
|
||||
credentials: OAuthCredentials,
|
||||
): Model<Api>[] {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
const domain = creds.enterpriseUrl
|
||||
? (normalizeDomain(creds.enterpriseUrl) ?? undefined)
|
||||
: undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||
return models.map((m) =>
|
||||
m.provider === "github-copilot" ? { ...m, baseUrl } : m,
|
||||
);
|
||||
},
|
||||
};
|
||||
492
packages/ai/src/utils/oauth/google-antigravity.ts
Normal file
492
packages/ai/src/utils/oauth/google-antigravity.ts
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
/**
|
||||
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
type AntigravityCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
|
||||
);
|
||||
const CLIENT_SECRET = decode(
|
||||
"R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=",
|
||||
);
|
||||
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
|
||||
|
||||
// Antigravity requires additional scopes
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
];
|
||||
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
|
||||
// Fallback project ID when discovery fails
|
||||
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<
|
||||
typeof import("node:http").createServer
|
||||
> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error(
|
||||
"Antigravity OAuth is only available in Node.js environments",
|
||||
);
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let result: { code: string; state: string } | null = null;
|
||||
let cancelled = false;
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:51121`);
|
||||
|
||||
if (url.pathname === "/oauth-callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
|
||||
);
|
||||
result = { code, state };
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404);
|
||||
res.end();
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(51121, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
while (!result && !cancelled) {
|
||||
await sleep();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string | { id?: string };
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a project for the user
|
||||
*/
|
||||
async function discoverProject(
|
||||
accessToken: string,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<string> {
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": JSON.stringify({
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
}),
|
||||
};
|
||||
|
||||
// Try endpoints in order: prod first, then sandbox
|
||||
const endpoints = [
|
||||
"https://cloudcode-pa.googleapis.com",
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com",
|
||||
];
|
||||
|
||||
onProgress?.("Checking for existing project...");
|
||||
|
||||
for (const endpoint of endpoints) {
|
||||
try {
|
||||
const loadResponse = await fetch(
|
||||
`${endpoint}/v1internal:loadCodeAssist`,
|
||||
{
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
if (loadResponse.ok) {
|
||||
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
|
||||
// Handle both string and object formats
|
||||
if (
|
||||
typeof data.cloudaicompanionProject === "string" &&
|
||||
data.cloudaicompanionProject
|
||||
) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
if (
|
||||
data.cloudaicompanionProject &&
|
||||
typeof data.cloudaicompanionProject === "object" &&
|
||||
data.cloudaicompanionProject.id
|
||||
) {
|
||||
return data.cloudaicompanionProject.id;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Try next endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Use fallback project ID
|
||||
onProgress?.("Using default project...");
|
||||
return DEFAULT_PROJECT_ID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch(
|
||||
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Antigravity token
|
||||
*/
|
||||
export async function refreshAntigravityToken(
|
||||
refreshToken: string,
|
||||
projectId: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Antigravity token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Antigravity OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginAntigravity(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const antigravityOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-antigravity",
|
||||
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAntigravity(
|
||||
callbacks.onAuth,
|
||||
callbacks.onProgress,
|
||||
callbacks.onManualCodeInput,
|
||||
);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Antigravity credentials missing projectId");
|
||||
}
|
||||
return refreshAntigravityToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
648
packages/ai/src/utils/oauth/google-gemini-cli.ts
Normal file
648
packages/ai/src/utils/oauth/google-gemini-cli.ts
Normal file
|
|
@ -0,0 +1,648 @@
|
|||
/**
|
||||
* Gemini CLI OAuth flow (Google Cloud Code Assist)
|
||||
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
type GeminiCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
|
||||
);
|
||||
const CLIENT_SECRET = decode(
|
||||
"R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=",
|
||||
);
|
||||
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
];
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<
|
||||
typeof import("node:http").createServer
|
||||
> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let result: { code: string; state: string } | null = null;
|
||||
let cancelled = false;
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:8085`);
|
||||
|
||||
if (url.pathname === "/oauth2callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
|
||||
);
|
||||
result = { code, state };
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404);
|
||||
res.end();
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(8085, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
while (!result && !cancelled) {
|
||||
await sleep();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string;
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Long-running operation response from onboardUser
|
||||
*/
|
||||
interface LongRunningOperationResponse {
|
||||
name?: string;
|
||||
done?: boolean;
|
||||
response?: {
|
||||
cloudaicompanionProject?: { id?: string };
|
||||
};
|
||||
}
|
||||
|
||||
// Tier IDs as used by the Cloud Code API
|
||||
const TIER_FREE = "free-tier";
|
||||
const TIER_LEGACY = "legacy-tier";
|
||||
const TIER_STANDARD = "standard-tier";
|
||||
|
||||
interface GoogleRpcErrorResponse {
|
||||
error?: {
|
||||
details?: Array<{ reason?: string }>;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait helper for onboarding retries
|
||||
*/
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default tier from allowed tiers
|
||||
*/
|
||||
function getDefaultTier(
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>,
|
||||
): { id?: string } {
|
||||
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
|
||||
const defaultTier = allowedTiers.find((t) => t.isDefault);
|
||||
return defaultTier ?? { id: TIER_LEGACY };
|
||||
}
|
||||
|
||||
function isVpcScAffectedUser(payload: unknown): boolean {
|
||||
if (!payload || typeof payload !== "object") return false;
|
||||
if (!("error" in payload)) return false;
|
||||
const error = (payload as GoogleRpcErrorResponse).error;
|
||||
if (!error?.details || !Array.isArray(error.details)) return false;
|
||||
return error.details.some(
|
||||
(detail) => detail.reason === "SECURITY_POLICY_VIOLATED",
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll a long-running operation until completion
|
||||
*/
|
||||
async function pollOperation(
|
||||
operationName: string,
|
||||
headers: Record<string, string>,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<LongRunningOperationResponse> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
if (attempt > 0) {
|
||||
onProgress?.(
|
||||
`Waiting for project provisioning (attempt ${attempt + 1})...`,
|
||||
);
|
||||
await wait(5000);
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers,
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to poll operation: ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as LongRunningOperationResponse;
|
||||
if (data.done) {
|
||||
return data;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a Google Cloud project for the user
|
||||
*/
|
||||
async function discoverProject(
|
||||
accessToken: string,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<string> {
|
||||
// Check for user-provided project ID via environment variable
|
||||
const envProjectId =
|
||||
process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
};
|
||||
|
||||
// Try to load existing project via loadCodeAssist
|
||||
onProgress?.("Checking for existing Cloud Code Assist project...");
|
||||
const loadResponse = await fetch(
|
||||
`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`,
|
||||
{
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
cloudaicompanionProject: envProjectId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
duetProject: envProjectId,
|
||||
},
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
let data: LoadCodeAssistPayload;
|
||||
|
||||
if (!loadResponse.ok) {
|
||||
let errorPayload: unknown;
|
||||
try {
|
||||
errorPayload = await loadResponse.clone().json();
|
||||
} catch {
|
||||
errorPayload = undefined;
|
||||
}
|
||||
|
||||
if (isVpcScAffectedUser(errorPayload)) {
|
||||
data = { currentTier: { id: TIER_STANDARD } };
|
||||
} else {
|
||||
const errorText = await loadResponse.text();
|
||||
throw new Error(
|
||||
`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
}
|
||||
|
||||
// If user already has a current tier and project, use it
|
||||
if (data.currentTier) {
|
||||
if (data.cloudaicompanionProject) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
// User has a tier but no managed project - they need to provide one via env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
// User needs to be onboarded - get the default tier
|
||||
const tier = getDefaultTier(data.allowedTiers);
|
||||
const tierId = tier?.id ?? TIER_FREE;
|
||||
|
||||
if (tierId !== TIER_FREE && !envProjectId) {
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
onProgress?.(
|
||||
"Provisioning Cloud Code Assist project (this may take a moment)...",
|
||||
);
|
||||
|
||||
// Build onboard request - for free tier, don't include project ID (Google provisions one)
|
||||
// For other tiers, include the user's project ID if available
|
||||
const onboardBody: Record<string, unknown> = {
|
||||
tierId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
};
|
||||
|
||||
if (tierId !== TIER_FREE && envProjectId) {
|
||||
onboardBody.cloudaicompanionProject = envProjectId;
|
||||
(onboardBody.metadata as Record<string, unknown>).duetProject =
|
||||
envProjectId;
|
||||
}
|
||||
|
||||
// Start onboarding - this returns a long-running operation
|
||||
const onboardResponse = await fetch(
|
||||
`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`,
|
||||
{
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(onboardBody),
|
||||
},
|
||||
);
|
||||
|
||||
if (!onboardResponse.ok) {
|
||||
const errorText = await onboardResponse.text();
|
||||
throw new Error(
|
||||
`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
|
||||
|
||||
// If the operation isn't done yet, poll until completion
|
||||
if (!lroData.done && lroData.name) {
|
||||
lroData = await pollOperation(lroData.name, headers, onProgress);
|
||||
}
|
||||
|
||||
// Try to get project ID from the response
|
||||
const projectId = lroData.response?.cloudaicompanionProject?.id;
|
||||
if (projectId) {
|
||||
return projectId;
|
||||
}
|
||||
|
||||
// If no project ID from onboarding, fall back to env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
"Could not discover or provision a Google Cloud project. " +
|
||||
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch(
|
||||
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Google Cloud Code Assist token
|
||||
*/
|
||||
export async function refreshGoogleCloudToken(
|
||||
refreshToken: string,
|
||||
projectId: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Google Cloud token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginGeminiCli(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const geminiCliOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-gemini-cli",
|
||||
name: "Google Cloud Code Assist (Gemini CLI)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGeminiCli(
|
||||
callbacks.onAuth,
|
||||
callbacks.onProgress,
|
||||
callbacks.onManualCodeInput,
|
||||
);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Google Cloud credentials missing projectId");
|
||||
}
|
||||
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
187
packages/ai/src/utils/oauth/index.ts
Normal file
187
packages/ai/src/utils/oauth/index.ts
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
/**
|
||||
* OAuth credential management for AI providers.
|
||||
*
|
||||
* This module handles login, token refresh, and credential storage
|
||||
* for OAuth-based providers:
|
||||
* - Anthropic (Claude Pro/Max)
|
||||
* - GitHub Copilot
|
||||
* - Google Cloud Code Assist (Gemini CLI)
|
||||
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
*/
|
||||
|
||||
// Anthropic
|
||||
export {
|
||||
anthropicOAuthProvider,
|
||||
loginAnthropic,
|
||||
refreshAnthropicToken,
|
||||
} from "./anthropic.js";
|
||||
// GitHub Copilot
|
||||
export {
|
||||
getGitHubCopilotBaseUrl,
|
||||
githubCopilotOAuthProvider,
|
||||
loginGitHubCopilot,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
} from "./github-copilot.js";
|
||||
// Google Antigravity
|
||||
export {
|
||||
antigravityOAuthProvider,
|
||||
loginAntigravity,
|
||||
refreshAntigravityToken,
|
||||
} from "./google-antigravity.js";
|
||||
// Google Gemini CLI
|
||||
export {
|
||||
geminiCliOAuthProvider,
|
||||
loginGeminiCli,
|
||||
refreshGoogleCloudToken,
|
||||
} from "./google-gemini-cli.js";
|
||||
// OpenAI Codex (ChatGPT OAuth)
|
||||
export {
|
||||
loginOpenAICodex,
|
||||
openaiCodexOAuthProvider,
|
||||
refreshOpenAICodexToken,
|
||||
} from "./openai-codex.js";
|
||||
|
||||
export * from "./types.js";
|
||||
|
||||
// ============================================================================
|
||||
// Provider Registry
|
||||
// ============================================================================
|
||||
|
||||
import { anthropicOAuthProvider } from "./anthropic.js";
|
||||
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||
import { antigravityOAuthProvider } from "./google-antigravity.js";
|
||||
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
|
||||
import { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInfo,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
|
||||
anthropicOAuthProvider,
|
||||
githubCopilotOAuthProvider,
|
||||
geminiCliOAuthProvider,
|
||||
antigravityOAuthProvider,
|
||||
openaiCodexOAuthProvider,
|
||||
];
|
||||
|
||||
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
|
||||
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
/**
|
||||
* Get an OAuth provider by ID
|
||||
*/
|
||||
export function getOAuthProvider(
|
||||
id: OAuthProviderId,
|
||||
): OAuthProviderInterface | undefined {
|
||||
return oauthProviderRegistry.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom OAuth provider
|
||||
*/
|
||||
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister an OAuth provider.
|
||||
*
|
||||
* If the provider is built-in, restores the built-in implementation.
|
||||
* Custom providers are removed completely.
|
||||
*/
|
||||
export function unregisterOAuthProvider(id: string): void {
|
||||
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find(
|
||||
(provider) => provider.id === id,
|
||||
);
|
||||
if (builtInProvider) {
|
||||
oauthProviderRegistry.set(id, builtInProvider);
|
||||
return;
|
||||
}
|
||||
oauthProviderRegistry.delete(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset OAuth providers to built-ins.
|
||||
*/
|
||||
export function resetOAuthProviders(): void {
|
||||
oauthProviderRegistry.clear();
|
||||
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
export function getOAuthProviders(): OAuthProviderInterface[] {
|
||||
return Array.from(oauthProviderRegistry.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
|
||||
*/
|
||||
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
|
||||
return getOAuthProviders().map((p) => ({
|
||||
id: p.id,
|
||||
name: p.name,
|
||||
available: true,
|
||||
}));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// High-level API (uses provider registry)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Refresh token for any OAuth provider.
|
||||
* @deprecated Use getOAuthProvider(id).refreshToken() instead
|
||||
*/
|
||||
export async function refreshOAuthToken(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: OAuthCredentials,
|
||||
): Promise<OAuthCredentials> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
return provider.refreshToken(credentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider from OAuth credentials.
|
||||
* Automatically refreshes expired tokens.
|
||||
*
|
||||
* @returns API key string and updated credentials, or null if no credentials
|
||||
* @throws Error if refresh fails
|
||||
*/
|
||||
export async function getOAuthApiKey(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: Record<string, OAuthCredentials>,
|
||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
let creds = credentials[providerId];
|
||||
if (!creds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Refresh if expired
|
||||
if (Date.now() >= creds.expires) {
|
||||
try {
|
||||
creds = await provider.refreshToken(creds);
|
||||
} catch (_error) {
|
||||
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
|
||||
}
|
||||
}
|
||||
|
||||
const apiKey = provider.getApiKey(creds);
|
||||
return { newCredentials: creds, apiKey };
|
||||
}
|
||||
499
packages/ai/src/utils/oauth/openai-codex.ts
Normal file
499
packages/ai/src/utils/oauth/openai-codex.ts
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
/**
|
||||
* OpenAI Codex (ChatGPT OAuth) flow
|
||||
*
|
||||
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
|
||||
let _http: typeof import("node:http") | null = null;
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
import("node:crypto").then((m) => {
|
||||
_randomBytes = m.randomBytes;
|
||||
});
|
||||
import("node:http").then((m) => {
|
||||
_http = m;
|
||||
});
|
||||
}
|
||||
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthPrompt,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
|
||||
const TOKEN_URL = "https://auth.openai.com/oauth/token";
|
||||
const REDIRECT_URI = "http://localhost:1455/auth/callback";
|
||||
const SCOPE = "openid profile email offline_access";
|
||||
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
|
||||
|
||||
const SUCCESS_HTML = `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Authentication successful</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Authentication successful. Return to your terminal to continue.</p>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
type TokenSuccess = {
|
||||
type: "success";
|
||||
access: string;
|
||||
refresh: string;
|
||||
expires: number;
|
||||
};
|
||||
type TokenFailure = { type: "failed" };
|
||||
type TokenResult = TokenSuccess | TokenFailure;
|
||||
|
||||
type JwtPayload = {
|
||||
[JWT_CLAIM_PATH]?: {
|
||||
chatgpt_account_id?: string;
|
||||
};
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
function createState(): string {
|
||||
if (!_randomBytes) {
|
||||
throw new Error(
|
||||
"OpenAI Codex OAuth is only available in Node.js environments",
|
||||
);
|
||||
}
|
||||
return _randomBytes(16).toString("hex");
|
||||
}
|
||||
|
||||
function parseAuthorizationInput(input: string): {
|
||||
code?: string;
|
||||
state?: string;
|
||||
} {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// not a URL
|
||||
}
|
||||
|
||||
if (value.includes("#")) {
|
||||
const [code, state] = value.split("#", 2);
|
||||
return { code, state };
|
||||
}
|
||||
|
||||
if (value.includes("code=")) {
|
||||
const params = new URLSearchParams(value);
|
||||
return {
|
||||
code: params.get("code") ?? undefined,
|
||||
state: params.get("state") ?? undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return { code: value };
|
||||
}
|
||||
|
||||
function decodeJwt(token: string): JwtPayload | null {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) return null;
|
||||
const payload = parts[1] ?? "";
|
||||
const decoded = atob(payload);
|
||||
return JSON.parse(decoded) as JwtPayload;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeAuthorizationCode(
|
||||
code: string,
|
||||
verifier: string,
|
||||
redirectUri: string = REDIRECT_URI,
|
||||
): Promise<TokenResult> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code,
|
||||
code_verifier: verifier,
|
||||
redirect_uri: redirectUri,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error("[openai-codex] code->token failed:", response.status, text);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (
|
||||
!json.access_token ||
|
||||
!json.refresh_token ||
|
||||
typeof json.expires_in !== "number"
|
||||
) {
|
||||
console.error("[openai-codex] token response missing fields:", json);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
|
||||
try {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "refresh_token",
|
||||
refresh_token: refreshToken,
|
||||
client_id: CLIENT_ID,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error(
|
||||
"[openai-codex] Token refresh failed:",
|
||||
response.status,
|
||||
text,
|
||||
);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (
|
||||
!json.access_token ||
|
||||
!json.refresh_token ||
|
||||
typeof json.expires_in !== "number"
|
||||
) {
|
||||
console.error(
|
||||
"[openai-codex] Token refresh response missing fields:",
|
||||
json,
|
||||
);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("[openai-codex] Token refresh error:", error);
|
||||
return { type: "failed" };
|
||||
}
|
||||
}
|
||||
|
||||
async function createAuthorizationFlow(
|
||||
originator: string = "pi",
|
||||
): Promise<{ verifier: string; state: string; url: string }> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
const state = createState();
|
||||
|
||||
const url = new URL(AUTHORIZE_URL);
|
||||
url.searchParams.set("response_type", "code");
|
||||
url.searchParams.set("client_id", CLIENT_ID);
|
||||
url.searchParams.set("redirect_uri", REDIRECT_URI);
|
||||
url.searchParams.set("scope", SCOPE);
|
||||
url.searchParams.set("code_challenge", challenge);
|
||||
url.searchParams.set("code_challenge_method", "S256");
|
||||
url.searchParams.set("state", state);
|
||||
url.searchParams.set("id_token_add_organizations", "true");
|
||||
url.searchParams.set("codex_cli_simplified_flow", "true");
|
||||
url.searchParams.set("originator", originator);
|
||||
|
||||
return { verifier, state, url: url.toString() };
|
||||
}
|
||||
|
||||
type OAuthServerInfo = {
|
||||
close: () => void;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string } | null>;
|
||||
};
|
||||
|
||||
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
|
||||
if (!_http) {
|
||||
throw new Error(
|
||||
"OpenAI Codex OAuth is only available in Node.js environments",
|
||||
);
|
||||
}
|
||||
let lastCode: string | null = null;
|
||||
let cancelled = false;
|
||||
const server = _http.createServer((req, res) => {
|
||||
try {
|
||||
const url = new URL(req.url || "", "http://localhost");
|
||||
if (url.pathname !== "/auth/callback") {
|
||||
res.statusCode = 404;
|
||||
res.end("Not found");
|
||||
return;
|
||||
}
|
||||
if (url.searchParams.get("state") !== state) {
|
||||
res.statusCode = 400;
|
||||
res.end("State mismatch");
|
||||
return;
|
||||
}
|
||||
const code = url.searchParams.get("code");
|
||||
if (!code) {
|
||||
res.statusCode = 400;
|
||||
res.end("Missing authorization code");
|
||||
return;
|
||||
}
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(SUCCESS_HTML);
|
||||
lastCode = code;
|
||||
} catch {
|
||||
res.statusCode = 500;
|
||||
res.end("Internal error");
|
||||
}
|
||||
});
|
||||
|
||||
return new Promise((resolve) => {
|
||||
server
|
||||
.listen(1455, "127.0.0.1", () => {
|
||||
resolve({
|
||||
close: () => server.close(),
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
for (let i = 0; i < 600; i += 1) {
|
||||
if (lastCode) return { code: lastCode };
|
||||
if (cancelled) return null;
|
||||
await sleep();
|
||||
}
|
||||
return null;
|
||||
},
|
||||
});
|
||||
})
|
||||
.on("error", (err: NodeJS.ErrnoException) => {
|
||||
console.error(
|
||||
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
|
||||
err.code,
|
||||
") Falling back to manual paste.",
|
||||
);
|
||||
resolve({
|
||||
close: () => {
|
||||
try {
|
||||
server.close();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
},
|
||||
cancelWait: () => {},
|
||||
waitForCode: async () => null,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function getAccountId(accessToken: string): string | null {
|
||||
const payload = decodeJwt(accessToken);
|
||||
const auth = payload?.[JWT_CLAIM_PATH];
|
||||
const accountId = auth?.chatgpt_account_id;
|
||||
return typeof accountId === "string" && accountId.length > 0
|
||||
? accountId
|
||||
: null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with OpenAI Codex OAuth
|
||||
*
|
||||
* @param options.onAuth - Called with URL and instructions when auth starts
|
||||
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
|
||||
* @param options.onProgress - Optional progress messages
|
||||
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
* Useful for showing paste input immediately alongside browser flow.
|
||||
* @param options.originator - OAuth originator parameter (defaults to "pi")
|
||||
*/
|
||||
export async function loginOpenAICodex(options: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
originator?: string;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const { verifier, state, url } = await createAuthorizationFlow(
|
||||
options.originator,
|
||||
);
|
||||
const server = await startLocalOAuthServer(state);
|
||||
|
||||
options.onAuth({
|
||||
url,
|
||||
instructions: "A browser window should open. Complete login to finish.",
|
||||
});
|
||||
|
||||
let code: string | undefined;
|
||||
try {
|
||||
if (options.onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualCode: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = options
|
||||
.onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualCode = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won
|
||||
code = result.code;
|
||||
} else if (manualCode) {
|
||||
// Manual input won (or callback timed out and user had entered code)
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise to complete and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualCode) {
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: wait for callback, then prompt if needed
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to onPrompt if still no code
|
||||
if (!code) {
|
||||
const input = await options.onPrompt({
|
||||
message: "Paste the authorization code (or full redirect URL):",
|
||||
});
|
||||
const parsed = parseAuthorizationInput(input);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("Missing authorization code");
|
||||
}
|
||||
|
||||
const tokenResult = await exchangeAuthorizationCode(code, verifier);
|
||||
if (tokenResult.type !== "success") {
|
||||
throw new Error("Token exchange failed");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(tokenResult.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: tokenResult.access,
|
||||
refresh: tokenResult.refresh,
|
||||
expires: tokenResult.expires,
|
||||
accountId,
|
||||
};
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OpenAI Codex OAuth token
|
||||
*/
|
||||
export async function refreshOpenAICodexToken(
|
||||
refreshToken: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const result = await refreshAccessToken(refreshToken);
|
||||
if (result.type !== "success") {
|
||||
throw new Error("Failed to refresh OpenAI Codex token");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(result.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: result.access,
|
||||
refresh: result.refresh,
|
||||
expires: result.expires,
|
||||
accountId,
|
||||
};
|
||||
}
|
||||
|
||||
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
|
||||
id: "openai-codex",
|
||||
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginOpenAICodex({
|
||||
onAuth: callbacks.onAuth,
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
onManualCodeInput: callbacks.onManualCodeInput,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshOpenAICodexToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
37
packages/ai/src/utils/oauth/pkce.ts
Normal file
37
packages/ai/src/utils/oauth/pkce.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* PKCE utilities using Web Crypto API.
|
||||
* Works in both Node.js 20+ and browsers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Encode bytes as base64url string.
|
||||
*/
|
||||
function base64urlEncode(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
for (const byte of bytes) {
|
||||
binary += String.fromCharCode(byte);
|
||||
}
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge.
|
||||
* Uses Web Crypto API for cross-platform compatibility.
|
||||
*/
|
||||
export async function generatePKCE(): Promise<{
|
||||
verifier: string;
|
||||
challenge: string;
|
||||
}> {
|
||||
// Generate random verifier
|
||||
const verifierBytes = new Uint8Array(32);
|
||||
crypto.getRandomValues(verifierBytes);
|
||||
const verifier = base64urlEncode(verifierBytes);
|
||||
|
||||
// Compute SHA-256 challenge
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
62
packages/ai/src/utils/oauth/types.ts
Normal file
62
packages/ai/src/utils/oauth/types.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import type { Api, Model } from "../../types.js";
|
||||
|
||||
export type OAuthCredentials = {
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
export type OAuthProviderId = string;
|
||||
|
||||
/** @deprecated Use OAuthProviderId instead */
|
||||
export type OAuthProvider = OAuthProviderId;
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthLoginCallbacks {
|
||||
onAuth: (info: OAuthAuthInfo) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OAuthProviderInterface {
|
||||
readonly id: OAuthProviderId;
|
||||
readonly name: string;
|
||||
|
||||
/** Run the login flow, return credentials to persist */
|
||||
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||
|
||||
/** Whether login uses a local callback server and supports manual code input. */
|
||||
usesCallbackServer?: boolean;
|
||||
|
||||
/** Refresh expired credentials, return updated credentials to persist */
|
||||
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||
|
||||
/** Convert credentials to API key string for the provider */
|
||||
getApiKey(credentials: OAuthCredentials): string;
|
||||
|
||||
/** Optional: modify models for this provider (e.g., update baseUrl) */
|
||||
modifyModels?(
|
||||
models: Model<Api>[],
|
||||
credentials: OAuthCredentials,
|
||||
): Model<Api>[];
|
||||
}
|
||||
|
||||
/** @deprecated Use OAuthProviderInterface instead */
|
||||
export interface OAuthProviderInfo {
|
||||
id: OAuthProviderId;
|
||||
name: string;
|
||||
available: boolean;
|
||||
}
|
||||
127
packages/ai/src/utils/overflow.ts
Normal file
127
packages/ai/src/utils/overflow.ts
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import type { AssistantMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Regex patterns to detect context overflow errors from different providers.
|
||||
*
|
||||
* These patterns match error messages returned when the input exceeds
|
||||
* the model's context window.
|
||||
*
|
||||
* Provider-specific patterns (with example error messages):
|
||||
*
|
||||
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
|
||||
* - OpenAI: "Your input exceeds the context window of this model"
|
||||
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
|
||||
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
|
||||
* - Groq: "Please reduce the length of the messages or completion"
|
||||
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
|
||||
* - llama.cpp: "the request exceeds the available context size, try increasing it"
|
||||
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
|
||||
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
|
||||
* - MiniMax: "invalid params, context window exceeds limit"
|
||||
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
|
||||
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
|
||||
* - Ollama: Silently truncates input - not detectable via error message
|
||||
*/
|
||||
const OVERFLOW_PATTERNS = [
|
||||
/prompt is too long/i, // Anthropic
|
||||
/input is too long for requested model/i, // Amazon Bedrock
|
||||
/exceeds the context window/i, // OpenAI (Completions & Responses API)
|
||||
/input token count.*exceeds the maximum/i, // Google (Gemini)
|
||||
/maximum prompt length is \d+/i, // xAI (Grok)
|
||||
/reduce the length of the messages/i, // Groq
|
||||
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
|
||||
/exceeds the limit of \d+/i, // GitHub Copilot
|
||||
/exceeds the available context size/i, // llama.cpp server
|
||||
/greater than the context length/i, // LM Studio
|
||||
/context window exceeds limit/i, // MiniMax
|
||||
/exceeded model token limit/i, // Kimi For Coding
|
||||
/too large for model with \d+ maximum context length/i, // Mistral
|
||||
/context[_ ]length[_ ]exceeded/i, // Generic fallback
|
||||
/too many tokens/i, // Generic fallback
|
||||
/token limit exceeded/i, // Generic fallback
|
||||
];
|
||||
|
||||
/**
|
||||
* Check if an assistant message represents a context overflow error.
|
||||
*
|
||||
* This handles two cases:
|
||||
* 1. Error-based overflow: Most providers return stopReason "error" with a
|
||||
* specific error message pattern.
|
||||
* 2. Silent overflow: Some providers accept overflow requests and return
|
||||
* successfully. For these, we check if usage.input exceeds the context window.
|
||||
*
|
||||
* ## Reliability by Provider
|
||||
*
|
||||
* **Reliable detection (returns error with detectable message):**
|
||||
* - Anthropic: "prompt is too long: X tokens > Y maximum"
|
||||
* - OpenAI (Completions & Responses): "exceeds the context window"
|
||||
* - Google Gemini: "input token count exceeds the maximum"
|
||||
* - xAI (Grok): "maximum prompt length is X but request contains Y"
|
||||
* - Groq: "reduce the length of the messages"
|
||||
* - Cerebras: 400/413 status code (no body)
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - OpenRouter (all backends): "maximum context length is X tokens"
|
||||
* - llama.cpp: "exceeds the available context size"
|
||||
* - LM Studio: "greater than the context length"
|
||||
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
|
||||
*
|
||||
* **Unreliable detection:**
|
||||
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
|
||||
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
|
||||
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
|
||||
* The response will have usage.input < expected, but we don't know the expected value.
|
||||
*
|
||||
* ## Custom Providers
|
||||
*
|
||||
* If you've added custom models via settings.json, this function may not detect
|
||||
* overflow errors from those providers. To add support:
|
||||
*
|
||||
* 1. Send a request that exceeds the model's context window
|
||||
* 2. Check the errorMessage in the response
|
||||
* 3. Create a regex pattern that matches the error
|
||||
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
|
||||
* check the errorMessage yourself before calling this function
|
||||
*
|
||||
* @param message - The assistant message to check
|
||||
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
|
||||
* @returns true if the message indicates a context overflow
|
||||
*/
|
||||
export function isContextOverflow(
|
||||
message: AssistantMessage,
|
||||
contextWindow?: number,
|
||||
): boolean {
|
||||
// Case 1: Check error message patterns
|
||||
if (message.stopReason === "error" && message.errorMessage) {
|
||||
// Check known patterns
|
||||
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cerebras returns 400/413 with no body for context overflow
|
||||
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
|
||||
if (
|
||||
/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
|
||||
if (contextWindow && message.stopReason === "stop") {
|
||||
const inputTokens = message.usage.input + message.usage.cacheRead;
|
||||
if (inputTokens > contextWindow) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the overflow patterns for testing purposes.
|
||||
*/
|
||||
export function getOverflowPatterns(): RegExp[] {
|
||||
return [...OVERFLOW_PATTERNS];
|
||||
}
|
||||
28
packages/ai/src/utils/sanitize-unicode.ts
Normal file
28
packages/ai/src/utils/sanitize-unicode.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Removes unpaired Unicode surrogate characters from a string.
|
||||
*
|
||||
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
|
||||
* or vice versa) cause JSON serialization errors in many API providers.
|
||||
*
|
||||
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
|
||||
* surrogates and will NOT be affected by this function.
|
||||
*
|
||||
* @param text - The text to sanitize
|
||||
* @returns The sanitized text with unpaired surrogates removed
|
||||
*
|
||||
* @example
|
||||
* // Valid emoji (properly paired surrogates) are preserved
|
||||
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
|
||||
*
|
||||
* // Unpaired high surrogate is removed
|
||||
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
|
||||
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
|
||||
*/
|
||||
export function sanitizeSurrogates(text: string): string {
|
||||
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
|
||||
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
|
||||
return text.replace(
|
||||
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
|
||||
"",
|
||||
);
|
||||
}
|
||||
24
packages/ai/src/utils/typebox-helpers.ts
Normal file
24
packages/ai/src/utils/typebox-helpers.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import { type TUnsafe, Type } from "@sinclair/typebox";
|
||||
|
||||
/**
|
||||
* Creates a string enum schema compatible with Google's API and other providers
|
||||
* that don't support anyOf/const patterns.
|
||||
*
|
||||
* @example
|
||||
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
|
||||
* description: "The operation to perform"
|
||||
* });
|
||||
*
|
||||
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
|
||||
*/
|
||||
export function StringEnum<T extends readonly string[]>(
|
||||
values: T,
|
||||
options?: { description?: string; default?: T[number] },
|
||||
): TUnsafe<T[number]> {
|
||||
return Type.Unsafe<T[number]>({
|
||||
type: "string",
|
||||
enum: values as any,
|
||||
...(options?.description && { description: options.description }),
|
||||
...(options?.default && { default: options.default }),
|
||||
});
|
||||
}
|
||||
88
packages/ai/src/utils/validation.ts
Normal file
88
packages/ai/src/utils/validation.ts
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import AjvModule from "ajv";
|
||||
import addFormatsModule from "ajv-formats";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
||||
|
||||
import type { Tool, ToolCall } from "../types.js";
|
||||
|
||||
// Detect if we're in a browser extension environment with strict CSP
|
||||
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
||||
const isBrowserExtension =
|
||||
typeof globalThis !== "undefined" &&
|
||||
(globalThis as any).chrome?.runtime?.id !== undefined;
|
||||
|
||||
// Create a singleton AJV instance with formats (only if not in browser extension)
|
||||
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
|
||||
let ajv: any = null;
|
||||
if (!isBrowserExtension) {
|
||||
try {
|
||||
ajv = new Ajv({
|
||||
allErrors: true,
|
||||
strict: false,
|
||||
coerceTypes: true,
|
||||
});
|
||||
addFormats(ajv);
|
||||
} catch (_e) {
|
||||
// AJV initialization failed (likely CSP restriction)
|
||||
console.warn("AJV validation disabled due to CSP restrictions");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
|
||||
* @param tools Array of tool definitions
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated arguments
|
||||
* @throws Error if tool is not found or validation fails
|
||||
*/
|
||||
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
|
||||
const tool = tools.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
throw new Error(`Tool "${toolCall.name}" not found`);
|
||||
}
|
||||
return validateToolArguments(tool, toolCall);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates tool call arguments against the tool's TypeBox schema
|
||||
* @param tool The tool definition with TypeBox schema
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated (and potentially coerced) arguments
|
||||
* @throws Error with formatted message if validation fails
|
||||
*/
|
||||
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
||||
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
|
||||
if (!ajv || isBrowserExtension) {
|
||||
// Trust the LLM's output without validation
|
||||
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
|
||||
return toolCall.arguments;
|
||||
}
|
||||
|
||||
// Compile the schema
|
||||
const validate = ajv.compile(tool.parameters);
|
||||
|
||||
// Clone arguments so AJV can safely mutate for type coercion
|
||||
const args = structuredClone(toolCall.arguments);
|
||||
|
||||
// Validate the arguments (AJV mutates args in-place for type coercion)
|
||||
if (validate(args)) {
|
||||
return args;
|
||||
}
|
||||
|
||||
// Format validation errors nicely
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((err: any) => {
|
||||
const path = err.instancePath
|
||||
? err.instancePath.substring(1)
|
||||
: err.params.missingProperty || "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n") || "Unknown validation error";
|
||||
|
||||
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue