import { type Content, FinishReason, FunctionCallingConfigMode, type GenerateContentConfig, type GenerateContentParameters, GoogleGenAI, type Part, } from "@google/genai"; import { calculateCost } from "../models.js"; import type { Api, AssistantMessage, Context, Model, StopReason, StreamFunction, StreamOptions, TextContent, ThinkingContent, Tool, ToolCall, } from "../types.js"; import { AssistantMessageEventStream } from "../utils/event-stream.js"; import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; import { validateToolArguments } from "../utils/validation.js"; import { transformMessages } from "./transorm-messages.js"; export interface GoogleOptions extends StreamOptions { toolChoice?: "auto" | "none" | "any"; thinking?: { enabled: boolean; budgetTokens?: number; // -1 for dynamic, 0 to disable }; } // Counter for generating unique tool call IDs let toolCallCounter = 0; export const streamGoogle: StreamFunction<"google-generative-ai"> = ( model: Model<"google-generative-ai">, context: Context, options?: GoogleOptions, ): AssistantMessageEventStream => { const stream = new AssistantMessageEventStream(); (async () => { const output: AssistantMessage = { role: "assistant", content: [], api: "google-generative-ai" as Api, provider: model.provider, model: model.id, usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, }, stopReason: "stop", timestamp: Date.now(), }; try { const client = createClient(model, options?.apiKey); const params = buildParams(model, context, options); const googleStream = await client.models.generateContentStream(params); stream.push({ type: "start", partial: output }); let currentBlock: TextContent | ThinkingContent | null = null; const blocks = output.content; const blockIndex = () => blocks.length - 1; for await (const chunk of googleStream) { const candidate = chunk.candidates?.[0]; if (candidate?.content?.parts) { for (const part of candidate.content.parts) { if (part.text !== undefined) { const isThinking = part.thought === true; if ( !currentBlock || (isThinking && currentBlock.type !== "thinking") || (!isThinking && currentBlock.type !== "text") ) { if (currentBlock) { if (currentBlock.type === "text") { stream.push({ type: "text_end", contentIndex: blocks.length - 1, content: currentBlock.text, partial: output, }); } else { stream.push({ type: "thinking_end", contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); } } if (isThinking) { currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined }; output.content.push(currentBlock); stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); } else { currentBlock = { type: "text", text: "" }; output.content.push(currentBlock); stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); } } if (currentBlock.type === "thinking") { currentBlock.thinking += part.text; currentBlock.thinkingSignature = part.thoughtSignature; stream.push({ type: "thinking_delta", contentIndex: blockIndex(), delta: part.text, partial: output, }); } else { currentBlock.text += part.text; stream.push({ type: "text_delta", contentIndex: blockIndex(), delta: part.text, partial: output, }); } } if (part.functionCall) { if (currentBlock) { if (currentBlock.type === "text") { stream.push({ type: "text_end", contentIndex: blockIndex(), content: currentBlock.text, partial: output, }); } else { stream.push({ type: "thinking_end", contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); } currentBlock = null; } // Generate unique ID if not provided or if it's a duplicate const providedId = part.functionCall.id; const needsNewId = !providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId); const toolCallId = needsNewId ? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}` : providedId; const toolCall: ToolCall = { type: "toolCall", id: toolCallId, name: part.functionCall.name || "", arguments: part.functionCall.args as Record, ...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }), }; // Validate tool arguments if tool definition is available if (context.tools) { const tool = context.tools.find((t) => t.name === toolCall.name); if (tool) { toolCall.arguments = validateToolArguments(tool, toolCall); } } output.content.push(toolCall); stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); stream.push({ type: "toolcall_delta", contentIndex: blockIndex(), delta: JSON.stringify(toolCall.arguments), partial: output, }); stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output }); } } } if (candidate?.finishReason) { output.stopReason = mapStopReason(candidate.finishReason); if (output.content.some((b) => b.type === "toolCall")) { output.stopReason = "toolUse"; } } if (chunk.usageMetadata) { output.usage = { input: chunk.usageMetadata.promptTokenCount || 0, output: (chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0), cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0, cacheWrite: 0, totalTokens: chunk.usageMetadata.totalTokenCount || 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0, }, }; calculateCost(model, output.usage); } } if (currentBlock) { if (currentBlock.type === "text") { stream.push({ type: "text_end", contentIndex: blockIndex(), content: currentBlock.text, partial: output, }); } else { stream.push({ type: "thinking_end", contentIndex: blockIndex(), content: currentBlock.thinking, partial: output, }); } } if (options?.signal?.aborted) { throw new Error("Request was aborted"); } if (output.stopReason === "aborted" || output.stopReason === "error") { throw new Error("An unkown error ocurred"); } stream.push({ type: "done", reason: output.stopReason, message: output }); stream.end(); } catch (error) { for (const block of output.content) delete (block as any).index; output.stopReason = options?.signal?.aborted ? "aborted" : "error"; output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); stream.push({ type: "error", reason: output.stopReason, error: output }); stream.end(); } })(); return stream; }; function createClient(model: Model<"google-generative-ai">, apiKey?: string): GoogleGenAI { if (!apiKey) { if (!process.env.GEMINI_API_KEY) { throw new Error( "Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.", ); } apiKey = process.env.GEMINI_API_KEY; } return new GoogleGenAI({ apiKey, httpOptions: model.headers ? { headers: model.headers } : undefined, }); } function buildParams( model: Model<"google-generative-ai">, context: Context, options: GoogleOptions = {}, ): GenerateContentParameters { const contents = convertMessages(model, context); const generationConfig: GenerateContentConfig = {}; if (options.temperature !== undefined) { generationConfig.temperature = options.temperature; } if (options.maxTokens !== undefined) { generationConfig.maxOutputTokens = options.maxTokens; } const config: GenerateContentConfig = { ...(Object.keys(generationConfig).length > 0 && generationConfig), ...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }), ...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }), }; if (context.tools && context.tools.length > 0 && options.toolChoice) { config.toolConfig = { functionCallingConfig: { mode: mapToolChoice(options.toolChoice), }, }; } else { config.toolConfig = undefined; } if (options.thinking?.enabled && model.reasoning) { config.thinkingConfig = { includeThoughts: true, ...(options.thinking.budgetTokens !== undefined && { thinkingBudget: options.thinking.budgetTokens }), }; } if (options.signal) { if (options.signal.aborted) { throw new Error("Request aborted"); } config.abortSignal = options.signal; } const params: GenerateContentParameters = { model: model.id, contents, config, }; return params; } function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] { const contents: Content[] = []; const transformedMessages = transformMessages(context.messages, model); for (const msg of transformedMessages) { if (msg.role === "user") { if (typeof msg.content === "string") { contents.push({ role: "user", parts: [{ text: sanitizeSurrogates(msg.content) }], }); } else { const parts: Part[] = msg.content.map((item) => { if (item.type === "text") { return { text: sanitizeSurrogates(item.text) }; } else { return { inlineData: { mimeType: item.mimeType, data: item.data, }, }; } }); const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts; if (filteredParts.length === 0) continue; contents.push({ role: "user", parts: filteredParts, }); } } else if (msg.role === "assistant") { const parts: Part[] = []; for (const block of msg.content) { if (block.type === "text") { parts.push({ text: sanitizeSurrogates(block.text) }); } else if (block.type === "thinking") { const thinkingPart: Part = { thought: true, thoughtSignature: block.thinkingSignature, text: sanitizeSurrogates(block.thinking), }; parts.push(thinkingPart); } else if (block.type === "toolCall") { const part: Part = { functionCall: { id: block.id, name: block.name, args: block.arguments, }, }; if (block.thoughtSignature) { part.thoughtSignature = block.thoughtSignature; } parts.push(part); } } if (parts.length === 0) continue; contents.push({ role: "model", parts, }); } else if (msg.role === "toolResult") { // Build parts array with functionResponse and/or images const parts: Part[] = []; // Extract text and image content const textResult = msg.content .filter((c) => c.type === "text") .map((c) => (c as any).text) .join("\n"); const imageBlocks = model.input.includes("image") ? msg.content.filter((c) => c.type === "image") : []; // Always add functionResponse with text result (or placeholder if only images) const hasText = textResult.length > 0; const hasImages = imageBlocks.length > 0; parts.push({ functionResponse: { id: msg.toolCallId, name: msg.toolName, response: { result: hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "", isError: msg.isError, }, }, }); // Add any images as inlineData parts for (const imageBlock of imageBlocks) { parts.push({ inlineData: { mimeType: (imageBlock as any).mimeType, data: (imageBlock as any).data, }, }); } contents.push({ role: "user", parts, }); } } return contents; } function convertTools(tools: Tool[]): any[] | undefined { if (tools.length === 0) return undefined; return [ { functionDeclarations: tools.map((tool) => ({ name: tool.name, description: tool.description, parameters: tool.parameters as any, // TypeBox already generates JSON Schema })), }, ]; } function mapToolChoice(choice: string): FunctionCallingConfigMode { switch (choice) { case "auto": return FunctionCallingConfigMode.AUTO; case "none": return FunctionCallingConfigMode.NONE; case "any": return FunctionCallingConfigMode.ANY; default: return FunctionCallingConfigMode.AUTO; } } function mapStopReason(reason: FinishReason): StopReason { switch (reason) { case FinishReason.STOP: return "stop"; case FinishReason.MAX_TOKENS: return "length"; case FinishReason.BLOCKLIST: case FinishReason.PROHIBITED_CONTENT: case FinishReason.SPII: case FinishReason.SAFETY: case FinishReason.IMAGE_SAFETY: case FinishReason.IMAGE_PROHIBITED_CONTENT: case FinishReason.RECITATION: case FinishReason.FINISH_REASON_UNSPECIFIED: case FinishReason.OTHER: case FinishReason.LANGUAGE: case FinishReason.MALFORMED_FUNCTION_CALL: case FinishReason.UNEXPECTED_TOOL_CALL: case FinishReason.NO_IMAGE: return "error"; default: { const _exhaustive: never = reason; throw new Error(`Unhandled stop reason: ${_exhaustive}`); } } }