mirror of
https://github.com/getcompanion-ai/co-mono.git
synced 2026-04-16 21:03:42 +00:00
Google streaming may emit thoughtSignature without thought=true (including empty-text signature-only parts). Treat non-empty thoughtSignature as thinking to avoid leaking reasoning into normal text and retain signature across streaming deltas. Add unit test coverage.
365 lines
10 KiB
TypeScript
365 lines
10 KiB
TypeScript
import {
|
|
type GenerateContentConfig,
|
|
type GenerateContentParameters,
|
|
GoogleGenAI,
|
|
type ThinkingConfig,
|
|
ThinkingLevel,
|
|
} from "@google/genai";
|
|
import { calculateCost } from "../models.js";
|
|
import type {
|
|
Api,
|
|
AssistantMessage,
|
|
Context,
|
|
Model,
|
|
StreamFunction,
|
|
StreamOptions,
|
|
TextContent,
|
|
ThinkingContent,
|
|
ToolCall,
|
|
} from "../types.js";
|
|
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
|
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
|
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
|
import {
|
|
convertMessages,
|
|
convertTools,
|
|
isThinkingPart,
|
|
mapStopReason,
|
|
mapToolChoice,
|
|
retainThoughtSignature,
|
|
} from "./google-shared.js";
|
|
|
|
export interface GoogleVertexOptions extends StreamOptions {
|
|
toolChoice?: "auto" | "none" | "any";
|
|
thinking?: {
|
|
enabled: boolean;
|
|
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
|
level?: GoogleThinkingLevel;
|
|
};
|
|
project?: string;
|
|
location?: string;
|
|
}
|
|
|
|
const API_VERSION = "v1";
|
|
|
|
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
|
|
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
|
|
MINIMAL: ThinkingLevel.MINIMAL,
|
|
LOW: ThinkingLevel.LOW,
|
|
MEDIUM: ThinkingLevel.MEDIUM,
|
|
HIGH: ThinkingLevel.HIGH,
|
|
};
|
|
|
|
// Counter for generating unique tool call IDs
|
|
let toolCallCounter = 0;
|
|
|
|
export const streamGoogleVertex: StreamFunction<"google-vertex"> = (
|
|
model: Model<"google-vertex">,
|
|
context: Context,
|
|
options?: GoogleVertexOptions,
|
|
): AssistantMessageEventStream => {
|
|
const stream = new AssistantMessageEventStream();
|
|
|
|
(async () => {
|
|
const output: AssistantMessage = {
|
|
role: "assistant",
|
|
content: [],
|
|
api: "google-vertex" as Api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: "stop",
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
try {
|
|
const project = resolveProject(options);
|
|
const location = resolveLocation(options);
|
|
const client = createClient(model, project, location);
|
|
const params = buildParams(model, context, options);
|
|
const googleStream = await client.models.generateContentStream(params);
|
|
|
|
stream.push({ type: "start", partial: output });
|
|
let currentBlock: TextContent | ThinkingContent | null = null;
|
|
const blocks = output.content;
|
|
const blockIndex = () => blocks.length - 1;
|
|
for await (const chunk of googleStream) {
|
|
const candidate = chunk.candidates?.[0];
|
|
if (candidate?.content?.parts) {
|
|
for (const part of candidate.content.parts) {
|
|
if (part.text !== undefined) {
|
|
const isThinking = isThinkingPart(part);
|
|
if (
|
|
!currentBlock ||
|
|
(isThinking && currentBlock.type !== "thinking") ||
|
|
(!isThinking && currentBlock.type !== "text")
|
|
) {
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blocks.length - 1,
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
if (isThinking) {
|
|
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
|
output.content.push(currentBlock);
|
|
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
|
} else {
|
|
currentBlock = { type: "text", text: "" };
|
|
output.content.push(currentBlock);
|
|
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
|
}
|
|
}
|
|
if (currentBlock.type === "thinking") {
|
|
currentBlock.thinking += part.text;
|
|
currentBlock.thinkingSignature = retainThoughtSignature(
|
|
currentBlock.thinkingSignature,
|
|
part.thoughtSignature,
|
|
);
|
|
stream.push({
|
|
type: "thinking_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: part.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
currentBlock.text += part.text;
|
|
stream.push({
|
|
type: "text_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: part.text,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
|
|
if (part.functionCall) {
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
currentBlock = null;
|
|
}
|
|
|
|
const providedId = part.functionCall.id;
|
|
const needsNewId =
|
|
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
|
const toolCallId = needsNewId
|
|
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
|
: providedId;
|
|
|
|
const toolCall: ToolCall = {
|
|
type: "toolCall",
|
|
id: toolCallId,
|
|
name: part.functionCall.name || "",
|
|
arguments: part.functionCall.args as Record<string, any>,
|
|
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
|
};
|
|
|
|
output.content.push(toolCall);
|
|
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
|
stream.push({
|
|
type: "toolcall_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: JSON.stringify(toolCall.arguments),
|
|
partial: output,
|
|
});
|
|
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
|
}
|
|
}
|
|
}
|
|
|
|
if (candidate?.finishReason) {
|
|
output.stopReason = mapStopReason(candidate.finishReason);
|
|
if (output.content.some((b) => b.type === "toolCall")) {
|
|
output.stopReason = "toolUse";
|
|
}
|
|
}
|
|
|
|
if (chunk.usageMetadata) {
|
|
output.usage = {
|
|
input: chunk.usageMetadata.promptTokenCount || 0,
|
|
output:
|
|
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
|
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
|
cacheWrite: 0,
|
|
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
|
cost: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
total: 0,
|
|
},
|
|
};
|
|
calculateCost(model, output.usage);
|
|
}
|
|
}
|
|
|
|
if (currentBlock) {
|
|
if (currentBlock.type === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: blockIndex(),
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
|
|
if (options?.signal?.aborted) {
|
|
throw new Error("Request was aborted");
|
|
}
|
|
|
|
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
|
throw new Error("An unknown error occurred");
|
|
}
|
|
|
|
stream.push({ type: "done", reason: output.stopReason, message: output });
|
|
stream.end();
|
|
} catch (error) {
|
|
// Remove internal index property used during streaming
|
|
for (const block of output.content) {
|
|
if ("index" in block) {
|
|
delete (block as { index?: number }).index;
|
|
}
|
|
}
|
|
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
|
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
|
stream.push({ type: "error", reason: output.stopReason, error: output });
|
|
stream.end();
|
|
}
|
|
})();
|
|
|
|
return stream;
|
|
};
|
|
|
|
function createClient(model: Model<"google-vertex">, project: string, location: string): GoogleGenAI {
|
|
const httpOptions: { headers?: Record<string, string> } = {};
|
|
|
|
if (model.headers) {
|
|
httpOptions.headers = { ...model.headers };
|
|
}
|
|
|
|
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
|
|
|
return new GoogleGenAI({
|
|
vertexai: true,
|
|
project,
|
|
location,
|
|
apiVersion: API_VERSION,
|
|
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
|
});
|
|
}
|
|
|
|
function resolveProject(options?: GoogleVertexOptions): string {
|
|
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
|
|
if (!project) {
|
|
throw new Error(
|
|
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
|
);
|
|
}
|
|
return project;
|
|
}
|
|
|
|
function resolveLocation(options?: GoogleVertexOptions): string {
|
|
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
|
if (!location) {
|
|
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
|
|
}
|
|
return location;
|
|
}
|
|
|
|
function buildParams(
|
|
model: Model<"google-vertex">,
|
|
context: Context,
|
|
options: GoogleVertexOptions = {},
|
|
): GenerateContentParameters {
|
|
const contents = convertMessages(model, context);
|
|
|
|
const generationConfig: GenerateContentConfig = {};
|
|
if (options.temperature !== undefined) {
|
|
generationConfig.temperature = options.temperature;
|
|
}
|
|
if (options.maxTokens !== undefined) {
|
|
generationConfig.maxOutputTokens = options.maxTokens;
|
|
}
|
|
|
|
const config: GenerateContentConfig = {
|
|
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
|
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
|
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
|
};
|
|
|
|
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
|
config.toolConfig = {
|
|
functionCallingConfig: {
|
|
mode: mapToolChoice(options.toolChoice),
|
|
},
|
|
};
|
|
} else {
|
|
config.toolConfig = undefined;
|
|
}
|
|
|
|
if (options.thinking?.enabled && model.reasoning) {
|
|
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
|
if (options.thinking.level !== undefined) {
|
|
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
|
|
} else if (options.thinking.budgetTokens !== undefined) {
|
|
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
|
}
|
|
config.thinkingConfig = thinkingConfig;
|
|
}
|
|
|
|
if (options.signal) {
|
|
if (options.signal.aborted) {
|
|
throw new Error("Request aborted");
|
|
}
|
|
config.abortSignal = options.signal;
|
|
}
|
|
|
|
const params: GenerateContentParameters = {
|
|
model: model.id,
|
|
contents,
|
|
config,
|
|
};
|
|
|
|
return params;
|
|
}
|