co-mono/packages/ai/src/providers/google.ts
Mario Zechner ecd240f636 Define own GoogleThinkingLevel type instead of importing from @google/genai
- Add GoogleThinkingLevel type mirroring Google's ThinkingLevel enum
- Update GoogleGeminiCliOptions and GoogleOptions to use our type
- Cast to any when assigning to Google SDK's ThinkingConfig
2025-12-30 22:42:25 +01:00

324 lines
9.6 KiB
TypeScript

import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
} from "@google/genai";
import { calculateCost } from "../models.js";
import { getEnvApiKey } from "../stream.js";
import type {
Api,
AssistantMessage,
Context,
Model,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import { convertMessages, convertTools, mapStopReason, mapToolChoice } from "./google-shared.js";
export interface GoogleOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
}
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogle: StreamFunction<"google-generative-ai"> = (
model: Model<"google-generative-ai">,
context: Context,
options?: GoogleOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-generative-ai" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey);
const params = buildParams(model, context, options);
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = part.thought === true;
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = part.thoughtSignature;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
// Generate unique ID if not provided or if it's a duplicate
const providedId = part.functionCall.id;
const needsNewId =
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments: part.functionCall.args as Record<string, any>,
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
};
output.content.push(toolCall);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unkown error ocurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
function createClient(model: Model<"google-generative-ai">, apiKey?: string): GoogleGenAI {
const httpOptions: { baseUrl?: string; apiVersion?: string; headers?: Record<string, string> } = {};
if (model.baseUrl) {
httpOptions.baseUrl = model.baseUrl;
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
}
if (model.headers) {
httpOptions.headers = model.headers;
}
return new GoogleGenAI({
apiKey,
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
});
}
function buildParams(
model: Model<"google-generative-ai">,
context: Context,
options: GoogleOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}