co-mono/packages/ai/src/providers/google-gemini-cli.ts
Ahmed Kamal e42e9e6305 fix(ai): classify Google thoughtSignature as thinking
Google streaming may emit thoughtSignature without thought=true (including empty-text signature-only parts). Treat non-empty thoughtSignature as thinking to avoid leaking reasoning into normal text and retain signature across streaming deltas. Add unit test coverage.
2026-01-06 20:47:19 +02:00

613 lines
18 KiB
TypeScript

/**
* Google Gemini CLI / Antigravity provider.
* Shared implementation for both google-gemini-cli and google-antigravity providers.
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
*/
import type { Content, ThinkingConfig } from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReasonString,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
/**
* Thinking level for Gemini 3 models.
* Mirrors Google's ThinkingLevel enum values.
*/
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
export interface GoogleGeminiCliOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
/**
* Thinking/reasoning configuration.
* - Gemini 2.x models: use `budgetTokens` to set the thinking budget
* - Gemini 3 models (gemini-3-pro-*, gemini-3-flash-*): use `level` instead
*
* When using `streamSimple`, this is handled automatically based on the model.
*/
thinking?: {
enabled: boolean;
/** Thinking budget in tokens. Use for Gemini 2.x models. */
budgetTokens?: number;
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
level?: GoogleThinkingLevel;
};
projectId?: string;
}
const DEFAULT_ENDPOINT = "https://cloudcode-pa.googleapis.com";
// Headers for Gemini CLI (prod endpoint)
const GEMINI_CLI_HEADERS = {
"User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"X-Goog-Api-Client": "gl-node/22.17.0",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Headers for Antigravity (sandbox endpoint) - requires specific User-Agent
const ANTIGRAVITY_HEADERS = {
"User-Agent": "antigravity/1.11.5 darwin/arm64",
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
// Retry configuration
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 1000;
/**
* Extract retry delay from Gemini error response (in milliseconds).
* Parses patterns like:
* - "Your quota will reset after 39s"
* - "Your quota will reset after 18h31m10s"
* - "Please retry in Xs" or "Please retry in Xms"
* - "retryDelay": "34.074824224s" (JSON field)
*/
function extractRetryDelay(errorText: string): number | undefined {
// Pattern 1: "Your quota will reset after ..." (formats: "18h31m10s", "10m15s", "6s", "39s")
const durationMatch = errorText.match(/reset after (?:(\d+)h)?(?:(\d+)m)?(\d+(?:\.\d+)?)s/i);
if (durationMatch) {
const hours = durationMatch[1] ? parseInt(durationMatch[1], 10) : 0;
const minutes = durationMatch[2] ? parseInt(durationMatch[2], 10) : 0;
const seconds = parseFloat(durationMatch[3]);
if (!Number.isNaN(seconds)) {
const totalMs = ((hours * 60 + minutes) * 60 + seconds) * 1000;
if (totalMs > 0) {
return Math.ceil(totalMs + 1000); // Add 1s buffer
}
}
}
// Pattern 2: "Please retry in X[ms|s]"
const retryInMatch = errorText.match(/Please retry in ([0-9.]+)(ms|s)/i);
if (retryInMatch?.[1]) {
const value = parseFloat(retryInMatch[1]);
if (!Number.isNaN(value) && value > 0) {
const ms = retryInMatch[2].toLowerCase() === "ms" ? value : value * 1000;
return Math.ceil(ms + 1000);
}
}
// Pattern 3: "retryDelay": "34.074824224s" (JSON field in error details)
const retryDelayMatch = errorText.match(/"retryDelay":\s*"([0-9.]+)(ms|s)"/i);
if (retryDelayMatch?.[1]) {
const value = parseFloat(retryDelayMatch[1]);
if (!Number.isNaN(value) && value > 0) {
const ms = retryDelayMatch[2].toLowerCase() === "ms" ? value : value * 1000;
return Math.ceil(ms + 1000);
}
}
return undefined;
}
/**
* Check if an error is retryable (rate limit, server error, etc.)
*/
function isRetryableError(status: number, errorText: string): boolean {
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
return true;
}
return /resource.?exhausted|rate.?limit|overloaded|service.?unavailable/i.test(errorText);
}
/**
* Sleep for a given number of milliseconds, respecting abort signal.
*/
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Request was aborted"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener("abort", () => {
clearTimeout(timeout);
reject(new Error("Request was aborted"));
});
});
}
interface CloudCodeAssistRequest {
project: string;
model: string;
request: {
contents: Content[];
systemInstruction?: { parts: { text: string }[] };
generationConfig?: {
maxOutputTokens?: number;
temperature?: number;
thinkingConfig?: ThinkingConfig;
};
tools?: ReturnType<typeof convertTools>;
toolConfig?: {
functionCallingConfig: {
mode: ReturnType<typeof mapToolChoice>;
};
};
};
userAgent?: string;
requestId?: string;
}
interface CloudCodeAssistResponseChunk {
response?: {
candidates?: Array<{
content?: {
role: string;
parts?: Array<{
text?: string;
thought?: boolean;
thoughtSignature?: string;
functionCall?: {
name: string;
args: Record<string, unknown>;
id?: string;
};
}>;
};
finishReason?: string;
}>;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
thoughtsTokenCount?: number;
totalTokenCount?: number;
cachedContentTokenCount?: number;
};
modelVersion?: string;
responseId?: string;
};
traceId?: string;
}
export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli"> = (
model: Model<"google-gemini-cli">,
context: Context,
options?: GoogleGeminiCliOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-gemini-cli" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// apiKey is JSON-encoded: { token, projectId }
const apiKeyRaw = options?.apiKey;
if (!apiKeyRaw) {
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
}
let accessToken: string;
let projectId: string;
try {
const parsed = JSON.parse(apiKeyRaw) as { token: string; projectId: string };
accessToken = parsed.token;
projectId = parsed.projectId;
} catch {
throw new Error("Invalid Google Cloud Code Assist credentials. Use /login to re-authenticate.");
}
if (!accessToken || !projectId) {
throw new Error("Missing token or projectId in Google Cloud credentials. Use /login to re-authenticate.");
}
const requestBody = buildRequest(model, context, projectId, options);
const endpoint = model.baseUrl || DEFAULT_ENDPOINT;
const url = `${endpoint}/v1internal:streamGenerateContent?alt=sse`;
// Use Antigravity headers for sandbox endpoint, otherwise Gemini CLI headers
const isAntigravity = endpoint.includes("sandbox.googleapis.com");
const headers = isAntigravity ? ANTIGRAVITY_HEADERS : GEMINI_CLI_HEADERS;
// Fetch with retry logic for rate limits and transient errors
let response: Response | undefined;
let lastError: Error | undefined;
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
try {
response = await fetch(url, {
method: "POST",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
Accept: "text/event-stream",
...headers,
},
body: JSON.stringify(requestBody),
signal: options?.signal,
});
if (response.ok) {
break; // Success, exit retry loop
}
const errorText = await response.text();
// Check if retryable
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
// Use server-provided delay or exponential backoff
const serverDelay = extractRetryDelay(errorText);
const delayMs = serverDelay ?? BASE_DELAY_MS * 2 ** attempt;
await sleep(delayMs, options?.signal);
continue;
}
// Not retryable or max retries exceeded
throw new Error(`Cloud Code Assist API error (${response.status}): ${errorText}`);
} catch (error) {
if (error instanceof Error && error.message === "Request was aborted") {
throw error;
}
lastError = error instanceof Error ? error : new Error(String(error));
// Network errors are retryable
if (attempt < MAX_RETRIES) {
const delayMs = BASE_DELAY_MS * 2 ** attempt;
await sleep(delayMs, options?.signal);
continue;
}
throw lastError;
}
}
if (!response || !response.ok) {
throw lastError ?? new Error("Failed to get response after retries");
}
if (!response.body) {
throw new Error("No response body");
}
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
// Read SSE stream
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (!line.startsWith("data:")) continue;
const jsonStr = line.slice(5).trim();
if (!jsonStr) continue;
let chunk: CloudCodeAssistResponseChunk;
try {
chunk = JSON.parse(jsonStr);
} catch {
continue;
}
// Unwrap the response
const responseData = chunk.response;
if (!responseData) continue;
const candidate = responseData.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
const providedId = part.functionCall.id;
const needsNewId =
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments: part.functionCall.args as Record<string, unknown>,
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
};
output.content.push(toolCall);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReasonString(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (responseData.usageMetadata) {
// promptTokenCount includes cachedContentTokenCount, so subtract to get fresh input
const promptTokens = responseData.usageMetadata.promptTokenCount || 0;
const cacheReadTokens = responseData.usageMetadata.cachedContentTokenCount || 0;
output.usage = {
input: promptTokens - cacheReadTokens,
output:
(responseData.usageMetadata.candidatesTokenCount || 0) +
(responseData.usageMetadata.thoughtsTokenCount || 0),
cacheRead: cacheReadTokens,
cacheWrite: 0,
totalTokens: responseData.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
function buildRequest(
model: Model<"google-gemini-cli">,
context: Context,
projectId: string,
options: GoogleGeminiCliOptions = {},
): CloudCodeAssistRequest {
const contents = convertMessages(model, context);
const generationConfig: CloudCodeAssistRequest["request"]["generationConfig"] = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
// Thinking config
if (options.thinking?.enabled && model.reasoning) {
generationConfig.thinkingConfig = {
includeThoughts: true,
};
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
}
const request: CloudCodeAssistRequest["request"] = {
contents,
};
// System instruction must be object with parts, not plain string
if (context.systemPrompt) {
request.systemInstruction = {
parts: [{ text: sanitizeSurrogates(context.systemPrompt) }],
};
}
if (Object.keys(generationConfig).length > 0) {
request.generationConfig = generationConfig;
}
if (context.tools && context.tools.length > 0) {
request.tools = convertTools(context.tools);
if (options.toolChoice) {
request.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
}
}
return {
project: projectId,
model: model.id,
request,
userAgent: "pi-coding-agent",
requestId: `pi-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`,
};
}