move pi-mono into companion-cloud as apps/companion-os

- Copy all pi-mono source into apps/companion-os/
- Update Dockerfile to COPY pre-built binary instead of downloading from GitHub Releases
- Update deploy-staging.yml to build pi from source (bun compile) before Docker build
- Add apps/companion-os/** to path triggers
- No more cross-repo dispatch needed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Harivansh Rathi 2026-03-07 09:22:50 -08:00
commit 0250f72976
579 changed files with 206942 additions and 0 deletions

View file

@ -0,0 +1,101 @@
import type {
Api,
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "./types.js";
export type ApiStreamFunction = (
model: Model<Api>,
context: Context,
options?: StreamOptions,
) => AssistantMessageEventStream;
export type ApiStreamSimpleFunction = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
export interface ApiProvider<
TApi extends Api = Api,
TOptions extends StreamOptions = StreamOptions,
> {
api: TApi;
stream: StreamFunction<TApi, TOptions>;
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
}
interface ApiProviderInternal {
api: Api;
stream: ApiStreamFunction;
streamSimple: ApiStreamSimpleFunction;
}
type RegisteredApiProvider = {
provider: ApiProviderInternal;
sourceId?: string;
};
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
api: TApi,
stream: StreamFunction<TApi, TOptions>,
): ApiStreamFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return stream(model as Model<TApi>, context, options as TOptions);
};
}
function wrapStreamSimple<TApi extends Api>(
api: TApi,
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
): ApiStreamSimpleFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return streamSimple(model as Model<TApi>, context, options);
};
}
export function registerApiProvider<
TApi extends Api,
TOptions extends StreamOptions,
>(provider: ApiProvider<TApi, TOptions>, sourceId?: string): void {
apiProviderRegistry.set(provider.api, {
provider: {
api: provider.api,
stream: wrapStream(provider.api, provider.stream),
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
},
sourceId,
});
}
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
return apiProviderRegistry.get(api)?.provider;
}
export function getApiProviders(): ApiProviderInternal[] {
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
}
export function unregisterApiProviders(sourceId: string): void {
for (const [api, entry] of apiProviderRegistry.entries()) {
if (entry.sourceId === sourceId) {
apiProviderRegistry.delete(api);
}
}
}
export function clearApiProviders(): void {
apiProviderRegistry.clear();
}

View file

@ -0,0 +1,9 @@
import {
streamBedrock,
streamSimpleBedrock,
} from "./providers/amazon-bedrock.js";
export const bedrockProviderModule = {
streamBedrock,
streamSimpleBedrock,
};

152
packages/ai/src/cli.ts Normal file
View file

@ -0,0 +1,152 @@
#!/usr/bin/env node
import { existsSync, readFileSync, writeFileSync } from "fs";
import { createInterface } from "readline";
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
const AUTH_FILE = "auth.json";
const PROVIDERS = getOAuthProviders();
function prompt(
rl: ReturnType<typeof createInterface>,
question: string,
): Promise<string> {
return new Promise((resolve) => rl.question(question, resolve));
}
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
if (!existsSync(AUTH_FILE)) return {};
try {
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
} catch {
return {};
}
}
function saveAuth(
auth: Record<string, { type: "oauth" } & OAuthCredentials>,
): void {
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
}
async function login(providerId: OAuthProviderId): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
console.error(`Unknown provider: ${providerId}`);
process.exit(1);
}
const rl = createInterface({ input: process.stdin, output: process.stdout });
const promptFn = (msg: string) => prompt(rl, `${msg} `);
try {
const credentials = await provider.login({
onAuth: (info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions);
console.log();
},
onPrompt: async (p) => {
return await promptFn(
`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`,
);
},
onProgress: (msg) => console.log(msg),
});
const auth = loadAuth();
auth[providerId] = { type: "oauth", ...credentials };
saveAuth(auth);
console.log(`\nCredentials saved to ${AUTH_FILE}`);
} finally {
rl.close();
}
}
async function main(): Promise<void> {
const args = process.argv.slice(2);
const command = args[0];
if (
!command ||
command === "help" ||
command === "--help" ||
command === "-h"
) {
const providerList = PROVIDERS.map(
(p) => ` ${p.id.padEnd(20)} ${p.name}`,
).join("\n");
console.log(`Usage: npx @mariozechner/pi-ai <command> [provider]
Commands:
login [provider] Login to an OAuth provider
list List available providers
Providers:
${providerList}
Examples:
npx @mariozechner/pi-ai login # interactive provider selection
npx @mariozechner/pi-ai login anthropic # login to specific provider
npx @mariozechner/pi-ai list # list providers
`);
return;
}
if (command === "list") {
console.log("Available OAuth providers:\n");
for (const p of PROVIDERS) {
console.log(` ${p.id.padEnd(20)} ${p.name}`);
}
return;
}
if (command === "login") {
let provider = args[1] as OAuthProviderId | undefined;
if (!provider) {
const rl = createInterface({
input: process.stdin,
output: process.stdout,
});
console.log("Select a provider:\n");
for (let i = 0; i < PROVIDERS.length; i++) {
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
}
console.log();
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
rl.close();
const index = parseInt(choice, 10) - 1;
if (index < 0 || index >= PROVIDERS.length) {
console.error("Invalid selection");
process.exit(1);
}
provider = PROVIDERS[index].id;
}
if (!PROVIDERS.some((p) => p.id === provider)) {
console.error(`Unknown provider: ${provider}`);
console.error(
`Use 'npx @mariozechner/pi-ai list' to see available providers`,
);
process.exit(1);
}
console.log(`Logging in to ${provider}...`);
await login(provider);
return;
}
console.error(`Unknown command: ${command}`);
console.error(`Use 'npx @mariozechner/pi-ai --help' for usage`);
process.exit(1);
}
main().catch((err) => {
console.error("Error:", err.message);
process.exit(1);
});

View file

@ -0,0 +1,145 @@
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _existsSync: typeof import("node:fs").existsSync | null = null;
let _homedir: typeof import("node:os").homedir | null = null;
let _join: typeof import("node:path").join | null = null;
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const NODE_FS_SPECIFIER = "node:" + "fs";
const NODE_OS_SPECIFIER = "node:" + "os";
const NODE_PATH_SPECIFIER = "node:" + "path";
// Eagerly load in Node.js/Bun environment only
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
_existsSync = (m as typeof import("node:fs")).existsSync;
});
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
_homedir = (m as typeof import("node:os")).homedir;
});
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
_join = (m as typeof import("node:path")).join;
});
}
import type { KnownProvider } from "./types.js";
let cachedVertexAdcCredentialsExists: boolean | null = null;
function hasVertexAdcCredentials(): boolean {
if (cachedVertexAdcCredentialsExists === null) {
// If node modules haven't loaded yet (async import race at startup),
// return false WITHOUT caching so the next call retries once they're ready.
// Only cache false permanently in a browser environment where fs is never available.
if (!_existsSync || !_homedir || !_join) {
const isNode =
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun);
if (!isNode) {
// Definitively in a browser — safe to cache false permanently
cachedVertexAdcCredentialsExists = false;
}
return false;
}
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
if (gacPath) {
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
} else {
// Fall back to default ADC path (lazy evaluation)
cachedVertexAdcCredentialsExists = _existsSync(
_join(
_homedir(),
".config",
"gcloud",
"application_default_credentials.json",
),
);
}
}
return cachedVertexAdcCredentialsExists;
}
/**
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
*
* Will not return API keys for providers that require OAuth tokens.
*/
export function getEnvApiKey(provider: KnownProvider): string | undefined;
export function getEnvApiKey(provider: string): string | undefined;
export function getEnvApiKey(provider: any): string | undefined {
// Fall back to environment variables
if (provider === "github-copilot") {
return (
process.env.COPILOT_GITHUB_TOKEN ||
process.env.GH_TOKEN ||
process.env.GITHUB_TOKEN
);
}
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
if (provider === "anthropic") {
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
}
// Vertex AI uses Application Default Credentials, not API keys.
// Auth is configured via `gcloud auth application-default login`.
if (provider === "google-vertex") {
const hasCredentials = hasVertexAdcCredentials();
const hasProject = !!(
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
);
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
if (hasCredentials && hasProject && hasLocation) {
return "<authenticated>";
}
}
if (provider === "amazon-bedrock") {
// Amazon Bedrock supports multiple credential sources:
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
if (
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
) {
return "<authenticated>";
}
}
const envMap: Record<string, string> = {
openai: "OPENAI_API_KEY",
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
google: "GEMINI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
xai: "XAI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
zai: "ZAI_API_KEY",
mistral: "MISTRAL_API_KEY",
minimax: "MINIMAX_API_KEY",
"minimax-cn": "MINIMAX_CN_API_KEY",
huggingface: "HF_TOKEN",
opencode: "OPENCODE_API_KEY",
"opencode-go": "OPENCODE_API_KEY",
"kimi-coding": "KIMI_API_KEY",
};
const envVar = envMap[provider];
return envVar ? process.env[envVar] : undefined;
}

32
packages/ai/src/index.ts Normal file
View file

@ -0,0 +1,32 @@
export type { Static, TSchema } from "@sinclair/typebox";
export { Type } from "@sinclair/typebox";
export * from "./api-registry.js";
export * from "./env-api-keys.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/azure-openai-responses.js";
export * from "./providers/google.js";
export * from "./providers/google-gemini-cli.js";
export * from "./providers/google-vertex.js";
export * from "./providers/mistral.js";
export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./providers/register-builtins.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/json-parse.js";
export type {
OAuthAuthInfo,
OAuthCredentials,
OAuthLoginCallbacks,
OAuthPrompt,
OAuthProvider,
OAuthProviderId,
OAuthProviderInfo,
OAuthProviderInterface,
} from "./utils/oauth/types.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";
export * from "./utils/validation.js";

File diff suppressed because it is too large Load diff

101
packages/ai/src/models.ts Normal file
View file

@ -0,0 +1,101 @@
import { MODELS } from "./models.generated.js";
import type { Api, KnownProvider, Model, Usage } from "./types.js";
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
// Initialize registry from MODELS on module load
for (const [provider, models] of Object.entries(MODELS)) {
const providerModels = new Map<string, Model<Api>>();
for (const [id, model] of Object.entries(models)) {
providerModels.set(id, model as Model<Api>);
}
modelRegistry.set(provider, providerModels);
}
type ModelApi<
TProvider extends KnownProvider,
TModelId extends keyof (typeof MODELS)[TProvider],
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi }
? TApi extends Api
? TApi
: never
: never;
export function getModel<
TProvider extends KnownProvider,
TModelId extends keyof (typeof MODELS)[TProvider],
>(
provider: TProvider,
modelId: TModelId,
): Model<ModelApi<TProvider, TModelId>> {
const providerModels = modelRegistry.get(provider);
return providerModels?.get(modelId as string) as Model<
ModelApi<TProvider, TModelId>
>;
}
export function getProviders(): KnownProvider[] {
return Array.from(modelRegistry.keys()) as KnownProvider[];
}
export function getModels<TProvider extends KnownProvider>(
provider: TProvider,
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
const models = modelRegistry.get(provider);
return models
? (Array.from(models.values()) as Model<
ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>
>[])
: [];
}
export function calculateCost<TApi extends Api>(
model: Model<TApi>,
usage: Usage,
): Usage["cost"] {
usage.cost.input = (model.cost.input / 1000000) * usage.input;
usage.cost.output = (model.cost.output / 1000000) * usage.output;
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
usage.cost.total =
usage.cost.input +
usage.cost.output +
usage.cost.cacheRead +
usage.cost.cacheWrite;
return usage.cost;
}
/**
* Check if a model supports xhigh thinking level.
*
* Supported today:
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
* - Anthropic Messages API Opus 4.6 models (xhigh maps to adaptive effort "max")
*/
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
if (
model.id.includes("gpt-5.2") ||
model.id.includes("gpt-5.3") ||
model.id.includes("gpt-5.4")
) {
return true;
}
if (model.api === "anthropic-messages") {
return model.id.includes("opus-4-6") || model.id.includes("opus-4.6");
}
return false;
}
/**
* Check if two models are equal by comparing both their id and provider.
* Returns false if either model is null or undefined.
*/
export function modelsAreEqual<TApi extends Api>(
a: Model<TApi> | null | undefined,
b: Model<TApi> | null | undefined,
): boolean {
if (!a || !b) return false;
return a.id === b.id && a.provider === b.provider;
}

1
packages/ai/src/oauth.ts Normal file
View file

@ -0,0 +1 @@
export * from "./utils/oauth/index.js";

View file

@ -0,0 +1,894 @@
import {
BedrockRuntimeClient,
type BedrockRuntimeClientConfig,
StopReason as BedrockStopReason,
type Tool as BedrockTool,
CachePointType,
CacheTTL,
type ContentBlock,
type ContentBlockDeltaEvent,
type ContentBlockStartEvent,
type ContentBlockStopEvent,
ConversationRole,
ConverseStreamCommand,
type ConverseStreamMetadataEvent,
ImageFormat,
type Message,
type SystemContentBlock,
type ToolChoice,
type ToolConfiguration,
ToolResultStatus,
} from "@aws-sdk/client-bedrock-runtime";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
adjustMaxTokensForThinking,
buildBaseOptions,
clampReasoning,
} from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
export interface BedrockOptions extends StreamOptions {
region?: string;
profile?: string;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
reasoning?: ThinkingLevel;
/* Custom token budgets per thinking level. Overrides default budgets. */
thinkingBudgets?: ThinkingBudgets;
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
interleavedThinking?: boolean;
}
type Block = (TextContent | ThinkingContent | ToolCall) & {
index?: number;
partialJson?: string;
};
export const streamBedrock: StreamFunction<
"bedrock-converse-stream",
BedrockOptions
> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options: BedrockOptions = {},
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "bedrock-converse-stream" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const blocks = output.content as Block[];
const config: BedrockRuntimeClientConfig = {
profile: options.profile,
};
// in Node.js/Bun environment only
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
// Region resolution: explicit option > env vars > SDK default chain.
// When AWS_PROFILE is set, we leave region undefined so the SDK can
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
const explicitRegion =
options.region ||
process.env.AWS_REGION ||
process.env.AWS_DEFAULT_REGION;
if (explicitRegion) {
config.region = explicitRegion;
} else if (!process.env.AWS_PROFILE) {
config.region = "us-east-1";
}
// Support proxies that don't need authentication
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
config.credentials = {
accessKeyId: "dummy-access-key",
secretAccessKey: "dummy-secret-key",
};
}
if (
process.env.HTTP_PROXY ||
process.env.HTTPS_PROXY ||
process.env.NO_PROXY ||
process.env.http_proxy ||
process.env.https_proxy ||
process.env.no_proxy
) {
const nodeHttpHandler = await import("@smithy/node-http-handler");
const proxyAgent = await import("proxy-agent");
const agent = new proxyAgent.ProxyAgent();
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
// on `http2` module and has no support for http agent.
// Use NodeHttpHandler to support http agent.
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
httpAgent: agent,
httpsAgent: agent,
});
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
const nodeHttpHandler = await import("@smithy/node-http-handler");
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
}
} else {
// Non-Node environment (browser): fall back to us-east-1 since
// there's no config file resolution available.
config.region = options.region || "us-east-1";
}
try {
const client = new BedrockRuntimeClient(config);
const cacheRetention = resolveCacheRetention(options.cacheRetention);
const commandInput = {
modelId: model.id,
messages: convertMessages(context, model, cacheRetention),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
inferenceConfig: {
maxTokens: options.maxTokens,
temperature: options.temperature,
},
toolConfig: convertToolConfig(context.tools, options.toolChoice),
additionalModelRequestFields: buildAdditionalModelRequestFields(
model,
options,
),
};
options?.onPayload?.(commandInput);
const command = new ConverseStreamCommand(commandInput);
const response = await client.send(command, {
abortSignal: options.signal,
});
for await (const item of response.stream!) {
if (item.messageStart) {
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
throw new Error(
"Unexpected assistant message start but got user message start instead",
);
}
stream.push({ type: "start", partial: output });
} else if (item.contentBlockStart) {
handleContentBlockStart(
item.contentBlockStart,
blocks,
output,
stream,
);
} else if (item.contentBlockDelta) {
handleContentBlockDelta(
item.contentBlockDelta,
blocks,
output,
stream,
);
} else if (item.contentBlockStop) {
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
} else if (item.messageStop) {
output.stopReason = mapStopReason(item.messageStop.stopReason);
} else if (item.metadata) {
handleMetadata(item.metadata, model, output);
} else if (item.internalServerException) {
throw new Error(
`Internal server error: ${item.internalServerException.message}`,
);
} else if (item.modelStreamErrorException) {
throw new Error(
`Model stream error: ${item.modelStreamErrorException.message}`,
);
} else if (item.validationException) {
throw new Error(
`Validation error: ${item.validationException.message}`,
);
} else if (item.throttlingException) {
throw new Error(
`Throttling error: ${item.throttlingException.message}`,
);
} else if (item.serviceUnavailableException) {
throw new Error(
`Service unavailable: ${item.serviceUnavailableException.message}`,
);
}
}
if (options.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "error" || output.stopReason === "aborted") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) {
delete (block as Block).index;
delete (block as Block).partialJson;
}
output.stopReason = options.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleBedrock: StreamFunction<
"bedrock-converse-stream",
SimpleStreamOptions
> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamBedrock(model, context, {
...base,
reasoning: undefined,
} satisfies BedrockOptions);
}
if (
model.id.includes("anthropic.claude") ||
model.id.includes("anthropic/claude")
) {
if (supportsAdaptiveThinking(model.id)) {
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamBedrock(model, context, {
...base,
maxTokens: adjusted.maxTokens,
reasoning: options.reasoning,
thinkingBudgets: {
...(options.thinkingBudgets || {}),
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
},
} satisfies BedrockOptions);
}
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
};
function handleContentBlockStart(
event: ContentBlockStartEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = event.contentBlockIndex!;
const start = event.start;
if (start?.toolUse) {
const block: Block = {
type: "toolCall",
id: start.toolUse.toolUseId || "",
name: start.toolUse.name || "",
arguments: {},
partialJson: "",
index,
};
output.content.push(block);
stream.push({
type: "toolcall_start",
contentIndex: blocks.length - 1,
partial: output,
});
}
}
function handleContentBlockDelta(
event: ContentBlockDeltaEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const contentBlockIndex = event.contentBlockIndex!;
const delta = event.delta;
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
let block = blocks[index];
if (delta?.text !== undefined) {
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
if (!block) {
const newBlock: Block = {
type: "text",
text: "",
index: contentBlockIndex,
};
output.content.push(newBlock);
index = blocks.length - 1;
block = blocks[index];
stream.push({ type: "text_start", contentIndex: index, partial: output });
}
if (block.type === "text") {
block.text += delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: delta.text,
partial: output,
});
}
} else if (delta?.toolUse && block?.type === "toolCall") {
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
block.arguments = parseStreamingJson(block.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: delta.toolUse.input || "",
partial: output,
});
} else if (delta?.reasoningContent) {
let thinkingBlock = block;
let thinkingIndex = index;
if (!thinkingBlock) {
const newBlock: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: contentBlockIndex,
};
output.content.push(newBlock);
thinkingIndex = blocks.length - 1;
thinkingBlock = blocks[thinkingIndex];
stream.push({
type: "thinking_start",
contentIndex: thinkingIndex,
partial: output,
});
}
if (thinkingBlock?.type === "thinking") {
if (delta.reasoningContent.text) {
thinkingBlock.thinking += delta.reasoningContent.text;
stream.push({
type: "thinking_delta",
contentIndex: thinkingIndex,
delta: delta.reasoningContent.text,
partial: output,
});
}
if (delta.reasoningContent.signature) {
thinkingBlock.thinkingSignature =
(thinkingBlock.thinkingSignature || "") +
delta.reasoningContent.signature;
}
}
}
}
function handleMetadata(
event: ConverseStreamMetadataEvent,
model: Model<"bedrock-converse-stream">,
output: AssistantMessage,
): void {
if (event.usage) {
output.usage.input = event.usage.inputTokens || 0;
output.usage.output = event.usage.outputTokens || 0;
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
output.usage.totalTokens =
event.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
}
function handleContentBlockStop(
event: ContentBlockStopEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
const block = blocks[index];
if (!block) return;
delete (block as Block).index;
switch (block.type) {
case "text":
stream.push({
type: "text_end",
contentIndex: index,
content: block.text,
partial: output,
});
break;
case "thinking":
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: output,
});
break;
case "toolCall":
block.arguments = parseStreamingJson(block.partialJson);
delete (block as Block).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
break;
}
}
/**
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
*/
function supportsAdaptiveThinking(modelId: string): boolean {
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
function mapThinkingLevelToEffort(
level: SimpleStreamOptions["reasoning"],
modelId: string,
): "low" | "medium" | "high" | "max" {
switch (level) {
case "minimal":
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
? "max"
: "high";
default:
return "high";
}
}
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
/**
* Check if the model supports prompt caching.
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
*/
function supportsPromptCaching(
model: Model<"bedrock-converse-stream">,
): boolean {
if (model.cost.cacheRead || model.cost.cacheWrite) {
return true;
}
const id = model.id.toLowerCase();
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4.")))
return true;
// Claude 3.7 Sonnet
if (id.includes("claude-3-7-sonnet")) return true;
// Claude 3.5 Haiku
if (id.includes("claude-3-5-haiku")) return true;
return false;
}
/**
* Check if the model supports thinking signatures in reasoningContent.
* Only Anthropic Claude models support the signature field.
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
* "This model doesn't support the reasoningContent.reasoningText.signature field"
*/
function supportsThinkingSignature(
model: Model<"bedrock-converse-stream">,
): boolean {
const id = model.id.toLowerCase();
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
}
function buildSystemPrompt(
systemPrompt: string | undefined,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): SystemContentBlock[] | undefined {
if (!systemPrompt) return undefined;
const blocks: SystemContentBlock[] = [
{ text: sanitizeSurrogates(systemPrompt) },
];
// Add cache point for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
blocks.push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
return blocks;
}
function normalizeToolCallId(id: string): string {
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
}
function convertMessages(
context: Context,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): Message[] {
const result: Message[] = [];
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
for (let i = 0; i < transformedMessages.length; i++) {
const m = transformedMessages[i];
switch (m.role) {
case "user":
result.push({
role: ConversationRole.USER,
content:
typeof m.content === "string"
? [{ text: sanitizeSurrogates(m.content) }]
: m.content.map((c) => {
switch (c.type) {
case "text":
return { text: sanitizeSurrogates(c.text) };
case "image":
return { image: createImageBlock(c.mimeType, c.data) };
default:
throw new Error("Unknown user content type");
}
}),
});
break;
case "assistant": {
// Skip assistant messages with empty content (e.g., from aborted requests)
// Bedrock rejects messages with empty content arrays
if (m.content.length === 0) {
continue;
}
const contentBlocks: ContentBlock[] = [];
for (const c of m.content) {
switch (c.type) {
case "text":
// Skip empty text blocks
if (c.text.trim().length === 0) continue;
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
break;
case "toolCall":
contentBlocks.push({
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
});
break;
case "thinking":
// Skip empty thinking blocks
if (c.thinking.trim().length === 0) continue;
// Only Anthropic models support the signature field in reasoningText.
// For other models, we omit the signature to avoid errors like:
// "This model doesn't support the reasoningContent.reasoningText.signature field"
if (supportsThinkingSignature(model)) {
contentBlocks.push({
reasoningContent: {
reasoningText: {
text: sanitizeSurrogates(c.thinking),
signature: c.thinkingSignature,
},
},
});
} else {
contentBlocks.push({
reasoningContent: {
reasoningText: { text: sanitizeSurrogates(c.thinking) },
},
});
}
break;
default:
throw new Error("Unknown assistant content type");
}
}
// Skip if all content blocks were filtered out
if (contentBlocks.length === 0) {
continue;
}
result.push({
role: ConversationRole.ASSISTANT,
content: contentBlocks,
});
break;
}
case "toolResult": {
// Collect all consecutive toolResult messages into a single user message
// Bedrock requires all tool results to be in one message
const toolResults: ContentBlock.ToolResultMember[] = [];
// Add current tool result with all content blocks combined
toolResults.push({
toolResult: {
toolUseId: m.toolCallId,
content: m.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: m.isError
? ToolResultStatus.ERROR
: ToolResultStatus.SUCCESS,
},
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult"
) {
const nextMsg = transformedMessages[j] as ToolResultMessage;
toolResults.push({
toolResult: {
toolUseId: nextMsg.toolCallId,
content: nextMsg.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: nextMsg.isError
? ToolResultStatus.ERROR
: ToolResultStatus.SUCCESS,
},
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
result.push({
role: ConversationRole.USER,
content: toolResults,
});
break;
}
default:
throw new Error("Unknown message role");
}
}
// Add cache point to the last user message for supported Claude models when caching is enabled
if (
cacheRetention !== "none" &&
supportsPromptCaching(model) &&
result.length > 0
) {
const lastMessage = result[result.length - 1];
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
(lastMessage.content as ContentBlock[]).push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
}
return result;
}
function convertToolConfig(
tools: Tool[] | undefined,
toolChoice: BedrockOptions["toolChoice"],
): ToolConfiguration | undefined {
if (!tools?.length || toolChoice === "none") return undefined;
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
toolSpec: {
name: tool.name,
description: tool.description,
inputSchema: { json: tool.parameters },
},
}));
let bedrockToolChoice: ToolChoice | undefined;
switch (toolChoice) {
case "auto":
bedrockToolChoice = { auto: {} };
break;
case "any":
bedrockToolChoice = { any: {} };
break;
default:
if (toolChoice?.type === "tool") {
bedrockToolChoice = { tool: { name: toolChoice.name } };
}
}
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
}
function mapStopReason(reason: string | undefined): StopReason {
switch (reason) {
case BedrockStopReason.END_TURN:
case BedrockStopReason.STOP_SEQUENCE:
return "stop";
case BedrockStopReason.MAX_TOKENS:
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
return "length";
case BedrockStopReason.TOOL_USE:
return "toolUse";
default:
return "error";
}
}
function buildAdditionalModelRequestFields(
model: Model<"bedrock-converse-stream">,
options: BedrockOptions,
): Record<string, any> | undefined {
if (!options.reasoning || !model.reasoning) {
return undefined;
}
if (
model.id.includes("anthropic.claude") ||
model.id.includes("anthropic/claude")
) {
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
? {
thinking: { type: "adaptive" },
output_config: {
effort: mapThinkingLevelToEffort(options.reasoning, model.id),
},
}
: (() => {
const defaultBudgets: Record<ThinkingLevel, number> = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
};
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
const level =
options.reasoning === "xhigh" ? "high" : options.reasoning;
const budget =
options.thinkingBudgets?.[level] ??
defaultBudgets[options.reasoning];
return {
thinking: {
type: "enabled",
budget_tokens: budget,
},
};
})();
if (
!supportsAdaptiveThinking(model.id) &&
(options.interleavedThinking ?? true)
) {
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
}
return result;
}
return undefined;
}
function createImageBlock(mimeType: string, data: string) {
let format: ImageFormat;
switch (mimeType) {
case "image/jpeg":
case "image/jpg":
format = ImageFormat.JPEG;
break;
case "image/png":
format = ImageFormat.PNG;
break;
case "image/gif":
format = ImageFormat.GIF;
break;
case "image/webp":
format = ImageFormat.WEBP;
break;
default:
throw new Error(`Unknown image type: ${mimeType}`);
}
const binaryString = atob(data);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return { source: { bytes }, format };
}

View file

@ -0,0 +1,989 @@
import Anthropic from "@anthropic-ai/sdk";
import type {
ContentBlockParam,
MessageCreateParamsStreaming,
MessageParam,
} from "@anthropic-ai/sdk/resources/messages.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
ImageContent,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import {
adjustMaxTokensForThinking,
buildBaseOptions,
} from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
function getCacheControl(
baseUrl: string,
cacheRetention?: CacheRetention,
): {
retention: CacheRetention;
cacheControl?: { type: "ephemeral"; ttl?: "1h" };
} {
const retention = resolveCacheRetention(cacheRetention);
if (retention === "none") {
return { retention };
}
const ttl =
retention === "long" && baseUrl.includes("api.anthropic.com")
? "1h"
: undefined;
return {
retention,
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
};
}
// Stealth mode: Mimic Claude Code's tool naming exactly
const claudeCodeVersion = "2.1.62";
// Claude Code 2.x tool names (canonical casing)
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
// To update: https://github.com/badlogic/cchistory
const claudeCodeTools = [
"Read",
"Write",
"Edit",
"Bash",
"Grep",
"Glob",
"AskUserQuestion",
"EnterPlanMode",
"ExitPlanMode",
"KillShell",
"NotebookEdit",
"Skill",
"Task",
"TaskOutput",
"TodoWrite",
"WebFetch",
"WebSearch",
];
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
// Convert tool name to CC canonical casing if it matches (case-insensitive)
const toClaudeCodeName = (name: string) =>
ccToolLookup.get(name.toLowerCase()) ?? name;
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
if (tools && tools.length > 0) {
const lowerName = name.toLowerCase();
const matchedTool = tools.find(
(tool) => tool.name.toLowerCase() === lowerName,
);
if (matchedTool) return matchedTool.name;
}
return name;
};
/**
* Convert content blocks to Anthropic API format
*/
function convertContentBlocks(content: (TextContent | ImageContent)[]):
| string
| Array<
| { type: "text"; text: string }
| {
type: "image";
source: {
type: "base64";
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
data: string;
};
}
> {
// If only text blocks, return as concatenated string for simplicity
const hasImages = content.some((c) => c.type === "image");
if (!hasImages) {
return sanitizeSurrogates(
content.map((c) => (c as TextContent).text).join("\n"),
);
}
// If we have images, convert to content block array
const blocks = content.map((block) => {
if (block.type === "text") {
return {
type: "text" as const,
text: sanitizeSurrogates(block.text),
};
}
return {
type: "image" as const,
source: {
type: "base64" as const,
media_type: block.mimeType as
| "image/jpeg"
| "image/png"
| "image/gif"
| "image/webp",
data: block.data,
},
};
});
// If only images (no text), add placeholder text block
const hasText = blocks.some((b) => b.type === "text");
if (!hasText) {
blocks.unshift({
type: "text" as const,
text: "(see attached image)",
});
}
return blocks;
}
export type AnthropicEffort = "low" | "medium" | "high" | "max";
export interface AnthropicOptions extends StreamOptions {
/**
* Enable extended thinking.
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
* For older models: uses budget-based thinking with thinkingBudgetTokens.
*/
thinkingEnabled?: boolean;
/**
* Token budget for extended thinking (older models only).
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
*/
thinkingBudgetTokens?: number;
/**
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
* Controls how much thinking Claude allocates:
* - "max": Always thinks with no constraints (Opus 4.6 only)
* - "high": Always thinks, deep reasoning (default)
* - "medium": Moderate thinking, may skip for simple queries
* - "low": Minimal thinking, skips for simple tasks
* Ignored for older models.
*/
effort?: AnthropicEffort;
interleavedThinking?: boolean;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
function mergeHeaders(
...headerSources: (Record<string, string> | undefined)[]
): Record<string, string> {
const merged: Record<string, string> = {};
for (const headers of headerSources) {
if (headers) {
Object.assign(merged, headers);
}
}
return merged;
}
export const streamAnthropic: StreamFunction<
"anthropic-messages",
AnthropicOptions
> = (
model: Model<"anthropic-messages">,
context: Context,
options?: AnthropicOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
let copilotDynamicHeaders: Record<string, string> | undefined;
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
copilotDynamicHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
}
const { client, isOAuthToken } = createClient(
model,
apiKey,
options?.interleavedThinking ?? true,
options?.headers,
copilotDynamicHeaders,
);
const params = buildParams(model, context, isOAuthToken, options);
options?.onPayload?.(params);
const anthropicStream = client.messages.stream(
{ ...params, stream: true },
{ signal: options?.signal },
);
stream.push({ type: "start", partial: output });
type Block = (
| ThinkingContent
| TextContent
| (ToolCall & { partialJson: string })
) & { index: number };
const blocks = output.content as Block[];
for await (const event of anthropicStream) {
if (event.type === "message_start") {
// Capture initial token usage from message_start event
// This ensures we have input token counts even if the stream is aborted early
output.usage.input = event.message.usage.input_tokens || 0;
output.usage.output = event.message.usage.output_tokens || 0;
output.usage.cacheRead =
event.message.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite =
event.message.usage.cache_creation_input_tokens || 0;
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
} else if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
const block: Block = {
type: "text",
text: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "text_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "thinking") {
const block: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "thinking_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "redacted_thinking") {
const block: Block = {
type: "thinking",
thinking: "[Reasoning redacted]",
thinkingSignature: event.content_block.data,
redacted: true,
index: event.index,
};
output.content.push(block);
stream.push({
type: "thinking_start",
contentIndex: output.content.length - 1,
partial: output,
});
} else if (event.content_block.type === "tool_use") {
const block: Block = {
type: "toolCall",
id: event.content_block.id,
name: isOAuthToken
? fromClaudeCodeName(event.content_block.name, context.tools)
: event.content_block.name,
arguments:
(event.content_block.input as Record<string, any>) ?? {},
partialJson: "",
index: event.index,
};
output.content.push(block);
stream.push({
type: "toolcall_start",
contentIndex: output.content.length - 1,
partial: output,
});
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "text") {
block.text += event.delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: index,
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "toolCall") {
block.partialJson += event.delta.partial_json;
block.arguments = parseStreamingJson(block.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: event.delta.partial_json,
partial: output,
});
}
} else if (event.delta.type === "signature_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinkingSignature = block.thinkingSignature || "";
block.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block) {
delete (block as any).index;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: index,
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialJson);
delete (block as any).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
}
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
output.stopReason = mapStopReason(event.delta.stop_reason);
}
// Only update usage fields if present (not null).
// Preserves input_tokens from message_start when proxies omit it in message_delta.
if (event.usage.input_tokens != null) {
output.usage.input = event.usage.input_tokens;
}
if (event.usage.output_tokens != null) {
output.usage.output = event.usage.output_tokens;
}
if (event.usage.cache_read_input_tokens != null) {
output.usage.cacheRead = event.usage.cache_read_input_tokens;
}
if (event.usage.cache_creation_input_tokens != null) {
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
}
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input +
output.usage.output +
output.usage.cacheRead +
output.usage.cacheWrite;
calculateCost(model, output.usage);
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
*/
function supportsAdaptiveThinking(modelId: string): boolean {
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
/**
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
* Note: effort "max" is only valid on Opus 4.6.
*/
function mapThinkingLevelToEffort(
level: SimpleStreamOptions["reasoning"],
modelId: string,
): AnthropicEffort {
switch (level) {
case "minimal":
return "low";
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
? "max"
: "high";
default:
return "high";
}
}
export const streamSimpleAnthropic: StreamFunction<
"anthropic-messages",
SimpleStreamOptions
> = (
model: Model<"anthropic-messages">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamAnthropic(model, context, {
...base,
thinkingEnabled: false,
} satisfies AnthropicOptions);
}
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
// For older models: use budget-based thinking
if (supportsAdaptiveThinking(model.id)) {
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
return streamAnthropic(model, context, {
...base,
thinkingEnabled: true,
effort,
} satisfies AnthropicOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamAnthropic(model, context, {
...base,
maxTokens: adjusted.maxTokens,
thinkingEnabled: true,
thinkingBudgetTokens: adjusted.thinkingBudget,
} satisfies AnthropicOptions);
};
function isOAuthToken(apiKey: string): boolean {
return apiKey.includes("sk-ant-oat");
}
function createClient(
model: Model<"anthropic-messages">,
apiKey: string,
interleavedThinking: boolean,
optionsHeaders?: Record<string, string>,
dynamicHeaders?: Record<string, string>,
): { client: Anthropic; isOAuthToken: boolean } {
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
const needsInterleavedBeta =
interleavedThinking && !supportsAdaptiveThinking(model.id);
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
if (model.provider === "github-copilot") {
const betaFeatures: string[] = [];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
...(betaFeatures.length > 0
? { "anthropic-beta": betaFeatures.join(",") }
: {}),
},
model.headers,
dynamicHeaders,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
// OAuth: Bearer auth, Claude Code identity headers
if (isOAuthToken(apiKey)) {
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
"user-agent": `claude-cli/${claudeCodeVersion}`,
"x-app": "cli",
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: true };
}
// API key auth
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": betaFeatures.join(","),
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
function buildParams(
model: Model<"anthropic-messages">,
context: Context,
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(
model.baseUrl,
options?.cacheRetention,
);
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(
context.messages,
model,
isOAuthToken,
cacheControl,
),
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
});
}
} else if (context.systemPrompt) {
// Add cache control to system prompt for non-OAuth tokens
params.system = [
{
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
}
// Temperature is incompatible with extended thinking (adaptive or budget-based).
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, isOAuthToken);
}
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6) or budget-based (older models)
if (options?.thinkingEnabled && model.reasoning) {
if (supportsAdaptiveThinking(model.id)) {
// Adaptive thinking: Claude decides when and how much to think
params.thinking = { type: "adaptive" };
if (options.effort) {
params.output_config = { effort: options.effort };
}
} else {
// Budget-based thinking for older models
params.thinking = {
type: "enabled",
budget_tokens: options.thinkingBudgetTokens || 1024,
};
}
}
if (options?.metadata) {
const userId = options.metadata.user_id;
if (typeof userId === "string") {
params.metadata = { user_id: userId };
}
}
if (options?.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
return params;
}
// Normalize tool call IDs to match Anthropic's required pattern and length
function normalizeToolCallId(id: string): string {
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
}
function convertMessages(
messages: Message[],
model: Model<"anthropic-messages">,
isOAuthToken: boolean,
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
): MessageParam[] {
const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(
messages,
model,
normalizeToolCallId,
);
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim().length > 0) {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
}
} else {
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
};
} else {
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as
| "image/jpeg"
| "image/png"
| "image/gif"
| "image/webp",
data: item.data,
},
};
}
});
let filteredBlocks = !model?.input.includes("image")
? blocks.filter((b) => b.type !== "image")
: blocks;
filteredBlocks = filteredBlocks.filter((b) => {
if (b.type === "text") {
return b.text.trim().length > 0;
}
return true;
});
if (filteredBlocks.length === 0) continue;
params.push({
role: "user",
content: filteredBlocks,
});
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length === 0) continue;
blocks.push({
type: "text",
text: sanitizeSurrogates(block.text),
});
} else if (block.type === "thinking") {
// Redacted thinking: pass the opaque payload back as redacted_thinking
if (block.redacted) {
blocks.push({
type: "redacted_thinking",
data: block.thinkingSignature!,
});
continue;
}
if (block.thinking.trim().length === 0) continue;
// If thinking signature is missing/empty (e.g., from aborted stream),
// convert to plain text block without <thinking> tags to avoid API rejection
// and prevent Claude from mimicking the tags in responses
if (
!block.thinkingSignature ||
block.thinkingSignature.trim().length === 0
) {
blocks.push({
type: "text",
text: sanitizeSurrogates(block.thinking),
});
} else {
blocks.push({
type: "thinking",
thinking: sanitizeSurrogates(block.thinking),
signature: block.thinkingSignature,
});
}
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
input: block.arguments ?? {},
});
}
}
if (blocks.length === 0) continue;
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
const toolResults: ContentBlockParam[] = [];
// Add the current tool result
toolResults.push({
type: "tool_result",
tool_use_id: msg.toolCallId,
content: convertContentBlocks(msg.content),
is_error: msg.isError,
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult"
) {
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
toolResults.push({
type: "tool_result",
tool_use_id: nextMsg.toolCallId,
content: convertContentBlocks(nextMsg.content),
is_error: nextMsg.isError,
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
// Add a single user message with all tool results
params.push({
role: "user",
content: toolResults,
});
}
}
// Add cache_control to the last user message to cache conversation history
if (cacheControl && params.length > 0) {
const lastMessage = params[params.length - 1];
if (lastMessage.role === "user") {
if (Array.isArray(lastMessage.content)) {
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
if (
lastBlock &&
(lastBlock.type === "text" ||
lastBlock.type === "image" ||
lastBlock.type === "tool_result")
) {
(lastBlock as any).cache_control = cacheControl;
}
} else if (typeof lastMessage.content === "string") {
lastMessage.content = [
{
type: "text",
text: lastMessage.content,
cache_control: cacheControl,
},
] as any;
}
}
}
return params;
}
function convertTools(
tools: Tool[],
isOAuthToken: boolean,
): Anthropic.Messages.Tool[] {
if (!tools) return [];
return tools.map((tool) => {
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
return {
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: jsonSchema.properties || {},
required: jsonSchema.required || [],
},
};
});
}
function mapStopReason(
reason: Anthropic.Messages.StopReason | string,
): StopReason {
switch (reason) {
case "end_turn":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
case "refusal":
return "error";
case "pause_turn": // Stop is good enough -> resubmit
return "stop";
case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
return "error";
default:
// Handle unknown stop reasons gracefully (API may add new values)
throw new Error(`Unhandled stop reason: ${reason}`);
}
}

View file

@ -0,0 +1,297 @@
import { AzureOpenAI } from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import {
convertResponsesMessages,
convertResponsesTools,
processResponsesStream,
} from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const DEFAULT_AZURE_API_VERSION = "v1";
const AZURE_TOOL_CALL_PROVIDERS = new Set([
"openai",
"openai-codex",
"opencode",
"azure-openai-responses",
]);
function parseDeploymentNameMap(
value: string | undefined,
): Map<string, string> {
const map = new Map<string, string>();
if (!value) return map;
for (const entry of value.split(",")) {
const trimmed = entry.trim();
if (!trimmed) continue;
const [modelId, deploymentName] = trimmed.split("=", 2);
if (!modelId || !deploymentName) continue;
map.set(modelId.trim(), deploymentName.trim());
}
return map;
}
function resolveDeploymentName(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): string {
if (options?.azureDeploymentName) {
return options.azureDeploymentName;
}
const mappedDeployment = parseDeploymentNameMap(
process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP,
).get(model.id);
return mappedDeployment || model.id;
}
// Azure OpenAI Responses-specific options
export interface AzureOpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
azureApiVersion?: string;
azureResourceName?: string;
azureBaseUrl?: string;
azureDeploymentName?: string;
}
/**
* Generate function for Azure OpenAI Responses API
*/
export const streamAzureOpenAIResponses: StreamFunction<
"azure-openai-responses",
AzureOpenAIResponsesOptions
> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: AzureOpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const deploymentName = resolveDeploymentName(model, options);
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "azure-openai-responses" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create Azure OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options);
const params = buildParams(model, context, options, deploymentName);
options?.onPayload?.(params);
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content)
delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleAzureOpenAIResponses: StreamFunction<
"azure-openai-responses",
SimpleStreamOptions
> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
return streamAzureOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies AzureOpenAIResponsesOptions);
};
function normalizeAzureBaseUrl(baseUrl: string): string {
return baseUrl.replace(/\/+$/, "");
}
function buildDefaultBaseUrl(resourceName: string): string {
return `https://${resourceName}.openai.azure.com/openai/v1`;
}
function resolveAzureConfig(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): { baseUrl: string; apiVersion: string } {
const apiVersion =
options?.azureApiVersion ||
process.env.AZURE_OPENAI_API_VERSION ||
DEFAULT_AZURE_API_VERSION;
const baseUrl =
options?.azureBaseUrl?.trim() ||
process.env.AZURE_OPENAI_BASE_URL?.trim() ||
undefined;
const resourceName =
options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
let resolvedBaseUrl = baseUrl;
if (!resolvedBaseUrl && resourceName) {
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
}
if (!resolvedBaseUrl && model.baseUrl) {
resolvedBaseUrl = model.baseUrl;
}
if (!resolvedBaseUrl) {
throw new Error(
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
);
}
return {
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
apiVersion,
};
}
function createClient(
model: Model<"azure-openai-responses">,
apiKey: string,
options?: AzureOpenAIResponsesOptions,
) {
if (!apiKey) {
if (!process.env.AZURE_OPENAI_API_KEY) {
throw new Error(
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.AZURE_OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (options?.headers) {
Object.assign(headers, options.headers);
}
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
return new AzureOpenAI({
apiKey,
apiVersion,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
baseURL: baseUrl,
});
}
function buildParams(
model: Model<"azure-openai-responses">,
context: Context,
options: AzureOpenAIResponsesOptions | undefined,
deploymentName: string,
) {
const messages = convertResponsesMessages(
model,
context,
AZURE_TOOL_CALL_PROVIDERS,
);
const params: ResponseCreateParamsStreaming = {
model: deploymentName,
input: messages,
stream: true,
prompt_cache_key: options?.sessionId,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.toLowerCase().startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}

View file

@ -0,0 +1,37 @@
import type { Message } from "../types.js";
// Copilot expects X-Initiator to indicate whether the request is user-initiated
// or agent-initiated (e.g. follow-up after assistant/tool messages).
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
const last = messages[messages.length - 1];
return last && last.role !== "user" ? "agent" : "user";
}
// Copilot requires Copilot-Vision-Request header when sending images
export function hasCopilotVisionInput(messages: Message[]): boolean {
return messages.some((msg) => {
if (msg.role === "user" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
return false;
});
}
export function buildCopilotDynamicHeaders(params: {
messages: Message[];
hasImages: boolean;
}): Record<string, string> {
const headers: Record<string, string> = {
"X-Initiator": inferCopilotInitiator(params.messages),
"Openai-Intent": "conversation-edits",
};
if (params.hasImages) {
headers["Copilot-Vision-Request"] = "true";
}
return headers;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,373 @@
/**
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
*/
import {
type Content,
FinishReason,
FunctionCallingConfigMode,
type Part,
} from "@google/genai";
import type {
Context,
ImageContent,
Model,
StopReason,
TextContent,
Tool,
} from "../types.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
type GoogleApiType =
| "google-generative-ai"
| "google-gemini-cli"
| "google-vertex";
/**
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
*
* Protocol note (Gemini / Vertex AI thought signatures):
* - `thought: true` is the definitive marker for thinking content (thought summaries).
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
* used to preserve reasoning context across multi-turn interactions.
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
* indicate the part itself is thinking content.
* - For non-functionCall responses, the signature appears on the last part for context replay.
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
* do not merge/move signatures across parts.
*
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
*/
export function isThinkingPart(
part: Pick<Part, "thought" | "thoughtSignature">,
): boolean {
return part.thought === true;
}
/**
* Retain thought signatures during streaming.
*
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
* This helper preserves the last non-empty signature for the current block.
*
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
* a signature from being overwritten with `undefined` within the same streamed block.
*/
export function retainThoughtSignature(
existing: string | undefined,
incoming: string | undefined,
): string | undefined {
if (typeof incoming === "string" && incoming.length > 0) return incoming;
return existing;
}
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
// Sentinel value that tells the Gemini API to skip thought signature validation.
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
function isValidThoughtSignature(signature: string | undefined): boolean {
if (!signature) return false;
if (signature.length % 4 !== 0) return false;
return base64SignaturePattern.test(signature);
}
/**
* Only keep signatures from the same provider/model and with valid base64.
*/
function resolveThoughtSignature(
isSameProviderAndModel: boolean,
signature: string | undefined,
): string | undefined {
return isSameProviderAndModel && isValidThoughtSignature(signature)
? signature
: undefined;
}
/**
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
*/
export function requiresToolCallId(modelId: string): boolean {
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
}
/**
* Convert internal messages to Gemini Content[] format.
*/
export function convertMessages<T extends GoogleApiType>(
model: Model<T>,
context: Context,
): Content[] {
const contents: Content[] = [];
const normalizeToolCallId = (id: string): string => {
if (!requiresToolCallId(model.id)) return id;
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
};
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
contents.push({
role: "user",
parts: [{ text: sanitizeSurrogates(msg.content) }],
});
} else {
const parts: Part[] = msg.content.map((item) => {
if (item.type === "text") {
return { text: sanitizeSurrogates(item.text) };
} else {
return {
inlineData: {
mimeType: item.mimeType,
data: item.data,
},
};
}
});
const filteredParts = !model.input.includes("image")
? parts.filter((p) => p.text !== undefined)
: parts;
if (filteredParts.length === 0) continue;
contents.push({
role: "user",
parts: filteredParts,
});
}
} else if (msg.role === "assistant") {
const parts: Part[] = [];
// Check if message is from same provider and model - only then keep thinking blocks
const isSameProviderAndModel =
msg.provider === model.provider && msg.model === model.id;
for (const block of msg.content) {
if (block.type === "text") {
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
if (!block.text || block.text.trim() === "") continue;
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.textSignature,
);
parts.push({
text: sanitizeSurrogates(block.text),
...(thoughtSignature && { thoughtSignature }),
});
} else if (block.type === "thinking") {
// Skip empty thinking blocks
if (!block.thinking || block.thinking.trim() === "") continue;
// Only keep as thinking block if same provider AND same model
// Otherwise convert to plain text (no tags to avoid model mimicking them)
if (isSameProviderAndModel) {
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.thinkingSignature,
);
parts.push({
thought: true,
text: sanitizeSurrogates(block.thinking),
...(thoughtSignature && { thoughtSignature }),
});
} else {
parts.push({
text: sanitizeSurrogates(block.thinking),
});
}
} else if (block.type === "toolCall") {
const thoughtSignature = resolveThoughtSignature(
isSameProviderAndModel,
block.thoughtSignature,
);
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
// Use the skip_thought_signature_validator sentinel for unsigned function calls
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
const effectiveSignature =
thoughtSignature ||
(isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
const part: Part = {
functionCall: {
name: block.name,
args: block.arguments ?? {},
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
},
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
};
parts.push(part);
}
}
if (parts.length === 0) continue;
contents.push({
role: "model",
parts,
});
} else if (msg.role === "toolResult") {
// Extract text and image content
const textContent = msg.content.filter(
(c): c is TextContent => c.type === "text",
);
const textResult = textContent.map((c) => c.text).join("\n");
const imageContent = model.input.includes("image")
? msg.content.filter((c): c is ImageContent => c.type === "image")
: [];
const hasText = textResult.length > 0;
const hasImages = imageContent.length > 0;
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
// Older models don't support this, so we put images in a separate user message.
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
// Use "output" key for success, "error" key for errors as per SDK documentation
const responseValue = hasText
? sanitizeSurrogates(textResult)
: hasImages
? "(see attached image)"
: "";
const imageParts: Part[] = imageContent.map((imageBlock) => ({
inlineData: {
mimeType: imageBlock.mimeType,
data: imageBlock.data,
},
}));
const includeId = requiresToolCallId(model.id);
const functionResponsePart: Part = {
functionResponse: {
name: msg.toolName,
response: msg.isError
? { error: responseValue }
: { output: responseValue },
// Nest images inside functionResponse.parts for Gemini 3
...(hasImages &&
supportsMultimodalFunctionResponse && { parts: imageParts }),
...(includeId ? { id: msg.toolCallId } : {}),
},
};
// Cloud Code Assist API requires all function responses to be in a single user turn.
// Check if the last content is already a user turn with function responses and merge.
const lastContent = contents[contents.length - 1];
if (
lastContent?.role === "user" &&
lastContent.parts?.some((p) => p.functionResponse)
) {
lastContent.parts.push(functionResponsePart);
} else {
contents.push({
role: "user",
parts: [functionResponsePart],
});
}
// For older models, add images in a separate user message
if (hasImages && !supportsMultimodalFunctionResponse) {
contents.push({
role: "user",
parts: [{ text: "Tool result image:" }, ...imageParts],
});
}
}
}
return contents;
}
/**
* Convert tools to Gemini function declarations format.
*
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
* models, where the API translates `parameters` into Anthropic's `input_schema`.
*/
export function convertTools(
tools: Tool[],
useParameters = false,
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
if (tools.length === 0) return undefined;
return [
{
functionDeclarations: tools.map((tool) => ({
name: tool.name,
description: tool.description,
...(useParameters
? { parameters: tool.parameters }
: { parametersJsonSchema: tool.parameters }),
})),
},
];
}
/**
* Map tool choice string to Gemini FunctionCallingConfigMode.
*/
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
switch (choice) {
case "auto":
return FunctionCallingConfigMode.AUTO;
case "none":
return FunctionCallingConfigMode.NONE;
case "any":
return FunctionCallingConfigMode.ANY;
default:
return FunctionCallingConfigMode.AUTO;
}
}
/**
* Map Gemini FinishReason to our StopReason.
*/
export function mapStopReason(reason: FinishReason): StopReason {
switch (reason) {
case FinishReason.STOP:
return "stop";
case FinishReason.MAX_TOKENS:
return "length";
case FinishReason.BLOCKLIST:
case FinishReason.PROHIBITED_CONTENT:
case FinishReason.SPII:
case FinishReason.SAFETY:
case FinishReason.IMAGE_SAFETY:
case FinishReason.IMAGE_PROHIBITED_CONTENT:
case FinishReason.IMAGE_RECITATION:
case FinishReason.IMAGE_OTHER:
case FinishReason.RECITATION:
case FinishReason.FINISH_REASON_UNSPECIFIED:
case FinishReason.OTHER:
case FinishReason.LANGUAGE:
case FinishReason.MALFORMED_FUNCTION_CALL:
case FinishReason.UNEXPECTED_TOOL_CALL:
case FinishReason.NO_IMAGE:
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Map string finish reason to our StopReason (for raw API responses).
*/
export function mapStopReasonString(reason: string): StopReason {
switch (reason) {
case "STOP":
return "stop";
case "MAX_TOKENS":
return "length";
default:
return "error";
}
}

View file

@ -0,0 +1,529 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
ThinkingLevel,
} from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
ThinkingLevel as PiThinkingLevel,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleVertexOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
project?: string;
location?: string;
}
const API_VERSION = "v1";
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
MINIMAL: ThinkingLevel.MINIMAL,
LOW: ThinkingLevel.LOW,
MEDIUM: ThinkingLevel.MEDIUM,
HIGH: ThinkingLevel.HIGH,
};
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogleVertex: StreamFunction<
"google-vertex",
GoogleVertexOptions
> = (
model: Model<"google-vertex">,
context: Context,
options?: GoogleVertexOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-vertex" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const project = resolveProject(options);
const location = resolveLocation(options);
const client = createClient(model, project, location, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: undefined,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
const providedId = part.functionCall.id;
const needsNewId =
!providedId ||
output.content.some(
(b) => b.type === "toolCall" && b.id === providedId,
);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments:
(part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && {
thoughtSignature: part.thoughtSignature,
}),
};
output.content.push(toolCall);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) +
(chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogleVertex: StreamFunction<
"google-vertex",
SimpleStreamOptions
> = (
model: Model<"google-vertex">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamGoogleVertex(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleVertexOptions);
}
const effort = clampReasoning(options.reasoning)!;
const geminiModel = model as unknown as Model<"google-generative-ai">;
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, geminiModel),
},
} satisfies GoogleVertexOptions);
}
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(
geminiModel,
effort,
options.thinkingBudgets,
),
},
} satisfies GoogleVertexOptions);
};
function createClient(
model: Model<"google-vertex">,
project: string,
location: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: { headers?: Record<string, string> } = {};
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
return new GoogleGenAI({
vertexai: true,
project,
location,
apiVersion: API_VERSION,
httpOptions: hasHttpOptions ? httpOptions : undefined,
});
}
function resolveProject(options?: GoogleVertexOptions): string {
const project =
options?.project ||
process.env.GOOGLE_CLOUD_PROJECT ||
process.env.GCLOUD_PROJECT;
if (!project) {
throw new Error(
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
);
}
return project;
}
function resolveLocation(options?: GoogleVertexOptions): string {
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
if (!location) {
throw new Error(
"Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.",
);
}
return location;
}
function buildParams(
model: Model<"google-vertex">,
context: Context,
options: GoogleVertexOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && {
systemInstruction: sanitizeSurrogates(context.systemPrompt),
}),
...(context.tools &&
context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,501 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
} from "@google/genai";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
}
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogle: StreamFunction<
"google-generative-ai",
GoogleOptions
> = (
model: Model<"google-generative-ai">,
context: Context,
options?: GoogleOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-generative-ai" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: undefined,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
// Generate unique ID if not provided or if it's a duplicate
const providedId = part.functionCall.id;
const needsNewId =
!providedId ||
output.content.some(
(b) => b.type === "toolCall" && b.id === providedId,
);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments:
(part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && {
thoughtSignature: part.thoughtSignature,
}),
};
output.content.push(toolCall);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) +
(chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogle: StreamFunction<
"google-generative-ai",
SimpleStreamOptions
> = (
model: Model<"google-generative-ai">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamGoogle(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleOptions);
}
const effort = clampReasoning(options.reasoning)!;
const googleModel = model as Model<"google-generative-ai">;
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, googleModel),
},
} satisfies GoogleOptions);
}
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(
googleModel,
effort,
options.thinkingBudgets,
),
},
} satisfies GoogleOptions);
};
function createClient(
model: Model<"google-generative-ai">,
apiKey?: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: {
baseUrl?: string;
apiVersion?: string;
headers?: Record<string, string>;
} = {};
if (model.baseUrl) {
httpOptions.baseUrl = model.baseUrl;
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
}
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
return new GoogleGenAI({
apiKey,
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
});
}
function buildParams(
model: Model<"google-generative-ai">,
context: Context,
options: GoogleOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && {
systemInstruction: sanitizeSurrogates(context.systemPrompt),
}),
...(context.tools &&
context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,688 @@
import { Mistral } from "@mistralai/mistralai";
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
import type {
ChatCompletionStreamRequest,
ChatCompletionStreamRequestMessages,
CompletionEvent,
ContentChunk,
FunctionTool,
} from "@mistralai/mistralai/models/components/index.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
/**
* Provider-specific options for the Mistral API.
*/
export interface MistralOptions extends StreamOptions {
toolChoice?:
| "auto"
| "none"
| "any"
| "required"
| { type: "function"; function: { name: string } };
promptMode?: "reasoning";
}
/**
* Stream responses from Mistral using `chat.stream`.
*/
export const streamMistral: StreamFunction<
"mistral-conversations",
MistralOptions
> = (
model: Model<"mistral-conversations">,
context: Context,
options?: MistralOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output = createOutput(model);
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
const mistral = new Mistral({
apiKey,
serverURL: model.baseUrl,
});
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
const transformedMessages = transformMessages(
context.messages,
model,
(id) => normalizeMistralToolCallId(id),
);
const payload = buildChatPayload(
model,
context,
transformedMessages,
options,
);
options?.onPayload?.(payload);
const mistralStream = await mistral.chat.stream(
payload,
buildRequestOptions(model, options),
);
stream.push({ type: "start", partial: output });
await consumeChatStream(model, output, stream, mistralStream);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = formatMistralError(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
*/
export const streamSimpleMistral: StreamFunction<
"mistral-conversations",
SimpleStreamOptions
> = (
model: Model<"mistral-conversations">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoning = clampReasoning(options?.reasoning);
return streamMistral(model, context, {
...base,
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
} satisfies MistralOptions);
};
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
return {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
}
function createMistralToolCallIdNormalizer(): (id: string) => string {
const idMap = new Map<string, string>();
const reverseMap = new Map<string, string>();
return (id: string): string => {
const existing = idMap.get(id);
if (existing) return existing;
let attempt = 0;
while (true) {
const candidate = deriveMistralToolCallId(id, attempt);
const owner = reverseMap.get(candidate);
if (!owner || owner === id) {
idMap.set(id, candidate);
reverseMap.set(candidate, id);
return candidate;
}
attempt++;
}
};
}
function deriveMistralToolCallId(id: string, attempt: number): string {
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH)
return normalized;
const seedBase = normalized || id;
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
return shortHash(seed)
.replace(/[^a-zA-Z0-9]/g, "")
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
}
function formatMistralError(error: unknown): string {
if (error instanceof Error) {
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
const statusCode =
typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
const bodyText =
typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
if (statusCode !== undefined && bodyText) {
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
}
if (statusCode !== undefined)
return `Mistral API error (${statusCode}): ${error.message}`;
return error.message;
}
return safeJsonStringify(error);
}
function truncateErrorText(text: string, maxChars: number): string {
if (text.length <= maxChars) return text;
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
}
function safeJsonStringify(value: unknown): string {
try {
const serialized = JSON.stringify(value);
return serialized === undefined ? String(value) : serialized;
} catch {
return String(value);
}
}
function buildRequestOptions(
model: Model<"mistral-conversations">,
options?: MistralOptions,
): RequestOptions {
const requestOptions: RequestOptions = {};
if (options?.signal) requestOptions.signal = options.signal;
requestOptions.retries = { strategy: "none" };
const headers: Record<string, string> = {};
if (model.headers) Object.assign(headers, model.headers);
if (options?.headers) Object.assign(headers, options.headers);
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
// Respect explicit caller-provided header values.
if (options?.sessionId && !headers["x-affinity"]) {
headers["x-affinity"] = options.sessionId;
}
if (Object.keys(headers).length > 0) {
requestOptions.headers = headers;
}
return requestOptions;
}
function buildChatPayload(
model: Model<"mistral-conversations">,
context: Context,
messages: Message[],
options?: MistralOptions,
): ChatCompletionStreamRequest {
const payload: ChatCompletionStreamRequest = {
model: model.id,
stream: true,
messages: toChatMessages(messages, model.input.includes("image")),
};
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
if (options?.temperature !== undefined)
payload.temperature = options.temperature;
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
if (options?.toolChoice)
payload.toolChoice = mapToolChoice(options.toolChoice);
if (options?.promptMode) payload.promptMode = options.promptMode as any;
if (context.systemPrompt) {
payload.messages.unshift({
role: "system",
content: sanitizeSurrogates(context.systemPrompt),
});
}
return payload;
}
async function consumeChatStream(
model: Model<"mistral-conversations">,
output: AssistantMessage,
stream: AssistantMessageEventStream,
mistralStream: AsyncIterable<CompletionEvent>,
): Promise<void> {
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const toolBlocksByKey = new Map<string, number>();
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (!block) return;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
return;
}
if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
}
};
for await (const event of mistralStream) {
const chunk = event.data;
if (chunk.usage) {
output.usage.input = chunk.usage.promptTokens || 0;
output.usage.output = chunk.usage.completionTokens || 0;
output.usage.cacheRead = 0;
output.usage.cacheWrite = 0;
output.usage.totalTokens =
chunk.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
const choice = chunk.choices[0];
if (!choice) continue;
if (choice.finishReason) {
output.stopReason = mapChatStopReason(choice.finishReason);
}
const delta = choice.delta;
if (delta.content !== null && delta.content !== undefined) {
const contentItems =
typeof delta.content === "string" ? [delta.content] : delta.content;
for (const item of contentItems) {
if (typeof item === "string") {
const textDelta = sanitizeSurrogates(item);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
continue;
}
if (item.type === "thinking") {
const deltaText = item.thinking
.map((part) => ("text" in part ? part.text : ""))
.filter((text) => text.length > 0)
.join("");
const thinkingDelta = sanitizeSurrogates(deltaText);
if (!thinkingDelta) continue;
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.thinking += thinkingDelta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: thinkingDelta,
partial: output,
});
continue;
}
if (item.type === "text") {
const textDelta = sanitizeSurrogates(item.text);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
}
}
}
const toolCalls = delta.toolCalls || [];
for (const toolCall of toolCalls) {
if (currentBlock) {
finishCurrentBlock(currentBlock);
currentBlock = null;
}
const callId =
toolCall.id && toolCall.id !== "null"
? toolCall.id
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
const key = `${callId}:${toolCall.index || 0}`;
const existingIndex = toolBlocksByKey.get(key);
let block: (ToolCall & { partialArgs?: string }) | undefined;
if (existingIndex !== undefined) {
const existing = output.content[existingIndex];
if (existing?.type === "toolCall") {
block = existing as ToolCall & { partialArgs?: string };
}
}
if (!block) {
block = {
type: "toolCall",
id: callId,
name: toolCall.function.name,
arguments: {},
partialArgs: "",
};
output.content.push(block);
toolBlocksByKey.set(key, output.content.length - 1);
stream.push({
type: "toolcall_start",
contentIndex: output.content.length - 1,
partial: output,
});
}
const argsDelta =
typeof toolCall.function.arguments === "string"
? toolCall.function.arguments
: JSON.stringify(toolCall.function.arguments || {});
block.partialArgs = (block.partialArgs || "") + argsDelta;
block.arguments = parseStreamingJson<Record<string, unknown>>(
block.partialArgs,
);
stream.push({
type: "toolcall_delta",
contentIndex: toolBlocksByKey.get(key)!,
delta: argsDelta,
partial: output,
});
}
}
finishCurrentBlock(currentBlock);
for (const index of toolBlocksByKey.values()) {
const block = output.content[index];
if (block.type !== "toolCall") continue;
const toolBlock = block as ToolCall & { partialArgs?: string };
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(
toolBlock.partialArgs,
);
delete toolBlock.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: toolBlock,
partial: output,
});
}
}
function toFunctionTools(
tools: Tool[],
): Array<FunctionTool & { type: "function" }> {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as unknown as Record<string, unknown>,
strict: false,
},
}));
}
function toChatMessages(
messages: Message[],
supportsImages: boolean,
): ChatCompletionStreamRequestMessages[] {
const result: ChatCompletionStreamRequestMessages[] = [];
for (const msg of messages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
continue;
}
const hadImages = msg.content.some((item) => item.type === "image");
const content: ContentChunk[] = msg.content
.filter((item) => item.type === "text" || supportsImages)
.map((item) => {
if (item.type === "text")
return { type: "text", text: sanitizeSurrogates(item.text) };
return {
type: "image_url",
imageUrl: `data:${item.mimeType};base64,${item.data}`,
};
});
if (content.length > 0) {
result.push({ role: "user", content });
continue;
}
if (hadImages && !supportsImages) {
result.push({
role: "user",
content: "(image omitted: model does not support images)",
});
}
continue;
}
if (msg.role === "assistant") {
const contentParts: ContentChunk[] = [];
const toolCalls: Array<{
id: string;
type: "function";
function: { name: string; arguments: string };
}> = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length > 0) {
contentParts.push({
type: "text",
text: sanitizeSurrogates(block.text),
});
}
continue;
}
if (block.type === "thinking") {
if (block.thinking.trim().length > 0) {
contentParts.push({
type: "thinking",
thinking: [
{ type: "text", text: sanitizeSurrogates(block.thinking) },
],
});
}
continue;
}
toolCalls.push({
id: block.id,
type: "function",
function: {
name: block.name,
arguments: JSON.stringify(block.arguments || {}),
},
});
}
const assistantMessage: ChatCompletionStreamRequestMessages = {
role: "assistant",
};
if (contentParts.length > 0) assistantMessage.content = contentParts;
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
if (contentParts.length > 0 || toolCalls.length > 0)
result.push(assistantMessage);
continue;
}
const toolContent: ContentChunk[] = [];
const textResult = msg.content
.filter((part) => part.type === "text")
.map((part) =>
part.type === "text" ? sanitizeSurrogates(part.text) : "",
)
.join("\n");
const hasImages = msg.content.some((part) => part.type === "image");
const toolText = buildToolResultText(
textResult,
hasImages,
supportsImages,
msg.isError,
);
toolContent.push({ type: "text", text: toolText });
for (const part of msg.content) {
if (!supportsImages) continue;
if (part.type !== "image") continue;
toolContent.push({
type: "image_url",
imageUrl: `data:${part.mimeType};base64,${part.data}`,
});
}
result.push({
role: "tool",
toolCallId: msg.toolCallId,
name: msg.toolName,
content: toolContent,
});
}
return result;
}
function buildToolResultText(
text: string,
hasImages: boolean,
supportsImages: boolean,
isError: boolean,
): string {
const trimmed = text.trim();
const errorPrefix = isError ? "[tool error] " : "";
if (trimmed.length > 0) {
const imageSuffix =
hasImages && !supportsImages
? "\n[tool image omitted: model does not support images]"
: "";
return `${errorPrefix}${trimmed}${imageSuffix}`;
}
if (hasImages) {
if (supportsImages) {
return isError
? "[tool error] (see attached image)"
: "(see attached image)";
}
return isError
? "[tool error] (image omitted: model does not support images)"
: "(image omitted: model does not support images)";
}
return isError ? "[tool error] (no tool output)" : "(no tool output)";
}
function mapToolChoice(
choice: MistralOptions["toolChoice"],
):
| "auto"
| "none"
| "any"
| "required"
| { type: "function"; function: { name: string } }
| undefined {
if (!choice) return undefined;
if (
choice === "auto" ||
choice === "none" ||
choice === "any" ||
choice === "required"
) {
return choice as any;
}
return {
type: "function",
function: { name: choice.function.name },
};
}
function mapChatStopReason(reason: string | null): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
case "model_length":
return "length";
case "tool_calls":
return "toolUse";
case "error":
return "error";
default:
return "stop";
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,949 @@
import OpenAI from "openai";
import type {
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
} from "openai/resources/chat/completions.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost, supportsXhigh } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
OpenAICompletionsCompat,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Check if conversation messages contain tool calls or tool results.
* This is needed because Anthropic (via proxy) requires the tools param
* to be present when messages include tool_calls or tool role messages.
*/
function hasToolHistory(messages: Message[]): boolean {
for (const msg of messages) {
if (msg.role === "toolResult") {
return true;
}
if (msg.role === "assistant") {
if (msg.content.some((block) => block.type === "toolCall")) {
return true;
}
}
}
return false;
}
export interface OpenAICompletionsOptions extends StreamOptions {
toolChoice?:
| "auto"
| "none"
| "required"
| { type: "function"; function: { name: string } };
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
}
export const streamOpenAICompletions: StreamFunction<
"openai-completions",
OpenAICompletionsOptions
> = (
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const openaiStream = await client.chat.completions.create(params, {
signal: options?.signal,
});
stream.push({ type: "start", partial: output });
let currentBlock:
| TextContent
| ThinkingContent
| (ToolCall & { partialArgs?: string })
| null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (block) {
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialArgs);
delete block.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall: block,
partial: output,
});
}
}
};
for await (const chunk of openaiStream) {
if (chunk.usage) {
const cachedTokens =
chunk.usage.prompt_tokens_details?.cached_tokens || 0;
const reasoningTokens =
chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
const outputTokens =
(chunk.usage.completion_tokens || 0) + reasoningTokens;
output.usage = {
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
input,
output: outputTokens,
cacheRead: cachedTokens,
cacheWrite: 0,
// Compute totalTokens ourselves since we add reasoning_tokens to output
// and some providers (e.g., Groq) don't include them in total_tokens
totalTokens: input + outputTokens + cachedTokens,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
const choice = chunk.choices?.[0];
if (!choice) continue;
if (choice.finish_reason) {
output.stopReason = mapStopReason(choice.finish_reason);
}
if (choice.delta) {
if (
choice.delta.content !== null &&
choice.delta.content !== undefined &&
choice.delta.content.length > 0
) {
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "text") {
currentBlock.text += choice.delta.content;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: choice.delta.content,
partial: output,
});
}
}
// Some endpoints return reasoning in reasoning_content (llama.cpp),
// or reasoning (other openai compatible endpoints)
// Use the first non-empty reasoning field to avoid duplication
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
const reasoningFields = [
"reasoning_content",
"reasoning",
"reasoning_text",
];
let foundReasoningField: string | null = null;
for (const field of reasoningFields) {
if (
(choice.delta as any)[field] !== null &&
(choice.delta as any)[field] !== undefined &&
(choice.delta as any)[field].length > 0
) {
if (!foundReasoningField) {
foundReasoningField = field;
break;
}
}
}
if (foundReasoningField) {
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: foundReasoningField,
};
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "thinking") {
const delta = (choice.delta as any)[foundReasoningField];
currentBlock.thinking += delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
if (choice?.delta?.tool_calls) {
for (const toolCall of choice.delta.tool_calls) {
if (
!currentBlock ||
currentBlock.type !== "toolCall" ||
(toolCall.id && currentBlock.id !== toolCall.id)
) {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "toolCall",
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: {},
partialArgs: "",
};
output.content.push(currentBlock);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock.type === "toolCall") {
if (toolCall.id) currentBlock.id = toolCall.id;
if (toolCall.function?.name)
currentBlock.name = toolCall.function.name;
let delta = "";
if (toolCall.function?.arguments) {
delta = toolCall.function.arguments;
currentBlock.partialArgs += toolCall.function.arguments;
currentBlock.arguments = parseStreamingJson(
currentBlock.partialArgs,
);
}
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
}
const reasoningDetails = (choice.delta as any).reasoning_details;
if (reasoningDetails && Array.isArray(reasoningDetails)) {
for (const detail of reasoningDetails) {
if (
detail.type === "reasoning.encrypted" &&
detail.id &&
detail.data
) {
const matchingToolCall = output.content.find(
(b) => b.type === "toolCall" && b.id === detail.id,
) as ToolCall | undefined;
if (matchingToolCall) {
matchingToolCall.thoughtSignature = JSON.stringify(detail);
}
}
}
}
}
}
finishCurrentBlock(currentBlock);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
// Some providers via OpenRouter give additional information in this field.
const rawMetadata = (error as any)?.error?.metadata?.raw;
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAICompletions: StreamFunction<
"openai-completions",
SimpleStreamOptions
> = (
model: Model<"openai-completions">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
const toolChoice = (options as OpenAICompletionsOptions | undefined)
?.toolChoice;
return streamOpenAICompletions(model, context, {
...base,
reasoningEffort,
toolChoice,
} satisfies OpenAICompletionsOptions);
};
function createClient(
model: Model<"openai-completions">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
) {
const compat = getCompat(model);
const messages = convertMessages(model, context, compat);
maybeAddOpenRouterAnthropicCacheControl(model, messages);
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: model.id,
messages,
stream: true,
};
if (compat.supportsUsageInStreaming !== false) {
(params as any).stream_options = { include_usage: true };
}
if (compat.supportsStore) {
params.store = false;
}
if (options?.maxTokens) {
if (compat.maxTokensField === "max_tokens") {
(params as any).max_tokens = options.maxTokens;
} else {
params.max_completion_tokens = options.maxTokens;
}
}
if (options?.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, compat);
} else if (hasToolHistory(context.messages)) {
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
params.tools = [];
}
if (options?.toolChoice) {
params.tool_choice = options.toolChoice;
}
if (
(compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") &&
model.reasoning
) {
// Both Z.ai and Qwen use enable_thinking: boolean
(params as any).enable_thinking = !!options?.reasoningEffort;
} else if (
options?.reasoningEffort &&
model.reasoning &&
compat.supportsReasoningEffort
) {
// OpenAI-style reasoning_effort
(params as any).reasoning_effort = mapReasoningEffort(
options.reasoningEffort,
compat.reasoningEffortMap,
);
}
// OpenRouter provider routing preferences
if (
model.baseUrl.includes("openrouter.ai") &&
model.compat?.openRouterRouting
) {
(params as any).provider = model.compat.openRouterRouting;
}
// Vercel AI Gateway provider routing preferences
if (
model.baseUrl.includes("ai-gateway.vercel.sh") &&
model.compat?.vercelGatewayRouting
) {
const routing = model.compat.vercelGatewayRouting;
if (routing.only || routing.order) {
const gatewayOptions: Record<string, string[]> = {};
if (routing.only) gatewayOptions.only = routing.only;
if (routing.order) gatewayOptions.order = routing.order;
(params as any).providerOptions = { gateway: gatewayOptions };
}
}
return params;
}
function mapReasoningEffort(
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
reasoningEffortMap: Partial<
Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>
>,
): string {
return reasoningEffortMap[effort] ?? effort;
}
function maybeAddOpenRouterAnthropicCacheControl(
model: Model<"openai-completions">,
messages: ChatCompletionMessageParam[],
): void {
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/"))
return;
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
// on the last user/assistant message (walking backwards until we find text content).
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i];
if (msg.role !== "user" && msg.role !== "assistant") continue;
const content = msg.content;
if (typeof content === "string") {
msg.content = [
Object.assign(
{ type: "text" as const, text: content },
{ cache_control: { type: "ephemeral" } },
),
];
return;
}
if (!Array.isArray(content)) continue;
// Find last text part and add cache_control
for (let j = content.length - 1; j >= 0; j--) {
const part = content[j];
if (part?.type === "text") {
Object.assign(part, { cache_control: { type: "ephemeral" } });
return;
}
}
}
}
export function convertMessages(
model: Model<"openai-completions">,
context: Context,
compat: Required<OpenAICompletionsCompat>,
): ChatCompletionMessageParam[] {
const params: ChatCompletionMessageParam[] = [];
const normalizeToolCallId = (id: string): string => {
// Handle pipe-separated IDs from OpenAI Responses API
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
// These come from providers like github-copilot, openai-codex, opencode
// Extract just the call_id part and normalize it
if (id.includes("|")) {
const [callId] = id.split("|");
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
}
if (model.provider === "openai")
return id.length > 40 ? id.slice(0, 40) : id;
return id;
};
const transformedMessages = transformMessages(context.messages, model, (id) =>
normalizeToolCallId(id),
);
if (context.systemPrompt) {
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
const role = useDeveloperRole ? "developer" : "system";
params.push({
role: role,
content: sanitizeSurrogates(context.systemPrompt),
});
}
let lastRole: string | null = null;
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
// Some providers don't allow user messages directly after tool results
// Insert a synthetic assistant message to bridge the gap
if (
compat.requiresAssistantAfterToolResult &&
lastRole === "toolResult" &&
msg.role === "user"
) {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
if (msg.role === "user") {
if (typeof msg.content === "string") {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
} else {
const content: ChatCompletionContentPart[] = msg.content.map(
(item): ChatCompletionContentPart => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
} satisfies ChatCompletionContentPartText;
} else {
return {
type: "image_url",
image_url: {
url: `data:${item.mimeType};base64,${item.data}`,
},
} satisfies ChatCompletionContentPartImage;
}
},
);
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
if (filteredContent.length === 0) continue;
params.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
// Some providers don't accept null content, use empty string instead
const assistantMsg: ChatCompletionAssistantMessageParam = {
role: "assistant",
content: compat.requiresAssistantAfterToolResult ? "" : null,
};
const textBlocks = msg.content.filter(
(b) => b.type === "text",
) as TextContent[];
// Filter out empty text blocks to avoid API validation errors
const nonEmptyTextBlocks = textBlocks.filter(
(b) => b.text && b.text.trim().length > 0,
);
if (nonEmptyTextBlocks.length > 0) {
// GitHub Copilot requires assistant content as a string, not an array.
// Sending as array causes Claude models to re-answer all previous prompts.
if (model.provider === "github-copilot") {
assistantMsg.content = nonEmptyTextBlocks
.map((b) => sanitizeSurrogates(b.text))
.join("");
} else {
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
return { type: "text", text: sanitizeSurrogates(b.text) };
});
}
}
// Handle thinking blocks
const thinkingBlocks = msg.content.filter(
(b) => b.type === "thinking",
) as ThinkingContent[];
// Filter out empty thinking blocks to avoid API validation errors
const nonEmptyThinkingBlocks = thinkingBlocks.filter(
(b) => b.thinking && b.thinking.trim().length > 0,
);
if (nonEmptyThinkingBlocks.length > 0) {
if (compat.requiresThinkingAsText) {
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
const thinkingText = nonEmptyThinkingBlocks
.map((b) => b.thinking)
.join("\n\n");
const textContent = assistantMsg.content as Array<{
type: "text";
text: string;
}> | null;
if (textContent) {
textContent.unshift({ type: "text", text: thinkingText });
} else {
assistantMsg.content = [{ type: "text", text: thinkingText }];
}
} else {
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
if (signature && signature.length > 0) {
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks
.map((b) => b.thinking)
.join("\n");
}
}
}
const toolCalls = msg.content.filter(
(b) => b.type === "toolCall",
) as ToolCall[];
if (toolCalls.length > 0) {
assistantMsg.tool_calls = toolCalls.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.name,
arguments: JSON.stringify(tc.arguments),
},
}));
const reasoningDetails = toolCalls
.filter((tc) => tc.thoughtSignature)
.map((tc) => {
try {
return JSON.parse(tc.thoughtSignature!);
} catch {
return null;
}
})
.filter(Boolean);
if (reasoningDetails.length > 0) {
(assistantMsg as any).reasoning_details = reasoningDetails;
}
}
// Skip assistant messages that have no content and no tool calls.
// Some providers require "either content or tool_calls, but not none".
// Other providers also don't accept empty assistant messages.
// This handles aborted assistant responses that got no content.
const content = assistantMsg.content;
const hasContent =
content !== null &&
content !== undefined &&
(typeof content === "string" ? content.length > 0 : content.length > 0);
if (!hasContent && !assistantMsg.tool_calls) {
continue;
}
params.push(assistantMsg);
} else if (msg.role === "toolResult") {
const imageBlocks: Array<{
type: "image_url";
image_url: { url: string };
}> = [];
let j = i;
for (
;
j < transformedMessages.length &&
transformedMessages[j].role === "toolResult";
j++
) {
const toolMsg = transformedMessages[j] as ToolResultMessage;
// Extract text and image content
const textResult = toolMsg.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join("\n");
const hasImages = toolMsg.content.some((c) => c.type === "image");
// Always send tool result with text (or placeholder if only images)
const hasText = textResult.length > 0;
// Some providers require the 'name' field in tool results
const toolResultMsg: ChatCompletionToolMessageParam = {
role: "tool",
content: sanitizeSurrogates(
hasText ? textResult : "(see attached image)",
),
tool_call_id: toolMsg.toolCallId,
};
if (compat.requiresToolResultName && toolMsg.toolName) {
(toolResultMsg as any).name = toolMsg.toolName;
}
params.push(toolResultMsg);
if (hasImages && model.input.includes("image")) {
for (const block of toolMsg.content) {
if (block.type === "image") {
imageBlocks.push({
type: "image_url",
image_url: {
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
},
});
}
}
}
}
i = j - 1;
if (imageBlocks.length > 0) {
if (compat.requiresAssistantAfterToolResult) {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
params.push({
role: "user",
content: [
{
type: "text",
text: "Attached image(s) from tool result:",
},
...imageBlocks,
],
});
lastRole = "user";
} else {
lastRole = "toolResult";
}
continue;
}
lastRole = msg.role;
}
return params;
}
function convertTools(
tools: Tool[],
compat: Required<OpenAICompletionsCompat>,
): OpenAI.Chat.Completions.ChatCompletionTool[] {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
// Only include strict if provider supports it. Some reject unknown fields.
...(compat.supportsStrictMode !== false && { strict: false }),
},
}));
}
function mapStopReason(
reason: ChatCompletionChunk.Choice["finish_reason"],
): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
return "length";
case "function_call":
case "tool_calls":
return "toolUse";
case "content_filter":
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Detect compatibility settings from provider and baseUrl for known providers.
* Provider takes precedence over URL-based detection since it's explicitly configured.
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
*/
function detectCompat(
model: Model<"openai-completions">,
): Required<OpenAICompletionsCompat> {
const provider = model.provider;
const baseUrl = model.baseUrl;
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
const isNonStandard =
provider === "cerebras" ||
baseUrl.includes("cerebras.ai") ||
provider === "xai" ||
baseUrl.includes("api.x.ai") ||
baseUrl.includes("chutes.ai") ||
baseUrl.includes("deepseek.com") ||
isZai ||
provider === "opencode" ||
baseUrl.includes("opencode.ai");
const useMaxTokens = baseUrl.includes("chutes.ai");
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
const reasoningEffortMap =
isGroq && model.id === "qwen/qwen3-32b"
? {
minimal: "default",
low: "default",
medium: "default",
high: "default",
xhigh: "default",
}
: {};
return {
supportsStore: !isNonStandard,
supportsDeveloperRole: !isNonStandard,
supportsReasoningEffort: !isGrok && !isZai,
reasoningEffortMap,
supportsUsageInStreaming: true,
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
requiresToolResultName: false,
requiresAssistantAfterToolResult: false,
requiresThinkingAsText: false,
thinkingFormat: isZai ? "zai" : "openai",
openRouterRouting: {},
vercelGatewayRouting: {},
supportsStrictMode: true,
};
}
/**
* Get resolved compatibility settings for a model.
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
*/
function getCompat(
model: Model<"openai-completions">,
): Required<OpenAICompletionsCompat> {
const detected = detectCompat(model);
if (!model.compat) return detected;
return {
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
supportsDeveloperRole:
model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
supportsReasoningEffort:
model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
reasoningEffortMap:
model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
supportsUsageInStreaming:
model.compat.supportsUsageInStreaming ??
detected.supportsUsageInStreaming,
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
requiresToolResultName:
model.compat.requiresToolResultName ?? detected.requiresToolResultName,
requiresAssistantAfterToolResult:
model.compat.requiresAssistantAfterToolResult ??
detected.requiresAssistantAfterToolResult,
requiresThinkingAsText:
model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
openRouterRouting: model.compat.openRouterRouting ?? {},
vercelGatewayRouting:
model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
supportsStrictMode:
model.compat.supportsStrictMode ?? detected.supportsStrictMode,
};
}

View file

@ -0,0 +1,583 @@
import type OpenAI from "openai";
import type {
Tool as OpenAITool,
ResponseCreateParamsStreaming,
ResponseFunctionToolCall,
ResponseInput,
ResponseInputContent,
ResponseInputImage,
ResponseInputText,
ResponseOutputMessage,
ResponseReasoningItem,
ResponseStreamEvent,
} from "openai/resources/responses/responses.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
ImageContent,
Model,
StopReason,
TextContent,
TextSignatureV1,
ThinkingContent,
Tool,
ToolCall,
Usage,
} from "../types.js";
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
// =============================================================================
// Utilities
// =============================================================================
function encodeTextSignatureV1(
id: string,
phase?: TextSignatureV1["phase"],
): string {
const payload: TextSignatureV1 = { v: 1, id };
if (phase) payload.phase = phase;
return JSON.stringify(payload);
}
function parseTextSignature(
signature: string | undefined,
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
if (!signature) return undefined;
if (signature.startsWith("{")) {
try {
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
if (parsed.v === 1 && typeof parsed.id === "string") {
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
return { id: parsed.id, phase: parsed.phase };
}
return { id: parsed.id };
}
} catch {
// Fall through to legacy plain-string handling.
}
}
return { id: signature };
}
export interface OpenAIResponsesStreamOptions {
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
applyServiceTierPricing?: (
usage: Usage,
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
) => void;
}
export interface ConvertResponsesMessagesOptions {
includeSystemPrompt?: boolean;
}
export interface ConvertResponsesToolsOptions {
strict?: boolean | null;
}
// =============================================================================
// Message conversion
// =============================================================================
export function convertResponsesMessages<TApi extends Api>(
model: Model<TApi>,
context: Context,
allowedToolCallProviders: ReadonlySet<string>,
options?: ConvertResponsesMessagesOptions,
): ResponseInput {
const messages: ResponseInput = [];
const normalizeToolCallId = (id: string): string => {
if (!allowedToolCallProviders.has(model.provider)) return id;
if (!id.includes("|")) return id;
const [callId, itemId] = id.split("|");
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
// OpenAI Responses API requires item id to start with "fc"
if (!sanitizedItemId.startsWith("fc")) {
sanitizedItemId = `fc_${sanitizedItemId}`;
}
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
let normalizedCallId =
sanitizedCallId.length > 64
? sanitizedCallId.slice(0, 64)
: sanitizedCallId;
let normalizedItemId =
sanitizedItemId.length > 64
? sanitizedItemId.slice(0, 64)
: sanitizedItemId;
normalizedCallId = normalizedCallId.replace(/_+$/, "");
normalizedItemId = normalizedItemId.replace(/_+$/, "");
return `${normalizedCallId}|${normalizedItemId}`;
};
const transformedMessages = transformMessages(
context.messages,
model,
normalizeToolCallId,
);
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
if (includeSystemPrompt && context.systemPrompt) {
const role = model.reasoning ? "developer" : "system";
messages.push({
role,
content: sanitizeSurrogates(context.systemPrompt),
});
}
let msgIndex = 0;
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
messages.push({
role: "user",
content: [
{ type: "input_text", text: sanitizeSurrogates(msg.content) },
],
});
} else {
const content: ResponseInputContent[] = msg.content.map(
(item): ResponseInputContent => {
if (item.type === "text") {
return {
type: "input_text",
text: sanitizeSurrogates(item.text),
} satisfies ResponseInputText;
}
return {
type: "input_image",
detail: "auto",
image_url: `data:${item.mimeType};base64,${item.data}`,
} satisfies ResponseInputImage;
},
);
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
if (filteredContent.length === 0) continue;
messages.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
const output: ResponseInput = [];
const assistantMsg = msg as AssistantMessage;
const isDifferentModel =
assistantMsg.model !== model.id &&
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api;
for (const block of msg.content) {
if (block.type === "thinking") {
if (block.thinking.trim().length === 0) continue;
if (block.thinkingSignature) {
const reasoningItem = JSON.parse(
block.thinkingSignature,
) as ResponseReasoningItem;
output.push(reasoningItem);
}
} else if (block.type === "text") {
const textBlock = block as TextContent;
const parsedSignature = parseTextSignature(textBlock.textSignature);
// OpenAI requires id to be max 64 characters
let msgId = parsedSignature?.id;
if (!msgId) {
msgId = `msg_${msgIndex}`;
} else if (msgId.length > 64) {
msgId = `msg_${shortHash(msgId)}`;
}
output.push({
type: "message",
role: "assistant",
content: [
{
type: "output_text",
text: sanitizeSurrogates(textBlock.text),
annotations: [],
},
],
status: "completed",
id: msgId,
phase: parsedSignature?.phase,
} satisfies ResponseOutputMessage);
} else if (block.type === "toolCall") {
const toolCall = block as ToolCall;
const [callId, itemIdRaw] = toolCall.id.split("|");
let itemId: string | undefined = itemIdRaw;
// For different-model messages, set id to undefined to avoid pairing validation.
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
// By omitting the id, we avoid triggering that validation (like cross-provider does).
if (isDifferentModel && itemId?.startsWith("fc_")) {
itemId = undefined;
}
output.push({
type: "function_call",
id: itemId,
call_id: callId,
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
});
}
}
if (output.length === 0) continue;
messages.push(...output);
} else if (msg.role === "toolResult") {
// Extract text and image content
const textResult = msg.content
.filter((c): c is TextContent => c.type === "text")
.map((c) => c.text)
.join("\n");
const hasImages = msg.content.some(
(c): c is ImageContent => c.type === "image",
);
// Always send function_call_output with text (or placeholder if only images)
const hasText = textResult.length > 0;
const [callId] = msg.toolCallId.split("|");
messages.push({
type: "function_call_output",
call_id: callId,
output: sanitizeSurrogates(
hasText ? textResult : "(see attached image)",
),
});
// If there are images and model supports them, send a follow-up user message with images
if (hasImages && model.input.includes("image")) {
const contentParts: ResponseInputContent[] = [];
// Add text prefix
contentParts.push({
type: "input_text",
text: "Attached image(s) from tool result:",
} satisfies ResponseInputText);
// Add images
for (const block of msg.content) {
if (block.type === "image") {
contentParts.push({
type: "input_image",
detail: "auto",
image_url: `data:${block.mimeType};base64,${block.data}`,
} satisfies ResponseInputImage);
}
}
messages.push({
role: "user",
content: contentParts,
});
}
}
msgIndex++;
}
return messages;
}
// =============================================================================
// Tool conversion
// =============================================================================
export function convertResponsesTools(
tools: Tool[],
options?: ConvertResponsesToolsOptions,
): OpenAITool[] {
const strict = options?.strict === undefined ? false : options.strict;
return tools.map((tool) => ({
type: "function",
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
strict,
}));
}
// =============================================================================
// Stream processing
// =============================================================================
export async function processResponsesStream<TApi extends Api>(
openaiStream: AsyncIterable<ResponseStreamEvent>,
output: AssistantMessage,
stream: AssistantMessageEventStream,
model: Model<TApi>,
options?: OpenAIResponsesStreamOptions,
): Promise<void> {
let currentItem:
| ResponseReasoningItem
| ResponseOutputMessage
| ResponseFunctionToolCall
| null = null;
let currentBlock:
| ThinkingContent
| TextContent
| (ToolCall & { partialJson: string })
| null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const event of openaiStream) {
if (event.type === "response.output_item.added") {
const item = event.item;
if (item.type === "reasoning") {
currentItem = item;
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else if (item.type === "message") {
currentItem = item;
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
} else if (item.type === "function_call") {
currentItem = item;
currentBlock = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: {},
partialJson: item.arguments || "",
};
output.content.push(currentBlock);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
}
} else if (event.type === "response.reasoning_summary_part.added") {
if (currentItem && currentItem.type === "reasoning") {
currentItem.summary = currentItem.summary || [];
currentItem.summary.push(event.part);
}
} else if (event.type === "response.reasoning_summary_text.delta") {
if (
currentItem?.type === "reasoning" &&
currentBlock?.type === "thinking"
) {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += event.delta;
lastPart.text += event.delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.reasoning_summary_part.done") {
if (
currentItem?.type === "reasoning" &&
currentBlock?.type === "thinking"
) {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += "\n\n";
lastPart.text += "\n\n";
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: "\n\n",
partial: output,
});
}
}
} else if (event.type === "response.content_part.added") {
if (currentItem?.type === "message") {
currentItem.content = currentItem.content || [];
// Filter out ReasoningText, only accept output_text and refusal
if (
event.part.type === "output_text" ||
event.part.type === "refusal"
) {
currentItem.content.push(event.part);
}
}
} else if (event.type === "response.output_text.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "output_text") {
currentBlock.text += event.delta;
lastPart.text += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.refusal.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "refusal") {
currentBlock.text += event.delta;
lastPart.refusal += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.function_call_arguments.delta") {
if (
currentItem?.type === "function_call" &&
currentBlock?.type === "toolCall"
) {
currentBlock.partialJson += event.delta;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
} else if (event.type === "response.function_call_arguments.done") {
if (
currentItem?.type === "function_call" &&
currentBlock?.type === "toolCall"
) {
currentBlock.partialJson = event.arguments;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
}
} else if (event.type === "response.output_item.done") {
const item = event.item;
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
currentBlock.thinking =
item.summary?.map((s) => s.text).join("\n\n") || "";
currentBlock.thinkingSignature = JSON.stringify(item);
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
currentBlock = null;
} else if (item.type === "message" && currentBlock?.type === "text") {
currentBlock.text = item.content
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
.join("");
currentBlock.textSignature = encodeTextSignatureV1(
item.id,
item.phase ?? undefined,
);
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
currentBlock = null;
} else if (item.type === "function_call") {
const args =
currentBlock?.type === "toolCall" && currentBlock.partialJson
? parseStreamingJson(currentBlock.partialJson)
: parseStreamingJson(item.arguments || "{}");
const toolCall: ToolCall = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: args,
};
currentBlock = null;
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
} else if (event.type === "response.completed") {
const response = event.response;
if (response?.usage) {
const cachedTokens =
response.usage.input_tokens_details?.cached_tokens || 0;
output.usage = {
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
input: (response.usage.input_tokens || 0) - cachedTokens,
output: response.usage.output_tokens || 0,
cacheRead: cachedTokens,
cacheWrite: 0,
totalTokens: response.usage.total_tokens || 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
calculateCost(model, output.usage);
if (options?.applyServiceTierPricing) {
const serviceTier = response?.service_tier ?? options.serviceTier;
options.applyServiceTierPricing(output.usage, serviceTier);
}
// Map status to stop reason
output.stopReason = mapStopReason(response?.status);
if (
output.content.some((b) => b.type === "toolCall") &&
output.stopReason === "stop"
) {
output.stopReason = "toolUse";
}
} else if (event.type === "error") {
throw new Error(
`Error Code ${event.code}: ${event.message}` || "Unknown error",
);
} else if (event.type === "response.failed") {
throw new Error("Unknown error");
}
}
}
function mapStopReason(
status: OpenAI.Responses.ResponseStatus | undefined,
): StopReason {
if (!status) return "stop";
switch (status) {
case "completed":
return "stop";
case "incomplete":
return "length";
case "failed":
case "cancelled":
return "error";
// These two are wonky ...
case "in_progress":
case "queued":
return "stop";
default: {
const _exhaustive: never = status;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}

View file

@ -0,0 +1,309 @@
import OpenAI from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
Usage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import {
buildCopilotDynamicHeaders,
hasCopilotVisionInput,
} from "./github-copilot-headers.js";
import {
convertResponsesMessages,
convertResponsesTools,
processResponsesStream,
} from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const OPENAI_TOOL_CALL_PROVIDERS = new Set([
"openai",
"openai-codex",
"opencode",
]);
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(
cacheRetention?: CacheRetention,
): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (
typeof process !== "undefined" &&
process.env.PI_CACHE_RETENTION === "long"
) {
return "long";
}
return "short";
}
/**
* Get prompt cache retention based on cacheRetention and base URL.
* Only applies to direct OpenAI API calls (api.openai.com).
*/
function getPromptCacheRetention(
baseUrl: string,
cacheRetention: CacheRetention,
): "24h" | undefined {
if (cacheRetention !== "long") {
return undefined;
}
if (baseUrl.includes("api.openai.com")) {
return "24h";
}
return undefined;
}
// OpenAI Responses-specific options
export interface OpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
}
/**
* Generate function for OpenAI Responses API
*/
export const streamOpenAIResponses: StreamFunction<
"openai-responses",
OpenAIResponsesOptions
> = (
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
const params = buildParams(model, context, options);
options?.onPayload?.(params);
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model, {
serviceTier: options?.serviceTier,
applyServiceTierPricing,
});
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content)
delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage =
error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAIResponses: StreamFunction<
"openai-responses",
SimpleStreamOptions
> = (
model: Model<"openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model)
? options?.reasoning
: clampReasoning(options?.reasoning);
return streamOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies OpenAIResponsesOptions);
};
function createClient(
model: Model<"openai-responses">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
) {
const messages = convertResponsesMessages(
model,
context,
OPENAI_TOOL_CALL_PROVIDERS,
);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const params: ResponseCreateParamsStreaming = {
model: model.id,
input: messages,
stream: true,
prompt_cache_key:
cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(
model.baseUrl,
cacheRetention,
),
store: false,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (options?.serviceTier !== undefined) {
params.service_tier = options.serviceTier;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}
function getServiceTierCostMultiplier(
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
): number {
switch (serviceTier) {
case "flex":
return 0.5;
case "priority":
return 2;
default:
return 1;
}
}
function applyServiceTierPricing(
usage: Usage,
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
) {
const multiplier = getServiceTierCostMultiplier(serviceTier);
if (multiplier === 1) return;
usage.cost.input *= multiplier;
usage.cost.output *= multiplier;
usage.cost.cacheRead *= multiplier;
usage.cost.cacheWrite *= multiplier;
usage.cost.total =
usage.cost.input +
usage.cost.output +
usage.cost.cacheRead +
usage.cost.cacheWrite;
}

View file

@ -0,0 +1,216 @@
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
import type {
AssistantMessage,
AssistantMessageEvent,
Context,
Model,
SimpleStreamOptions,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
import {
streamAzureOpenAIResponses,
streamSimpleAzureOpenAIResponses,
} from "./azure-openai-responses.js";
import { streamGoogle, streamSimpleGoogle } from "./google.js";
import {
streamGoogleGeminiCli,
streamSimpleGoogleGeminiCli,
} from "./google-gemini-cli.js";
import {
streamGoogleVertex,
streamSimpleGoogleVertex,
} from "./google-vertex.js";
import { streamMistral, streamSimpleMistral } from "./mistral.js";
import {
streamOpenAICodexResponses,
streamSimpleOpenAICodexResponses,
} from "./openai-codex-responses.js";
import {
streamOpenAICompletions,
streamSimpleOpenAICompletions,
} from "./openai-completions.js";
import {
streamOpenAIResponses,
streamSimpleOpenAIResponses,
} from "./openai-responses.js";
interface BedrockProviderModule {
streamBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: StreamOptions,
) => AsyncIterable<AssistantMessageEvent>;
streamSimpleBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
) => AsyncIterable<AssistantMessageEvent>;
}
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
export function setBedrockProviderModule(module: BedrockProviderModule): void {
bedrockProviderModuleOverride = module;
}
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
if (bedrockProviderModuleOverride) {
return bedrockProviderModuleOverride;
}
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
return module as BedrockProviderModule;
}
function forwardStream(
target: AssistantMessageEventStream,
source: AsyncIterable<AssistantMessageEvent>,
): void {
(async () => {
for await (const event of source) {
target.push(event);
}
target.end();
})();
}
function createLazyLoadErrorMessage(
model: Model<"bedrock-converse-stream">,
error: unknown,
): AssistantMessage {
return {
role: "assistant",
content: [],
api: "bedrock-converse-stream",
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "error",
errorMessage: error instanceof Error ? error.message : String(error),
timestamp: Date.now(),
};
}
function streamBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: StreamOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
function streamSimpleBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamSimpleBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
export function registerBuiltInApiProviders(): void {
registerApiProvider({
api: "anthropic-messages",
stream: streamAnthropic,
streamSimple: streamSimpleAnthropic,
});
registerApiProvider({
api: "openai-completions",
stream: streamOpenAICompletions,
streamSimple: streamSimpleOpenAICompletions,
});
registerApiProvider({
api: "mistral-conversations",
stream: streamMistral,
streamSimple: streamSimpleMistral,
});
registerApiProvider({
api: "openai-responses",
stream: streamOpenAIResponses,
streamSimple: streamSimpleOpenAIResponses,
});
registerApiProvider({
api: "azure-openai-responses",
stream: streamAzureOpenAIResponses,
streamSimple: streamSimpleAzureOpenAIResponses,
});
registerApiProvider({
api: "openai-codex-responses",
stream: streamOpenAICodexResponses,
streamSimple: streamSimpleOpenAICodexResponses,
});
registerApiProvider({
api: "google-generative-ai",
stream: streamGoogle,
streamSimple: streamSimpleGoogle,
});
registerApiProvider({
api: "google-gemini-cli",
stream: streamGoogleGeminiCli,
streamSimple: streamSimpleGoogleGeminiCli,
});
registerApiProvider({
api: "google-vertex",
stream: streamGoogleVertex,
streamSimple: streamSimpleGoogleVertex,
});
registerApiProvider({
api: "bedrock-converse-stream",
stream: streamBedrockLazy,
streamSimple: streamSimpleBedrockLazy,
});
}
export function resetApiProviders(): void {
clearApiProviders();
registerBuiltInApiProviders();
}
registerBuiltInApiProviders();

View file

@ -0,0 +1,59 @@
import type {
Api,
Model,
SimpleStreamOptions,
StreamOptions,
ThinkingBudgets,
ThinkingLevel,
} from "../types.js";
export function buildBaseOptions(
model: Model<Api>,
options?: SimpleStreamOptions,
apiKey?: string,
): StreamOptions {
return {
temperature: options?.temperature,
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
signal: options?.signal,
apiKey: apiKey || options?.apiKey,
cacheRetention: options?.cacheRetention,
sessionId: options?.sessionId,
headers: options?.headers,
onPayload: options?.onPayload,
maxRetryDelayMs: options?.maxRetryDelayMs,
metadata: options?.metadata,
};
}
export function clampReasoning(
effort: ThinkingLevel | undefined,
): Exclude<ThinkingLevel, "xhigh"> | undefined {
return effort === "xhigh" ? "high" : effort;
}
export function adjustMaxTokensForThinking(
baseMaxTokens: number,
modelMaxTokens: number,
reasoningLevel: ThinkingLevel,
customBudgets?: ThinkingBudgets,
): { maxTokens: number; thinkingBudget: number } {
const defaultBudgets: ThinkingBudgets = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
};
const budgets = { ...defaultBudgets, ...customBudgets };
const minOutputTokens = 1024;
const level = clampReasoning(reasoningLevel)!;
let thinkingBudget = budgets[level]!;
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
if (maxTokens <= thinkingBudget) {
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
}
return { maxTokens, thinkingBudget };
}

View file

@ -0,0 +1,193 @@
import type {
Api,
AssistantMessage,
Message,
Model,
ToolCall,
ToolResultMessage,
} from "../types.js";
/**
* Normalize tool call ID for cross-provider compatibility.
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
*/
export function transformMessages<TApi extends Api>(
messages: Message[],
model: Model<TApi>,
normalizeToolCallId?: (
id: string,
model: Model<TApi>,
source: AssistantMessage,
) => string,
): Message[] {
// Build a map of original tool call IDs to normalized IDs
const toolCallIdMap = new Map<string, string>();
// First pass: transform messages (thinking blocks, tool call ID normalization)
const transformed = messages.map((msg) => {
// User messages pass through unchanged
if (msg.role === "user") {
return msg;
}
// Handle toolResult messages - normalize toolCallId if we have a mapping
if (msg.role === "toolResult") {
const normalizedId = toolCallIdMap.get(msg.toolCallId);
if (normalizedId && normalizedId !== msg.toolCallId) {
return { ...msg, toolCallId: normalizedId };
}
return msg;
}
// Assistant messages need transformation check
if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage;
const isSameModel =
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api &&
assistantMsg.model === model.id;
const transformedContent = assistantMsg.content.flatMap((block) => {
if (block.type === "thinking") {
// Redacted thinking is opaque encrypted content, only valid for the same model.
// Drop it for cross-model to avoid API errors.
if (block.redacted) {
return isSameModel ? block : [];
}
// For same model: keep thinking blocks with signatures (needed for replay)
// even if the thinking text is empty (OpenAI encrypted reasoning)
if (isSameModel && block.thinkingSignature) return block;
// Skip empty thinking blocks, convert others to plain text
if (!block.thinking || block.thinking.trim() === "") return [];
if (isSameModel) return block;
return {
type: "text" as const,
text: block.thinking,
};
}
if (block.type === "text") {
if (isSameModel) return block;
return {
type: "text" as const,
text: block.text,
};
}
if (block.type === "toolCall") {
const toolCall = block as ToolCall;
let normalizedToolCall: ToolCall = toolCall;
if (!isSameModel && toolCall.thoughtSignature) {
normalizedToolCall = { ...toolCall };
delete (normalizedToolCall as { thoughtSignature?: string })
.thoughtSignature;
}
if (!isSameModel && normalizeToolCallId) {
const normalizedId = normalizeToolCallId(
toolCall.id,
model,
assistantMsg,
);
if (normalizedId !== toolCall.id) {
toolCallIdMap.set(toolCall.id, normalizedId);
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
}
}
return normalizedToolCall;
}
return block;
});
return {
...assistantMsg,
content: transformedContent,
};
}
return msg;
});
// Second pass: insert synthetic empty tool results for orphaned tool calls
// This preserves thinking signatures and satisfies API requirements
const result: Message[] = [];
let pendingToolCalls: ToolCall[] = [];
let existingToolResultIds = new Set<string>();
for (let i = 0; i < transformed.length; i++) {
const msg = transformed[i];
if (msg.role === "assistant") {
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
// Skip errored/aborted assistant messages entirely.
// These are incomplete turns that shouldn't be replayed:
// - May have partial content (reasoning without message, incomplete tool calls)
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
// - The model should retry from the last valid state
const assistantMsg = msg as AssistantMessage;
if (
assistantMsg.stopReason === "error" ||
assistantMsg.stopReason === "aborted"
) {
continue;
}
// Track tool calls from this assistant message
const toolCalls = assistantMsg.content.filter(
(b) => b.type === "toolCall",
) as ToolCall[];
if (toolCalls.length > 0) {
pendingToolCalls = toolCalls;
existingToolResultIds = new Set();
}
result.push(msg);
} else if (msg.role === "toolResult") {
existingToolResultIds.add(msg.toolCallId);
result.push(msg);
} else if (msg.role === "user") {
// User message interrupts tool flow - insert synthetic results for orphaned calls
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
result.push(msg);
} else {
result.push(msg);
}
}
return result;
}

59
packages/ai/src/stream.ts Normal file
View file

@ -0,0 +1,59 @@
import "./providers/register-builtins.js";
import { getApiProvider } from "./api-registry.js";
import type {
Api,
AssistantMessage,
AssistantMessageEventStream,
Context,
Model,
ProviderStreamOptions,
SimpleStreamOptions,
StreamOptions,
} from "./types.js";
export { getEnvApiKey } from "./env-api-keys.js";
function resolveApiProvider(api: Api) {
const provider = getApiProvider(api);
if (!provider) {
throw new Error(`No API provider registered for api: ${api}`);
}
return provider;
}
export function stream<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.stream(model, context, options as StreamOptions);
}
export async function complete<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): Promise<AssistantMessage> {
const s = stream(model, context, options);
return s.result();
}
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.streamSimple(model, context, options);
}
export async function completeSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): Promise<AssistantMessage> {
const s = streamSimple(model, context, options);
return s.result();
}

361
packages/ai/src/types.ts Normal file
View file

@ -0,0 +1,361 @@
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type KnownApi =
| "openai-completions"
| "mistral-conversations"
| "openai-responses"
| "azure-openai-responses"
| "openai-codex-responses"
| "anthropic-messages"
| "bedrock-converse-stream"
| "google-generative-ai"
| "google-gemini-cli"
| "google-vertex";
export type Api = KnownApi | (string & {});
export type KnownProvider =
| "amazon-bedrock"
| "anthropic"
| "google"
| "google-gemini-cli"
| "google-antigravity"
| "google-vertex"
| "openai"
| "azure-openai-responses"
| "openai-codex"
| "github-copilot"
| "xai"
| "groq"
| "cerebras"
| "openrouter"
| "vercel-ai-gateway"
| "zai"
| "mistral"
| "minimax"
| "minimax-cn"
| "huggingface"
| "opencode"
| "opencode-go"
| "kimi-coding";
export type Provider = KnownProvider | string;
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
/** Token budgets for each thinking level (token-based providers only) */
export interface ThinkingBudgets {
minimal?: number;
low?: number;
medium?: number;
high?: number;
}
// Base options all providers share
export type CacheRetention = "none" | "short" | "long";
export type Transport = "sse" | "websocket" | "auto";
export interface StreamOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
/**
* Preferred transport for providers that support multiple transports.
* Providers that do not support this option ignore it.
*/
transport?: Transport;
/**
* Prompt cache retention preference. Providers map this to their supported values.
* Default: "short".
*/
cacheRetention?: CacheRetention;
/**
* Optional session identifier for providers that support session-based caching.
* Providers can use this to enable prompt caching, request routing, or other
* session-aware features. Ignored by providers that don't support it.
*/
sessionId?: string;
/**
* Optional callback for inspecting provider payloads before sending.
*/
onPayload?: (payload: unknown) => void;
/**
* Optional custom HTTP headers to include in API requests.
* Merged with provider defaults; can override default headers.
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
*/
headers?: Record<string, string>;
/**
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
* If the server's requested delay exceeds this value, the request fails immediately
* with an error containing the requested delay, allowing higher-level retry logic
* to handle it with user visibility.
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
*/
maxRetryDelayMs?: number;
/**
* Optional metadata to include in API requests.
* Providers extract the fields they understand and ignore the rest.
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
*/
metadata?: Record<string, unknown>;
}
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
// Unified options with reasoning passed to streamSimple() and completeSimple()
export interface SimpleStreamOptions extends StreamOptions {
reasoning?: ThinkingLevel;
/** Custom token budgets for thinking levels (token-based providers only) */
thinkingBudgets?: ThinkingBudgets;
}
// Generic StreamFunction with typed options
export type StreamFunction<
TApi extends Api = Api,
TOptions extends StreamOptions = StreamOptions,
> = (
model: Model<TApi>,
context: Context,
options?: TOptions,
) => AssistantMessageEventStream;
export interface TextSignatureV1 {
v: 1;
id: string;
phase?: "commentary" | "final_answer";
}
export interface TextContent {
type: "text";
text: string;
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
}
export interface ThinkingContent {
type: "thinking";
thinking: string;
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
/** When true, the thinking content was redacted by safety filters. The opaque
* encrypted payload is stored in `thinkingSignature` so it can be passed back
* to the API for multi-turn continuity. */
redacted?: boolean;
}
export interface ImageContent {
type: "image";
data: string; // base64 encoded image data
mimeType: string; // e.g., "image/jpeg", "image/png"
}
export interface ToolCall {
type: "toolCall";
id: string;
name: string;
arguments: Record<string, any>;
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
}
export interface Usage {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
total: number;
};
}
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
export interface UserMessage {
role: "user";
content: string | (TextContent | ImageContent)[];
timestamp: number; // Unix timestamp in milliseconds
}
export interface AssistantMessage {
role: "assistant";
content: (TextContent | ThinkingContent | ToolCall)[];
api: Api;
provider: Provider;
model: string;
usage: Usage;
stopReason: StopReason;
errorMessage?: string;
timestamp: number; // Unix timestamp in milliseconds
}
export interface ToolResultMessage<TDetails = any> {
role: "toolResult";
toolCallId: string;
toolName: string;
content: (TextContent | ImageContent)[]; // Supports text and images
details?: TDetails;
isError: boolean;
timestamp: number; // Unix timestamp in milliseconds
}
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
import type { TSchema } from "@sinclair/typebox";
export interface Tool<TParameters extends TSchema = TSchema> {
name: string;
description: string;
parameters: TParameters;
}
export interface Context {
systemPrompt?: string;
messages: Message[];
tools?: Tool[];
}
export type AssistantMessageEvent =
| { type: "start"; partial: AssistantMessage }
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "text_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "text_end";
contentIndex: number;
content: string;
partial: AssistantMessage;
}
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "thinking_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "thinking_end";
contentIndex: number;
content: string;
partial: AssistantMessage;
}
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
| {
type: "toolcall_delta";
contentIndex: number;
delta: string;
partial: AssistantMessage;
}
| {
type: "toolcall_end";
contentIndex: number;
toolCall: ToolCall;
partial: AssistantMessage;
}
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
message: AssistantMessage;
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
error: AssistantMessage;
};
/**
* Compatibility settings for OpenAI-compatible completions APIs.
* Use this to override URL-based auto-detection for custom providers.
*/
export interface OpenAICompletionsCompat {
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
supportsStore?: boolean;
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
supportsDeveloperRole?: boolean;
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
supportsReasoningEffort?: boolean;
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
supportsUsageInStreaming?: boolean;
/** Which field to use for max tokens. Default: auto-detected from URL. */
maxTokensField?: "max_completion_tokens" | "max_tokens";
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
requiresToolResultName?: boolean;
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
requiresAssistantAfterToolResult?: boolean;
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
requiresThinkingAsText?: boolean;
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
thinkingFormat?: "openai" | "zai" | "qwen";
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
openRouterRouting?: OpenRouterRouting;
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
vercelGatewayRouting?: VercelGatewayRouting;
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
supportsStrictMode?: boolean;
}
/** Compatibility settings for OpenAI Responses APIs. */
export interface OpenAIResponsesCompat {
// Reserved for future use
}
/**
* OpenRouter provider routing preferences.
* Controls which upstream providers OpenRouter routes requests to.
* @see https://openrouter.ai/docs/provider-routing
*/
export interface OpenRouterRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
/**
* Vercel AI Gateway routing preferences.
* Controls which upstream providers the gateway routes requests to.
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
*/
export interface VercelGatewayRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
// Model interface for the unified model system
export interface Model<TApi extends Api> {
id: string;
name: string;
api: TApi;
provider: Provider;
baseUrl: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number; // $/million tokens
output: number; // $/million tokens
cacheRead: number; // $/million tokens
cacheWrite: number; // $/million tokens
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
compat?: TApi extends "openai-completions"
? OpenAICompletionsCompat
: TApi extends "openai-responses"
? OpenAIResponsesCompat
: never;
}

View file

@ -0,0 +1,92 @@
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
// Generic event stream class for async iteration
export class EventStream<T, R = T> implements AsyncIterable<T> {
private queue: T[] = [];
private waiting: ((value: IteratorResult<T>) => void)[] = [];
private done = false;
private finalResultPromise: Promise<R>;
private resolveFinalResult!: (result: R) => void;
constructor(
private isComplete: (event: T) => boolean,
private extractResult: (event: T) => R,
) {
this.finalResultPromise = new Promise((resolve) => {
this.resolveFinalResult = resolve;
});
}
push(event: T): void {
if (this.done) return;
if (this.isComplete(event)) {
this.done = true;
this.resolveFinalResult(this.extractResult(event));
}
// Deliver to waiting consumer or queue it
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: event, done: false });
} else {
this.queue.push(event);
}
}
end(result?: R): void {
this.done = true;
if (result !== undefined) {
this.resolveFinalResult(result);
}
// Notify all waiting consumers that we're done
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as any, done: true });
}
}
async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (true) {
if (this.queue.length > 0) {
yield this.queue.shift()!;
} else if (this.done) {
return;
} else {
const result = await new Promise<IteratorResult<T>>((resolve) =>
this.waiting.push(resolve),
);
if (result.done) return;
yield result.value;
}
}
}
result(): Promise<R> {
return this.finalResultPromise;
}
}
export class AssistantMessageEventStream extends EventStream<
AssistantMessageEvent,
AssistantMessage
> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") {
return event.message;
} else if (event.type === "error") {
return event.error;
}
throw new Error("Unexpected event type for final result");
},
);
}
}
/** Factory function for AssistantMessageEventStream (for use in extensions) */
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
return new AssistantMessageEventStream();
}

View file

@ -0,0 +1,17 @@
/** Fast deterministic hash to shorten long strings */
export function shortHash(str: string): string {
let h1 = 0xdeadbeef;
let h2 = 0x41c6ce57;
for (let i = 0; i < str.length; i++) {
const ch = str.charCodeAt(i);
h1 = Math.imul(h1 ^ ch, 2654435761);
h2 = Math.imul(h2 ^ ch, 1597334677);
}
h1 =
Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^
Math.imul(h2 ^ (h2 >>> 13), 3266489909);
h2 =
Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^
Math.imul(h1 ^ (h1 >>> 13), 3266489909);
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
}

View file

@ -0,0 +1,30 @@
import { parse as partialParse } from "partial-json";
/**
* Attempts to parse potentially incomplete JSON during streaming.
* Always returns a valid object, even if the JSON is incomplete.
*
* @param partialJson The partial JSON string from streaming
* @returns Parsed object or empty object if parsing fails
*/
export function parseStreamingJson<T = any>(
partialJson: string | undefined,
): T {
if (!partialJson || partialJson.trim() === "") {
return {} as T;
}
// Try standard parsing first (fastest for complete JSON)
try {
return JSON.parse(partialJson) as T;
} catch {
// Try partial-json for incomplete JSON
try {
const result = partialParse(partialJson);
return (result ?? {}) as T;
} catch {
// If all parsing fails, return empty object
return {} as T;
}
}
}

View file

@ -0,0 +1,144 @@
/**
* Anthropic OAuth flow (Claude Pro/Max)
*/
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
/**
* Login with Anthropic OAuth (device code flow)
*
* @param onAuthUrl - Callback to handle the authorization URL (e.g., open browser)
* @param onPromptCode - Callback to prompt user for the authorization code
*/
export async function loginAnthropic(
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Build authorization URL
const authParams = new URLSearchParams({
code: "true",
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES,
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
});
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuthUrl(authUrl);
// Wait for user to paste authorization code (format: code#state)
const authCode = await onPromptCode();
const splits = authCode.split("#");
const code = splits[0];
const state = splits[1];
// Exchange code for tokens
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code: code,
state: state,
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
// Save credentials
return {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
};
}
/**
* Refresh Anthropic OAuth token
*/
export async function refreshAnthropicToken(
refreshToken: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
grant_type: "refresh_token",
client_id: CLIENT_ID,
refresh_token: refreshToken,
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Anthropic token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
return {
refresh: data.refresh_token,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
};
}
export const anthropicOAuthProvider: OAuthProviderInterface = {
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshAnthropicToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,423 @@
/**
* GitHub Copilot OAuth flow
*/
import { getModels } from "../../models.js";
import type { Api, Model } from "../../types.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type CopilotCredentials = OAuthCredentials & {
enterpriseUrl?: string;
};
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
const COPILOT_HEADERS = {
"User-Agent": "GitHubCopilotChat/0.35.0",
"Editor-Version": "vscode/1.107.0",
"Editor-Plugin-Version": "copilot-chat/0.35.0",
"Copilot-Integration-Id": "vscode-chat",
} as const;
type DeviceCodeResponse = {
device_code: string;
user_code: string;
verification_uri: string;
interval: number;
expires_in: number;
};
type DeviceTokenSuccessResponse = {
access_token: string;
token_type?: string;
scope?: string;
};
type DeviceTokenErrorResponse = {
error: string;
error_description?: string;
interval?: number;
};
export function normalizeDomain(input: string): string | null {
const trimmed = input.trim();
if (!trimmed) return null;
try {
const url = trimmed.includes("://")
? new URL(trimmed)
: new URL(`https://${trimmed}`);
return url.hostname;
} catch {
return null;
}
}
function getUrls(domain: string): {
deviceCodeUrl: string;
accessTokenUrl: string;
copilotTokenUrl: string;
} {
return {
deviceCodeUrl: `https://${domain}/login/device/code`,
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
};
}
/**
* Parse the proxy-ep from a Copilot token and convert to API base URL.
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
* Returns API URL like https://api.individual.githubcopilot.com
*/
function getBaseUrlFromToken(token: string): string | null {
const match = token.match(/proxy-ep=([^;]+)/);
if (!match) return null;
const proxyHost = match[1];
// Convert proxy.xxx to api.xxx
const apiHost = proxyHost.replace(/^proxy\./, "api.");
return `https://${apiHost}`;
}
export function getGitHubCopilotBaseUrl(
token?: string,
enterpriseDomain?: string,
): string {
// If we have a token, extract the base URL from proxy-ep
if (token) {
const urlFromToken = getBaseUrlFromToken(token);
if (urlFromToken) return urlFromToken;
}
// Fallback for enterprise or if token parsing fails
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
return "https://api.individual.githubcopilot.com";
}
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
const response = await fetch(url, init);
if (!response.ok) {
const text = await response.text();
throw new Error(`${response.status} ${response.statusText}: ${text}`);
}
return response.json();
}
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
const urls = getUrls(domain);
const data = await fetchJson(urls.deviceCodeUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
scope: "read:user",
}),
});
if (!data || typeof data !== "object") {
throw new Error("Invalid device code response");
}
const deviceCode = (data as Record<string, unknown>).device_code;
const userCode = (data as Record<string, unknown>).user_code;
const verificationUri = (data as Record<string, unknown>).verification_uri;
const interval = (data as Record<string, unknown>).interval;
const expiresIn = (data as Record<string, unknown>).expires_in;
if (
typeof deviceCode !== "string" ||
typeof userCode !== "string" ||
typeof verificationUri !== "string" ||
typeof interval !== "number" ||
typeof expiresIn !== "number"
) {
throw new Error("Invalid device code response fields");
}
return {
device_code: deviceCode,
user_code: userCode,
verification_uri: verificationUri,
interval,
expires_in: expiresIn,
};
}
/**
* Sleep that can be interrupted by an AbortSignal
*/
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Login cancelled"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener(
"abort",
() => {
clearTimeout(timeout);
reject(new Error("Login cancelled"));
},
{ once: true },
);
});
}
async function pollForGitHubAccessToken(
domain: string,
deviceCode: string,
intervalSeconds: number,
expiresIn: number,
signal?: AbortSignal,
) {
const urls = getUrls(domain);
const deadline = Date.now() + expiresIn * 1000;
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
while (Date.now() < deadline) {
if (signal?.aborted) {
throw new Error("Login cancelled");
}
const raw = await fetchJson(urls.accessTokenUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
device_code: deviceCode,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
}),
});
if (
raw &&
typeof raw === "object" &&
typeof (raw as DeviceTokenSuccessResponse).access_token === "string"
) {
return (raw as DeviceTokenSuccessResponse).access_token;
}
if (
raw &&
typeof raw === "object" &&
typeof (raw as DeviceTokenErrorResponse).error === "string"
) {
const err = (raw as DeviceTokenErrorResponse).error;
if (err === "authorization_pending") {
await abortableSleep(intervalMs, signal);
continue;
}
if (err === "slow_down") {
intervalMs += 5000;
await abortableSleep(intervalMs, signal);
continue;
}
throw new Error(`Device flow failed: ${err}`);
}
await abortableSleep(intervalMs, signal);
}
throw new Error("Device flow timed out");
}
/**
* Refresh GitHub Copilot token
*/
export async function refreshGitHubCopilotToken(
refreshToken: string,
enterpriseDomain?: string,
): Promise<OAuthCredentials> {
const domain = enterpriseDomain || "github.com";
const urls = getUrls(domain);
const raw = await fetchJson(urls.copilotTokenUrl, {
headers: {
Accept: "application/json",
Authorization: `Bearer ${refreshToken}`,
...COPILOT_HEADERS,
},
});
if (!raw || typeof raw !== "object") {
throw new Error("Invalid Copilot token response");
}
const token = (raw as Record<string, unknown>).token;
const expiresAt = (raw as Record<string, unknown>).expires_at;
if (typeof token !== "string" || typeof expiresAt !== "number") {
throw new Error("Invalid Copilot token response fields");
}
return {
refresh: refreshToken,
access: token,
expires: expiresAt * 1000 - 5 * 60 * 1000,
enterpriseUrl: enterpriseDomain,
};
}
/**
* Enable a model for the user's GitHub Copilot account.
* This is required for some models (like Claude, Grok) before they can be used.
*/
async function enableGitHubCopilotModel(
token: string,
modelId: string,
enterpriseDomain?: string,
): Promise<boolean> {
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
const url = `${baseUrl}/models/${modelId}/policy`;
try {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...COPILOT_HEADERS,
"openai-intent": "chat-policy",
"x-interaction-type": "chat-policy",
},
body: JSON.stringify({ state: "enabled" }),
});
return response.ok;
} catch {
return false;
}
}
/**
* Enable all known GitHub Copilot models that may require policy acceptance.
* Called after successful login to ensure all models are available.
*/
async function enableAllGitHubCopilotModels(
token: string,
enterpriseDomain?: string,
onProgress?: (model: string, success: boolean) => void,
): Promise<void> {
const models = getModels("github-copilot");
await Promise.all(
models.map(async (model) => {
const success = await enableGitHubCopilotModel(
token,
model.id,
enterpriseDomain,
);
onProgress?.(model.id, success);
}),
);
}
/**
* Login with GitHub Copilot OAuth (device code flow)
*
* @param options.onAuth - Callback with URL and optional instructions (user code)
* @param options.onPrompt - Callback to prompt user for input
* @param options.onProgress - Optional progress callback
* @param options.signal - Optional AbortSignal for cancellation
*/
export async function loginGitHubCopilot(options: {
onAuth: (url: string, instructions?: string) => void;
onPrompt: (prompt: {
message: string;
placeholder?: string;
allowEmpty?: boolean;
}) => Promise<string>;
onProgress?: (message: string) => void;
signal?: AbortSignal;
}): Promise<OAuthCredentials> {
const input = await options.onPrompt({
message: "GitHub Enterprise URL/domain (blank for github.com)",
placeholder: "company.ghe.com",
allowEmpty: true,
});
if (options.signal?.aborted) {
throw new Error("Login cancelled");
}
const trimmed = input.trim();
const enterpriseDomain = normalizeDomain(input);
if (trimmed && !enterpriseDomain) {
throw new Error("Invalid GitHub Enterprise URL/domain");
}
const domain = enterpriseDomain || "github.com";
const device = await startDeviceFlow(domain);
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
const githubAccessToken = await pollForGitHubAccessToken(
domain,
device.device_code,
device.interval,
device.expires_in,
options.signal,
);
const credentials = await refreshGitHubCopilotToken(
githubAccessToken,
enterpriseDomain ?? undefined,
);
// Enable all models after successful login
options.onProgress?.("Enabling models...");
await enableAllGitHubCopilotModels(
credentials.access,
enterpriseDomain ?? undefined,
);
return credentials;
}
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
id: "github-copilot",
name: "GitHub Copilot",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as CopilotCredentials;
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
modifyModels(
models: Model<Api>[],
credentials: OAuthCredentials,
): Model<Api>[] {
const creds = credentials as CopilotCredentials;
const domain = creds.enterpriseUrl
? (normalizeDomain(creds.enterpriseUrl) ?? undefined)
: undefined;
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
return models.map((m) =>
m.provider === "github-copilot" ? { ...m, baseUrl } : m,
);
},
};

View file

@ -0,0 +1,492 @@
/**
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type AntigravityCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
// Antigravity OAuth credentials (different from Gemini CLI)
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
);
const CLIENT_SECRET = decode(
"R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=",
);
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
// Antigravity requires additional scopes
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
// Fallback project ID when discovery fails
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<
typeof import("node:http").createServer
> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error(
"Antigravity OAuth is only available in Node.js environments",
);
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:51121`);
if (url.pathname === "/oauth-callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(51121, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string | { id?: string };
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Discover or provision a project for the user
*/
async function discoverProject(
accessToken: string,
onProgress?: (message: string) => void,
): Promise<string> {
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Try endpoints in order: prod first, then sandbox
const endpoints = [
"https://cloudcode-pa.googleapis.com",
"https://daily-cloudcode-pa.sandbox.googleapis.com",
];
onProgress?.("Checking for existing project...");
for (const endpoint of endpoints) {
try {
const loadResponse = await fetch(
`${endpoint}/v1internal:loadCodeAssist`,
{
method: "POST",
headers,
body: JSON.stringify({
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
}),
},
);
if (loadResponse.ok) {
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
// Handle both string and object formats
if (
typeof data.cloudaicompanionProject === "string" &&
data.cloudaicompanionProject
) {
return data.cloudaicompanionProject;
}
if (
data.cloudaicompanionProject &&
typeof data.cloudaicompanionProject === "object" &&
data.cloudaicompanionProject.id
) {
return data.cloudaicompanionProject.id;
}
}
} catch {
// Try next endpoint
}
}
// Use fallback project ID
onProgress?.("Using default project...");
return DEFAULT_PROJECT_ID;
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch(
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
{
headers: {
Authorization: `Bearer ${accessToken}`,
},
},
);
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Antigravity token
*/
export async function refreshAntigravityToken(
refreshToken: string,
projectId: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Antigravity token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Antigravity OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginAntigravity(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const antigravityOAuthProvider: OAuthProviderInterface = {
id: "google-antigravity",
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAntigravity(
callbacks.onAuth,
callbacks.onProgress,
callbacks.onManualCodeInput,
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as AntigravityCredentials;
if (!creds.projectId) {
throw new Error("Antigravity credentials missing projectId");
}
return refreshAntigravityToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as AntigravityCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,648 @@
/**
* Gemini CLI OAuth flow (Google Cloud Code Assist)
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthProviderInterface,
} from "./types.js";
type GeminiCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
);
const CLIENT_SECRET = decode(
"R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=",
);
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<
typeof import("node:http").createServer
> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:8085`);
if (url.pathname === "/oauth2callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(8085, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string;
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Long-running operation response from onboardUser
*/
interface LongRunningOperationResponse {
name?: string;
done?: boolean;
response?: {
cloudaicompanionProject?: { id?: string };
};
}
// Tier IDs as used by the Cloud Code API
const TIER_FREE = "free-tier";
const TIER_LEGACY = "legacy-tier";
const TIER_STANDARD = "standard-tier";
interface GoogleRpcErrorResponse {
error?: {
details?: Array<{ reason?: string }>;
};
}
/**
* Wait helper for onboarding retries
*/
function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
/**
* Get default tier from allowed tiers
*/
function getDefaultTier(
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>,
): { id?: string } {
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
const defaultTier = allowedTiers.find((t) => t.isDefault);
return defaultTier ?? { id: TIER_LEGACY };
}
function isVpcScAffectedUser(payload: unknown): boolean {
if (!payload || typeof payload !== "object") return false;
if (!("error" in payload)) return false;
const error = (payload as GoogleRpcErrorResponse).error;
if (!error?.details || !Array.isArray(error.details)) return false;
return error.details.some(
(detail) => detail.reason === "SECURITY_POLICY_VIOLATED",
);
}
/**
* Poll a long-running operation until completion
*/
async function pollOperation(
operationName: string,
headers: Record<string, string>,
onProgress?: (message: string) => void,
): Promise<LongRunningOperationResponse> {
let attempt = 0;
while (true) {
if (attempt > 0) {
onProgress?.(
`Waiting for project provisioning (attempt ${attempt + 1})...`,
);
await wait(5000);
}
const response = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`,
{
method: "GET",
headers,
},
);
if (!response.ok) {
throw new Error(
`Failed to poll operation: ${response.status} ${response.statusText}`,
);
}
const data = (await response.json()) as LongRunningOperationResponse;
if (data.done) {
return data;
}
attempt += 1;
}
}
/**
* Discover or provision a Google Cloud project for the user
*/
async function discoverProject(
accessToken: string,
onProgress?: (message: string) => void,
): Promise<string> {
// Check for user-provided project ID via environment variable
const envProjectId =
process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "gl-node/22.17.0",
};
// Try to load existing project via loadCodeAssist
onProgress?.("Checking for existing Cloud Code Assist project...");
const loadResponse = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`,
{
method: "POST",
headers,
body: JSON.stringify({
cloudaicompanionProject: envProjectId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
duetProject: envProjectId,
},
}),
},
);
let data: LoadCodeAssistPayload;
if (!loadResponse.ok) {
let errorPayload: unknown;
try {
errorPayload = await loadResponse.clone().json();
} catch {
errorPayload = undefined;
}
if (isVpcScAffectedUser(errorPayload)) {
data = { currentTier: { id: TIER_STANDARD } };
} else {
const errorText = await loadResponse.text();
throw new Error(
`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`,
);
}
} else {
data = (await loadResponse.json()) as LoadCodeAssistPayload;
}
// If user already has a current tier and project, use it
if (data.currentTier) {
if (data.cloudaicompanionProject) {
return data.cloudaicompanionProject;
}
// User has a tier but no managed project - they need to provide one via env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
// User needs to be onboarded - get the default tier
const tier = getDefaultTier(data.allowedTiers);
const tierId = tier?.id ?? TIER_FREE;
if (tierId !== TIER_FREE && !envProjectId) {
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
onProgress?.(
"Provisioning Cloud Code Assist project (this may take a moment)...",
);
// Build onboard request - for free tier, don't include project ID (Google provisions one)
// For other tiers, include the user's project ID if available
const onboardBody: Record<string, unknown> = {
tierId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
};
if (tierId !== TIER_FREE && envProjectId) {
onboardBody.cloudaicompanionProject = envProjectId;
(onboardBody.metadata as Record<string, unknown>).duetProject =
envProjectId;
}
// Start onboarding - this returns a long-running operation
const onboardResponse = await fetch(
`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`,
{
method: "POST",
headers,
body: JSON.stringify(onboardBody),
},
);
if (!onboardResponse.ok) {
const errorText = await onboardResponse.text();
throw new Error(
`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`,
);
}
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
// If the operation isn't done yet, poll until completion
if (!lroData.done && lroData.name) {
lroData = await pollOperation(lroData.name, headers, onProgress);
}
// Try to get project ID from the response
const projectId = lroData.response?.cloudaicompanionProject?.id;
if (projectId) {
return projectId;
}
// If no project ID from onboarding, fall back to env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"Could not discover or provision a Google Cloud project. " +
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch(
"https://www.googleapis.com/oauth2/v1/userinfo?alt=json",
{
headers: {
Authorization: `Bearer ${accessToken}`,
},
},
);
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Google Cloud Code Assist token
*/
export async function refreshGoogleCloudToken(
refreshToken: string,
projectId: string,
): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Google Cloud token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginGeminiCli(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const geminiCliOAuthProvider: OAuthProviderInterface = {
id: "google-gemini-cli",
name: "Google Cloud Code Assist (Gemini CLI)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGeminiCli(
callbacks.onAuth,
callbacks.onProgress,
callbacks.onManualCodeInput,
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as GeminiCredentials;
if (!creds.projectId) {
throw new Error("Google Cloud credentials missing projectId");
}
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as GeminiCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,187 @@
/**
* OAuth credential management for AI providers.
*
* This module handles login, token refresh, and credential storage
* for OAuth-based providers:
* - Anthropic (Claude Pro/Max)
* - GitHub Copilot
* - Google Cloud Code Assist (Gemini CLI)
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
*/
// Anthropic
export {
anthropicOAuthProvider,
loginAnthropic,
refreshAnthropicToken,
} from "./anthropic.js";
// GitHub Copilot
export {
getGitHubCopilotBaseUrl,
githubCopilotOAuthProvider,
loginGitHubCopilot,
normalizeDomain,
refreshGitHubCopilotToken,
} from "./github-copilot.js";
// Google Antigravity
export {
antigravityOAuthProvider,
loginAntigravity,
refreshAntigravityToken,
} from "./google-antigravity.js";
// Google Gemini CLI
export {
geminiCliOAuthProvider,
loginGeminiCli,
refreshGoogleCloudToken,
} from "./google-gemini-cli.js";
// OpenAI Codex (ChatGPT OAuth)
export {
loginOpenAICodex,
openaiCodexOAuthProvider,
refreshOpenAICodexToken,
} from "./openai-codex.js";
export * from "./types.js";
// ============================================================================
// Provider Registry
// ============================================================================
import { anthropicOAuthProvider } from "./anthropic.js";
import { githubCopilotOAuthProvider } from "./github-copilot.js";
import { antigravityOAuthProvider } from "./google-antigravity.js";
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
import { openaiCodexOAuthProvider } from "./openai-codex.js";
import type {
OAuthCredentials,
OAuthProviderId,
OAuthProviderInfo,
OAuthProviderInterface,
} from "./types.js";
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
anthropicOAuthProvider,
githubCopilotOAuthProvider,
geminiCliOAuthProvider,
antigravityOAuthProvider,
openaiCodexOAuthProvider,
];
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
);
/**
* Get an OAuth provider by ID
*/
export function getOAuthProvider(
id: OAuthProviderId,
): OAuthProviderInterface | undefined {
return oauthProviderRegistry.get(id);
}
/**
* Register a custom OAuth provider
*/
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
oauthProviderRegistry.set(provider.id, provider);
}
/**
* Unregister an OAuth provider.
*
* If the provider is built-in, restores the built-in implementation.
* Custom providers are removed completely.
*/
export function unregisterOAuthProvider(id: string): void {
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find(
(provider) => provider.id === id,
);
if (builtInProvider) {
oauthProviderRegistry.set(id, builtInProvider);
return;
}
oauthProviderRegistry.delete(id);
}
/**
* Reset OAuth providers to built-ins.
*/
export function resetOAuthProviders(): void {
oauthProviderRegistry.clear();
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
oauthProviderRegistry.set(provider.id, provider);
}
}
/**
* Get all registered OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInterface[] {
return Array.from(oauthProviderRegistry.values());
}
/**
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
*/
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
return getOAuthProviders().map((p) => ({
id: p.id,
name: p.name,
available: true,
}));
}
// ============================================================================
// High-level API (uses provider registry)
// ============================================================================
/**
* Refresh token for any OAuth provider.
* @deprecated Use getOAuthProvider(id).refreshToken() instead
*/
export async function refreshOAuthToken(
providerId: OAuthProviderId,
credentials: OAuthCredentials,
): Promise<OAuthCredentials> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
return provider.refreshToken(credentials);
}
/**
* Get API key for a provider from OAuth credentials.
* Automatically refreshes expired tokens.
*
* @returns API key string and updated credentials, or null if no credentials
* @throws Error if refresh fails
*/
export async function getOAuthApiKey(
providerId: OAuthProviderId,
credentials: Record<string, OAuthCredentials>,
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
let creds = credentials[providerId];
if (!creds) {
return null;
}
// Refresh if expired
if (Date.now() >= creds.expires) {
try {
creds = await provider.refreshToken(creds);
} catch (_error) {
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
}
}
const apiKey = provider.getApiKey(creds);
return { newCredentials: creds, apiKey };
}

View file

@ -0,0 +1,499 @@
/**
* OpenAI Codex (ChatGPT OAuth) flow
*
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
let _http: typeof import("node:http") | null = null;
if (
typeof process !== "undefined" &&
(process.versions?.node || process.versions?.bun)
) {
import("node:crypto").then((m) => {
_randomBytes = m.randomBytes;
});
import("node:http").then((m) => {
_http = m;
});
}
import { generatePKCE } from "./pkce.js";
import type {
OAuthCredentials,
OAuthLoginCallbacks,
OAuthPrompt,
OAuthProviderInterface,
} from "./types.js";
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
const TOKEN_URL = "https://auth.openai.com/oauth/token";
const REDIRECT_URI = "http://localhost:1455/auth/callback";
const SCOPE = "openid profile email offline_access";
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
const SUCCESS_HTML = `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Authentication successful</title>
</head>
<body>
<p>Authentication successful. Return to your terminal to continue.</p>
</body>
</html>`;
type TokenSuccess = {
type: "success";
access: string;
refresh: string;
expires: number;
};
type TokenFailure = { type: "failed" };
type TokenResult = TokenSuccess | TokenFailure;
type JwtPayload = {
[JWT_CLAIM_PATH]?: {
chatgpt_account_id?: string;
};
[key: string]: unknown;
};
function createState(): string {
if (!_randomBytes) {
throw new Error(
"OpenAI Codex OAuth is only available in Node.js environments",
);
}
return _randomBytes(16).toString("hex");
}
function parseAuthorizationInput(input: string): {
code?: string;
state?: string;
} {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// not a URL
}
if (value.includes("#")) {
const [code, state] = value.split("#", 2);
return { code, state };
}
if (value.includes("code=")) {
const params = new URLSearchParams(value);
return {
code: params.get("code") ?? undefined,
state: params.get("state") ?? undefined,
};
}
return { code: value };
}
function decodeJwt(token: string): JwtPayload | null {
try {
const parts = token.split(".");
if (parts.length !== 3) return null;
const payload = parts[1] ?? "";
const decoded = atob(payload);
return JSON.parse(decoded) as JwtPayload;
} catch {
return null;
}
}
async function exchangeAuthorizationCode(
code: string,
verifier: string,
redirectUri: string = REDIRECT_URI,
): Promise<TokenResult> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code,
code_verifier: verifier,
redirect_uri: redirectUri,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error("[openai-codex] code->token failed:", response.status, text);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (
!json.access_token ||
!json.refresh_token ||
typeof json.expires_in !== "number"
) {
console.error("[openai-codex] token response missing fields:", json);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
}
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
try {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: CLIENT_ID,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error(
"[openai-codex] Token refresh failed:",
response.status,
text,
);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (
!json.access_token ||
!json.refresh_token ||
typeof json.expires_in !== "number"
) {
console.error(
"[openai-codex] Token refresh response missing fields:",
json,
);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
} catch (error) {
console.error("[openai-codex] Token refresh error:", error);
return { type: "failed" };
}
}
async function createAuthorizationFlow(
originator: string = "pi",
): Promise<{ verifier: string; state: string; url: string }> {
const { verifier, challenge } = await generatePKCE();
const state = createState();
const url = new URL(AUTHORIZE_URL);
url.searchParams.set("response_type", "code");
url.searchParams.set("client_id", CLIENT_ID);
url.searchParams.set("redirect_uri", REDIRECT_URI);
url.searchParams.set("scope", SCOPE);
url.searchParams.set("code_challenge", challenge);
url.searchParams.set("code_challenge_method", "S256");
url.searchParams.set("state", state);
url.searchParams.set("id_token_add_organizations", "true");
url.searchParams.set("codex_cli_simplified_flow", "true");
url.searchParams.set("originator", originator);
return { verifier, state, url: url.toString() };
}
type OAuthServerInfo = {
close: () => void;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string } | null>;
};
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
if (!_http) {
throw new Error(
"OpenAI Codex OAuth is only available in Node.js environments",
);
}
let lastCode: string | null = null;
let cancelled = false;
const server = _http.createServer((req, res) => {
try {
const url = new URL(req.url || "", "http://localhost");
if (url.pathname !== "/auth/callback") {
res.statusCode = 404;
res.end("Not found");
return;
}
if (url.searchParams.get("state") !== state) {
res.statusCode = 400;
res.end("State mismatch");
return;
}
const code = url.searchParams.get("code");
if (!code) {
res.statusCode = 400;
res.end("Missing authorization code");
return;
}
res.statusCode = 200;
res.setHeader("Content-Type", "text/html; charset=utf-8");
res.end(SUCCESS_HTML);
lastCode = code;
} catch {
res.statusCode = 500;
res.end("Internal error");
}
});
return new Promise((resolve) => {
server
.listen(1455, "127.0.0.1", () => {
resolve({
close: () => server.close(),
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
for (let i = 0; i < 600; i += 1) {
if (lastCode) return { code: lastCode };
if (cancelled) return null;
await sleep();
}
return null;
},
});
})
.on("error", (err: NodeJS.ErrnoException) => {
console.error(
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
err.code,
") Falling back to manual paste.",
);
resolve({
close: () => {
try {
server.close();
} catch {
// ignore
}
},
cancelWait: () => {},
waitForCode: async () => null,
});
});
});
}
function getAccountId(accessToken: string): string | null {
const payload = decodeJwt(accessToken);
const auth = payload?.[JWT_CLAIM_PATH];
const accountId = auth?.chatgpt_account_id;
return typeof accountId === "string" && accountId.length > 0
? accountId
: null;
}
/**
* Login with OpenAI Codex OAuth
*
* @param options.onAuth - Called with URL and instructions when auth starts
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
* @param options.onProgress - Optional progress messages
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
* Races with browser callback - whichever completes first wins.
* Useful for showing paste input immediately alongside browser flow.
* @param options.originator - OAuth originator parameter (defaults to "pi")
*/
export async function loginOpenAICodex(options: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
originator?: string;
}): Promise<OAuthCredentials> {
const { verifier, state, url } = await createAuthorizationFlow(
options.originator,
);
const server = await startLocalOAuthServer(state);
options.onAuth({
url,
instructions: "A browser window should open. Complete login to finish.",
});
let code: string | undefined;
try {
if (options.onManualCodeInput) {
// Race between browser callback and manual input
let manualCode: string | undefined;
let manualError: Error | undefined;
const manualPromise = options
.onManualCodeInput()
.then((input) => {
manualCode = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won
code = result.code;
} else if (manualCode) {
// Manual input won (or callback timed out and user had entered code)
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
// If still no code, wait for manual promise to complete and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualCode) {
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
}
} else {
// Original flow: wait for callback, then prompt if needed
const result = await server.waitForCode();
if (result?.code) {
code = result.code;
}
}
// Fallback to onPrompt if still no code
if (!code) {
const input = await options.onPrompt({
message: "Paste the authorization code (or full redirect URL):",
});
const parsed = parseAuthorizationInput(input);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
if (!code) {
throw new Error("Missing authorization code");
}
const tokenResult = await exchangeAuthorizationCode(code, verifier);
if (tokenResult.type !== "success") {
throw new Error("Token exchange failed");
}
const accountId = getAccountId(tokenResult.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: tokenResult.access,
refresh: tokenResult.refresh,
expires: tokenResult.expires,
accountId,
};
} finally {
server.close();
}
}
/**
* Refresh OpenAI Codex OAuth token
*/
export async function refreshOpenAICodexToken(
refreshToken: string,
): Promise<OAuthCredentials> {
const result = await refreshAccessToken(refreshToken);
if (result.type !== "success") {
throw new Error("Failed to refresh OpenAI Codex token");
}
const accountId = getAccountId(result.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: result.access,
refresh: result.refresh,
expires: result.expires,
accountId,
};
}
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
id: "openai-codex",
name: "ChatGPT Plus/Pro (Codex Subscription)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshOpenAICodexToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,37 @@
/**
* PKCE utilities using Web Crypto API.
* Works in both Node.js 20+ and browsers.
*/
/**
* Encode bytes as base64url string.
*/
function base64urlEncode(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) {
binary += String.fromCharCode(byte);
}
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
}
/**
* Generate PKCE code verifier and challenge.
* Uses Web Crypto API for cross-platform compatibility.
*/
export async function generatePKCE(): Promise<{
verifier: string;
challenge: string;
}> {
// Generate random verifier
const verifierBytes = new Uint8Array(32);
crypto.getRandomValues(verifierBytes);
const verifier = base64urlEncode(verifierBytes);
// Compute SHA-256 challenge
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
return { verifier, challenge };
}

View file

@ -0,0 +1,62 @@
import type { Api, Model } from "../../types.js";
export type OAuthCredentials = {
refresh: string;
access: string;
expires: number;
[key: string]: unknown;
};
export type OAuthProviderId = string;
/** @deprecated Use OAuthProviderId instead */
export type OAuthProvider = OAuthProviderId;
export type OAuthPrompt = {
message: string;
placeholder?: string;
allowEmpty?: boolean;
};
export type OAuthAuthInfo = {
url: string;
instructions?: string;
};
export interface OAuthLoginCallbacks {
onAuth: (info: OAuthAuthInfo) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
signal?: AbortSignal;
}
export interface OAuthProviderInterface {
readonly id: OAuthProviderId;
readonly name: string;
/** Run the login flow, return credentials to persist */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Whether login uses a local callback server and supports manual code input. */
usesCallbackServer?: boolean;
/** Refresh expired credentials, return updated credentials to persist */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl) */
modifyModels?(
models: Model<Api>[],
credentials: OAuthCredentials,
): Model<Api>[];
}
/** @deprecated Use OAuthProviderInterface instead */
export interface OAuthProviderInfo {
id: OAuthProviderId;
name: string;
available: boolean;
}

View file

@ -0,0 +1,127 @@
import type { AssistantMessage } from "../types.js";
/**
* Regex patterns to detect context overflow errors from different providers.
*
* These patterns match error messages returned when the input exceeds
* the model's context window.
*
* Provider-specific patterns (with example error messages):
*
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
* - OpenAI: "Your input exceeds the context window of this model"
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
* - Groq: "Please reduce the length of the messages or completion"
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
* - llama.cpp: "the request exceeds the available context size, try increasing it"
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
* - MiniMax: "invalid params, context window exceeds limit"
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
* - Ollama: Silently truncates input - not detectable via error message
*/
const OVERFLOW_PATTERNS = [
/prompt is too long/i, // Anthropic
/input is too long for requested model/i, // Amazon Bedrock
/exceeds the context window/i, // OpenAI (Completions & Responses API)
/input token count.*exceeds the maximum/i, // Google (Gemini)
/maximum prompt length is \d+/i, // xAI (Grok)
/reduce the length of the messages/i, // Groq
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
/exceeds the limit of \d+/i, // GitHub Copilot
/exceeds the available context size/i, // llama.cpp server
/greater than the context length/i, // LM Studio
/context window exceeds limit/i, // MiniMax
/exceeded model token limit/i, // Kimi For Coding
/too large for model with \d+ maximum context length/i, // Mistral
/context[_ ]length[_ ]exceeded/i, // Generic fallback
/too many tokens/i, // Generic fallback
/token limit exceeded/i, // Generic fallback
];
/**
* Check if an assistant message represents a context overflow error.
*
* This handles two cases:
* 1. Error-based overflow: Most providers return stopReason "error" with a
* specific error message pattern.
* 2. Silent overflow: Some providers accept overflow requests and return
* successfully. For these, we check if usage.input exceeds the context window.
*
* ## Reliability by Provider
*
* **Reliable detection (returns error with detectable message):**
* - Anthropic: "prompt is too long: X tokens > Y maximum"
* - OpenAI (Completions & Responses): "exceeds the context window"
* - Google Gemini: "input token count exceeds the maximum"
* - xAI (Grok): "maximum prompt length is X but request contains Y"
* - Groq: "reduce the length of the messages"
* - Cerebras: 400/413 status code (no body)
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - OpenRouter (all backends): "maximum context length is X tokens"
* - llama.cpp: "exceeds the available context size"
* - LM Studio: "greater than the context length"
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
*
* **Unreliable detection:**
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
* The response will have usage.input < expected, but we don't know the expected value.
*
* ## Custom Providers
*
* If you've added custom models via settings.json, this function may not detect
* overflow errors from those providers. To add support:
*
* 1. Send a request that exceeds the model's context window
* 2. Check the errorMessage in the response
* 3. Create a regex pattern that matches the error
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
* check the errorMessage yourself before calling this function
*
* @param message - The assistant message to check
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
* @returns true if the message indicates a context overflow
*/
export function isContextOverflow(
message: AssistantMessage,
contextWindow?: number,
): boolean {
// Case 1: Check error message patterns
if (message.stopReason === "error" && message.errorMessage) {
// Check known patterns
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
return true;
}
// Cerebras returns 400/413 with no body for context overflow
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
if (
/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)
) {
return true;
}
}
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
if (contextWindow && message.stopReason === "stop") {
const inputTokens = message.usage.input + message.usage.cacheRead;
if (inputTokens > contextWindow) {
return true;
}
}
return false;
}
/**
* Get the overflow patterns for testing purposes.
*/
export function getOverflowPatterns(): RegExp[] {
return [...OVERFLOW_PATTERNS];
}

View file

@ -0,0 +1,28 @@
/**
* Removes unpaired Unicode surrogate characters from a string.
*
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
* or vice versa) cause JSON serialization errors in many API providers.
*
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
* surrogates and will NOT be affected by this function.
*
* @param text - The text to sanitize
* @returns The sanitized text with unpaired surrogates removed
*
* @example
* // Valid emoji (properly paired surrogates) are preserved
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
*
* // Unpaired high surrogate is removed
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
*/
export function sanitizeSurrogates(text: string): string {
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
return text.replace(
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
"",
);
}

View file

@ -0,0 +1,24 @@
import { type TUnsafe, Type } from "@sinclair/typebox";
/**
* Creates a string enum schema compatible with Google's API and other providers
* that don't support anyOf/const patterns.
*
* @example
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
* description: "The operation to perform"
* });
*
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
*/
export function StringEnum<T extends readonly string[]>(
values: T,
options?: { description?: string; default?: T[number] },
): TUnsafe<T[number]> {
return Type.Unsafe<T[number]>({
type: "string",
enum: values as any,
...(options?.description && { description: options.description }),
...(options?.default && { default: options.default }),
});
}

View file

@ -0,0 +1,88 @@
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import type { Tool, ToolCall } from "../types.js";
// Detect if we're in a browser extension environment with strict CSP
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
const isBrowserExtension =
typeof globalThis !== "undefined" &&
(globalThis as any).chrome?.runtime?.id !== undefined;
// Create a singleton AJV instance with formats (only if not in browser extension)
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
let ajv: any = null;
if (!isBrowserExtension) {
try {
ajv = new Ajv({
allErrors: true,
strict: false,
coerceTypes: true,
});
addFormats(ajv);
} catch (_e) {
// AJV initialization failed (likely CSP restriction)
console.warn("AJV validation disabled due to CSP restrictions");
}
}
/**
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
* @param tools Array of tool definitions
* @param toolCall The tool call from the LLM
* @returns The validated arguments
* @throws Error if tool is not found or validation fails
*/
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
const tool = tools.find((t) => t.name === toolCall.name);
if (!tool) {
throw new Error(`Tool "${toolCall.name}" not found`);
}
return validateToolArguments(tool, toolCall);
}
/**
* Validates tool call arguments against the tool's TypeBox schema
* @param tool The tool definition with TypeBox schema
* @param toolCall The tool call from the LLM
* @returns The validated (and potentially coerced) arguments
* @throws Error with formatted message if validation fails
*/
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
if (!ajv || isBrowserExtension) {
// Trust the LLM's output without validation
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
return toolCall.arguments;
}
// Compile the schema
const validate = ajv.compile(tool.parameters);
// Clone arguments so AJV can safely mutate for type coercion
const args = structuredClone(toolCall.arguments);
// Validate the arguments (AJV mutates args in-place for type coercion)
if (validate(args)) {
return args;
}
// Format validation errors nicely
const errors =
validate.errors
?.map((err: any) => {
const path = err.instancePath
? err.instancePath.substring(1)
: err.params.missingProperty || "root";
return ` - ${path}: ${err.message}`;
})
.join("\n") || "Unknown validation error";
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
throw new Error(errorMessage);
}