mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 15:02:32 +00:00
505 lines
14 KiB
TypeScript
505 lines
14 KiB
TypeScript
import {
|
|
type Content,
|
|
FinishReason,
|
|
FunctionCallingConfigMode,
|
|
type GenerateContentConfig,
|
|
type GenerateContentParameters,
|
|
GoogleGenAI,
|
|
type Part,
|
|
type Schema,
|
|
type ThinkingConfig,
|
|
type ThinkingLevel,
|
|
} from "@google/genai";
|
|
import { calculateCost } from "../models.js";
|
|
import type {
|
|
Api,
|
|
AssistantMessage,
|
|
Context,
|
|
ImageContent,
|
|
Model,
|
|
StopReason,
|
|
StreamFunction,
|
|
StreamOptions,
|
|
TextContent,
|
|
ThinkingContent,
|
|
Tool,
|
|
ToolCall,
|
|
} from "../types.js";
|
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
|
|
|
import { transformMessages } from "./transorm-messages.js";
|
|
|
|
export interface GoogleOptions extends StreamOptions {
|
|
toolChoice?: "auto" | "none" | "any";
|
|
thinking?: {
|
|
enabled: boolean;
|
|
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
|
level?: ThinkingLevel;
|
|
};
|
|
}
|
|
|
|
// Counter for generating unique tool call IDs
|
|
let toolCallCounter = 0;
|
|
|
|
export const streamGoogle: StreamFunction<"google-generative-ai"> = (
|
|
model: Model<"google-generative-ai">,
|
|
context: Context,
|
|
options?: GoogleOptions,
|
|
): AssistantMessageEventStream => {
|
|
const stream = new AssistantMessageEventStream();
|
|
|
|
(async () => {
|
|
const output: AssistantMessage = {
|
|
role: "assistant",
|
|
content: [],
|
|
api: "google-generative-ai" as Api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: "stop",
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
try {
|
|
const client = createClient(model, options?.apiKey);
|
|
const params = buildParams(model, context, options);
|
|
const googleStream = await client.models.generateContentStream(params);
|
|
|
|
stream.push({ type: "start", partial: output });
|
|
let currentBlock: TextContent | ThinkingContent | null = null;
|
|
const blocks = output.content;
|
|
const blockIndex = () => blocks.length - 1;
|
|
for await (const chunk of googleStream) {
|
|
const candidate = chunk.candidates?.[0];
|
|
if (candidate?.content?.parts) {
|
|
for (const part of candidate.content.parts) {
|
|
if (part.text !== undefined) {
|
|
const isThinking = part.thought === true;
|
|
if (
|
|
!currentBlock ||
|
|
(isThinking && currentBlock.type !== "thinking") ||
|
|
(!isThinking && currentBlock.type !== "text")
|
|
) {
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blocks.length - 1,
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
if (isThinking) {
|
|
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
|
output.content.push(currentBlock);
|
|
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
|
} else {
|
|
currentBlock = { type: "text", text: "" };
|
|
output.content.push(currentBlock);
|
|
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
|
}
|
|
}
|
|
if (currentBlock.type === "thinking") {
|
|
currentBlock.thinking += part.text;
|
|
currentBlock.thinkingSignature = part.thoughtSignature;
|
|
stream.push({
|
|
type: "thinking_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: part.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
currentBlock.text += part.text;
|
|
stream.push({
|
|
type: "text_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: part.text,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
|
|
if (part.functionCall) {
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
currentBlock = null;
|
|
}
|
|
|
|
// Generate unique ID if not provided or if it's a duplicate
|
|
const providedId = part.functionCall.id;
|
|
const needsNewId =
|
|
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
|
const toolCallId = needsNewId
|
|
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
|
: providedId;
|
|
|
|
const toolCall: ToolCall = {
|
|
type: "toolCall",
|
|
id: toolCallId,
|
|
name: part.functionCall.name || "",
|
|
arguments: part.functionCall.args as Record<string, any>,
|
|
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
|
};
|
|
|
|
output.content.push(toolCall);
|
|
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
|
stream.push({
|
|
type: "toolcall_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: JSON.stringify(toolCall.arguments),
|
|
partial: output,
|
|
});
|
|
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
|
}
|
|
}
|
|
}
|
|
|
|
if (candidate?.finishReason) {
|
|
output.stopReason = mapStopReason(candidate.finishReason);
|
|
if (output.content.some((b) => b.type === "toolCall")) {
|
|
output.stopReason = "toolUse";
|
|
}
|
|
}
|
|
|
|
if (chunk.usageMetadata) {
|
|
output.usage = {
|
|
input: chunk.usageMetadata.promptTokenCount || 0,
|
|
output:
|
|
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
|
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
|
cacheWrite: 0,
|
|
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
|
cost: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
total: 0,
|
|
},
|
|
};
|
|
calculateCost(model, output.usage);
|
|
}
|
|
}
|
|
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
|
|
if (options?.signal?.aborted) {
|
|
throw new Error("Request was aborted");
|
|
}
|
|
|
|
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
|
throw new Error("An unkown error ocurred");
|
|
}
|
|
|
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
|
stream.end();
|
|
} catch (error) {
|
|
// Remove internal index property used during streaming
|
|
for (const block of output.content) {
|
|
if ("index" in block) {
|
|
delete (block as { index?: number }).index;
|
|
}
|
|
}
|
|
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
|
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
|
stream.push({ type: "error", reason: output.stopReason, error: output });
|
|
stream.end();
|
|
}
|
|
})();
|
|
|
|
return stream;
|
|
};
|
|
|
|
function createClient(model: Model<"google-generative-ai">, apiKey?: string): GoogleGenAI {
|
|
if (!apiKey) {
|
|
if (!process.env.GEMINI_API_KEY) {
|
|
throw new Error(
|
|
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.",
|
|
);
|
|
}
|
|
apiKey = process.env.GEMINI_API_KEY;
|
|
}
|
|
|
|
const httpOptions: { baseUrl?: string; apiVersion?: string; headers?: Record<string, string> } = {};
|
|
if (model.baseUrl) {
|
|
httpOptions.baseUrl = model.baseUrl;
|
|
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
|
|
}
|
|
if (model.headers) {
|
|
httpOptions.headers = model.headers;
|
|
}
|
|
|
|
return new GoogleGenAI({
|
|
apiKey,
|
|
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
|
|
});
|
|
}
|
|
|
|
function buildParams(
|
|
model: Model<"google-generative-ai">,
|
|
context: Context,
|
|
options: GoogleOptions = {},
|
|
): GenerateContentParameters {
|
|
const contents = convertMessages(model, context);
|
|
|
|
const generationConfig: GenerateContentConfig = {};
|
|
if (options.temperature !== undefined) {
|
|
generationConfig.temperature = options.temperature;
|
|
}
|
|
if (options.maxTokens !== undefined) {
|
|
generationConfig.maxOutputTokens = options.maxTokens;
|
|
}
|
|
|
|
const config: GenerateContentConfig = {
|
|
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
|
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
|
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
|
};
|
|
|
|
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
|
config.toolConfig = {
|
|
functionCallingConfig: {
|
|
mode: mapToolChoice(options.toolChoice),
|
|
},
|
|
};
|
|
} else {
|
|
config.toolConfig = undefined;
|
|
}
|
|
|
|
if (options.thinking?.enabled && model.reasoning) {
|
|
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
|
if (options.thinking.level !== undefined) {
|
|
thinkingConfig.thinkingLevel = options.thinking.level;
|
|
} else if (options.thinking.budgetTokens !== undefined) {
|
|
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
|
}
|
|
config.thinkingConfig = thinkingConfig;
|
|
}
|
|
|
|
if (options.signal) {
|
|
if (options.signal.aborted) {
|
|
throw new Error("Request aborted");
|
|
}
|
|
config.abortSignal = options.signal;
|
|
}
|
|
|
|
const params: GenerateContentParameters = {
|
|
model: model.id,
|
|
contents,
|
|
config,
|
|
};
|
|
|
|
return params;
|
|
}
|
|
function convertMessages(model: Model<"google-generative-ai">, context: Context): Content[] {
|
|
const contents: Content[] = [];
|
|
const transformedMessages = transformMessages(context.messages, model);
|
|
|
|
for (const msg of transformedMessages) {
|
|
if (msg.role === "user") {
|
|
if (typeof msg.content === "string") {
|
|
contents.push({
|
|
role: "user",
|
|
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
|
});
|
|
} else {
|
|
const parts: Part[] = msg.content.map((item) => {
|
|
if (item.type === "text") {
|
|
return { text: sanitizeSurrogates(item.text) };
|
|
} else {
|
|
return {
|
|
inlineData: {
|
|
mimeType: item.mimeType,
|
|
data: item.data,
|
|
},
|
|
};
|
|
}
|
|
});
|
|
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
|
if (filteredParts.length === 0) continue;
|
|
contents.push({
|
|
role: "user",
|
|
parts: filteredParts,
|
|
});
|
|
}
|
|
} else if (msg.role === "assistant") {
|
|
const parts: Part[] = [];
|
|
|
|
for (const block of msg.content) {
|
|
if (block.type === "text") {
|
|
parts.push({ text: sanitizeSurrogates(block.text) });
|
|
} else if (block.type === "thinking") {
|
|
const thinkingPart: Part = {
|
|
thought: true,
|
|
thoughtSignature: block.thinkingSignature,
|
|
text: sanitizeSurrogates(block.thinking),
|
|
};
|
|
parts.push(thinkingPart);
|
|
} else if (block.type === "toolCall") {
|
|
const part: Part = {
|
|
functionCall: {
|
|
id: block.id,
|
|
name: block.name,
|
|
args: block.arguments,
|
|
},
|
|
};
|
|
if (block.thoughtSignature) {
|
|
part.thoughtSignature = block.thoughtSignature;
|
|
}
|
|
parts.push(part);
|
|
}
|
|
}
|
|
|
|
if (parts.length === 0) continue;
|
|
contents.push({
|
|
role: "model",
|
|
parts,
|
|
});
|
|
} else if (msg.role === "toolResult") {
|
|
// Build parts array with functionResponse and/or images
|
|
const parts: Part[] = [];
|
|
|
|
// Extract text and image content
|
|
const textContent = msg.content.filter((c): c is TextContent => c.type === "text");
|
|
const textResult = textContent.map((c) => c.text).join("\n");
|
|
const imageContent = model.input.includes("image")
|
|
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
|
: [];
|
|
|
|
// Always add functionResponse with text result (or placeholder if only images)
|
|
const hasText = textResult.length > 0;
|
|
const hasImages = imageContent.length > 0;
|
|
|
|
// Use "output" key for success, "error" key for errors as per SDK documentation
|
|
const responseValue = hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "";
|
|
|
|
parts.push({
|
|
functionResponse: {
|
|
id: msg.toolCallId,
|
|
name: msg.toolName,
|
|
response: msg.isError ? { error: responseValue } : { output: responseValue },
|
|
},
|
|
});
|
|
|
|
// Add any images as inlineData parts
|
|
for (const imageBlock of imageContent) {
|
|
parts.push({
|
|
inlineData: {
|
|
mimeType: imageBlock.mimeType,
|
|
data: imageBlock.data,
|
|
},
|
|
});
|
|
}
|
|
|
|
contents.push({
|
|
role: "user",
|
|
parts,
|
|
});
|
|
}
|
|
}
|
|
|
|
return contents;
|
|
}
|
|
|
|
function convertTools(
|
|
tools: Tool[],
|
|
): { functionDeclarations: { name: string; description?: string; parameters: Schema }[] }[] | undefined {
|
|
if (tools.length === 0) return undefined;
|
|
return [
|
|
{
|
|
functionDeclarations: tools.map((tool) => ({
|
|
name: tool.name,
|
|
description: tool.description,
|
|
parameters: tool.parameters as Schema, // TypeBox generates JSON Schema compatible with SDK Schema type
|
|
})),
|
|
},
|
|
];
|
|
}
|
|
|
|
function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
|
switch (choice) {
|
|
case "auto":
|
|
return FunctionCallingConfigMode.AUTO;
|
|
case "none":
|
|
return FunctionCallingConfigMode.NONE;
|
|
case "any":
|
|
return FunctionCallingConfigMode.ANY;
|
|
default:
|
|
return FunctionCallingConfigMode.AUTO;
|
|
}
|
|
}
|
|
|
|
function mapStopReason(reason: FinishReason): StopReason {
|
|
switch (reason) {
|
|
case FinishReason.STOP:
|
|
return "stop";
|
|
case FinishReason.MAX_TOKENS:
|
|
return "length";
|
|
case FinishReason.BLOCKLIST:
|
|
case FinishReason.PROHIBITED_CONTENT:
|
|
case FinishReason.SPII:
|
|
case FinishReason.SAFETY:
|
|
case FinishReason.IMAGE_SAFETY:
|
|
case FinishReason.IMAGE_PROHIBITED_CONTENT:
|
|
case FinishReason.IMAGE_RECITATION:
|
|
case FinishReason.IMAGE_OTHER:
|
|
case FinishReason.RECITATION:
|
|
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
|
case FinishReason.OTHER:
|
|
case FinishReason.LANGUAGE:
|
|
case FinishReason.MALFORMED_FUNCTION_CALL:
|
|
case FinishReason.UNEXPECTED_TOOL_CALL:
|
|
case FinishReason.NO_IMAGE:
|
|
return "error";
|
|
default: {
|
|
const _exhaustive: never = reason;
|
|
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
|
}
|
|
}
|
|
}
|