diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index 88e15624..4c5e8377 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -30,6 +30,10 @@ - **OpenAI completions empty content blocks**: Empty text or thinking blocks in assistant messages are now filtered out before sending to the OpenAI completions API, preventing validation errors. ([#344](https://github.com/badlogic/pi-mono/pull/344) by [@default-anton](https://github.com/default-anton)) - **zAi provider API mapping**: Fixed zAi models to use `openai-completions` API with correct base URL (`https://api.z.ai/api/coding/paas/v4`) instead of incorrect Anthropic API mapping. ([#344](https://github.com/badlogic/pi-mono/pull/344), [#358](https://github.com/badlogic/pi-mono/pull/358) by [@default-anton](https://github.com/default-anton)) +### Added + +- Added Vertex AI provider with ADC support, Gemini model catalog, and test coverage. + ## [0.28.0] - 2025-12-25 ### Breaking Changes diff --git a/packages/ai/README.md b/packages/ai/README.md index 89cdcfc4..1c3821e2 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -9,6 +9,7 @@ Unified LLM API with automatic model discovery, provider configuration, token an - **OpenAI** - **Anthropic** - **Google** +- **Vertex AI** (Gemini via Vertex AI) - **Mistral** - **Groq** - **Cerebras** @@ -848,6 +849,10 @@ Several providers require OAuth authentication instead of static API keys: - **Google Gemini CLI** (Free Gemini 2.0/2.5 via Google Cloud Code Assist) - **Antigravity** (Free Gemini 3, Claude, GPT-OSS via Google Cloud) +### Vertex AI (ADC) + +Vertex AI models use Application Default Credentials. Run `gcloud auth application-default login`, set `GOOGLE_CLOUD_PROJECT` (or `GCLOUD_PROJECT`), and `GOOGLE_CLOUD_LOCATION`. You can also pass `project`/`location` in the call options. + ### CLI Login The quickest way to authenticate: @@ -871,11 +876,11 @@ import { loginGitHubCopilot, loginGeminiCli, loginAntigravity, - + // Token management refreshOAuthToken, // (provider, credentials) => new credentials getOAuthApiKey, // (provider, credentialsMap) => { newCredentials, apiKey } | null - + // Types type OAuthProvider, // 'anthropic' | 'github-copilot' | 'google-gemini-cli' | 'google-antigravity' type OAuthCredentials, diff --git a/packages/ai/scripts/generate-models.ts b/packages/ai/scripts/generate-models.ts index d1164d6f..36128519 100644 --- a/packages/ai/scripts/generate-models.ts +++ b/packages/ai/scripts/generate-models.ts @@ -644,6 +644,143 @@ async function generateModels() { ]; allModels.push(...antigravityModels); + const VERTEX_BASE_URL = "https://{location}-aiplatform.googleapis.com"; + const vertexModels: Model<"google-vertex">[] = [ + { + id: "gemini-3-pro-preview", + name: "Gemini 3 Pro Preview (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 2, output: 12, cacheRead: 0.2, cacheWrite: 0 }, + contextWindow: 1000000, + maxTokens: 64000, + }, + { + id: "gemini-3-flash-preview", + name: "Gemini 3 Flash Preview (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 0.5, output: 3, cacheRead: 0.05, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-2.0-flash", + name: "Gemini 2.0 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: false, + input: ["text", "image"], + cost: { input: 0.1, output: 0.4, cacheRead: 0.025, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 8192, + }, + { + id: "gemini-2.0-flash-lite", + name: "Gemini 2.0 Flash Lite (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 0.1, output: 0.4, cacheRead: 0.025, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-2.5-pro", + name: "Gemini 2.5 Pro (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 1.25, output: 10, cacheRead: 0.31, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-2.5-flash", + name: "Gemini 2.5 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 0.3, output: 2.5, cacheRead: 0.075, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-2.5-flash-lite-preview-09-2025", + name: "Gemini 2.5 Flash Lite Preview 09-25 (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 0.1, output: 0.4, cacheRead: 0.025, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-2.5-flash-lite", + name: "Gemini 2.5 Flash Lite (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: true, + input: ["text", "image"], + cost: { input: 0.1, output: 0.4, cacheRead: 0.025, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: 65536, + }, + { + id: "gemini-1.5-pro", + name: "Gemini 1.5 Pro (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: false, + input: ["text", "image"], + cost: { input: 1.25, output: 5, cacheRead: 0.3125, cacheWrite: 0 }, + contextWindow: 1000000, + maxTokens: 8192, + }, + { + id: "gemini-1.5-flash", + name: "Gemini 1.5 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: false, + input: ["text", "image"], + cost: { input: 0.075, output: 0.3, cacheRead: 0.01875, cacheWrite: 0 }, + contextWindow: 1000000, + maxTokens: 8192, + }, + { + id: "gemini-1.5-flash-8b", + name: "Gemini 1.5 Flash-8B (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: VERTEX_BASE_URL, + reasoning: false, + input: ["text", "image"], + cost: { input: 0.0375, output: 0.15, cacheRead: 0.01, cacheWrite: 0 }, + contextWindow: 1000000, + maxTokens: 8192, + }, + ]; + allModels.push(...vertexModels); + // Group by provider and deduplicate by model ID const providers: Record>> = {}; for (const model of allModels) { diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 5ff971de..0fab4dbd 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -2,6 +2,7 @@ export * from "./models.js"; export * from "./providers/anthropic.js"; export * from "./providers/google.js"; export * from "./providers/google-gemini-cli.js"; +export * from "./providers/google-vertex.js"; export * from "./providers/openai-completions.js"; export * from "./providers/openai-responses.js"; export * from "./stream.js"; diff --git a/packages/ai/src/models.generated.ts b/packages/ai/src/models.generated.ts index b4c7d35c..81dda5e2 100644 --- a/packages/ai/src/models.generated.ts +++ b/packages/ai/src/models.generated.ts @@ -7102,4 +7102,193 @@ export const MODELS = { maxTokens: 131072, } satisfies Model<"openai-completions">, }, + "google-vertex": { + "gemini-3-pro-preview": { + id: "gemini-3-pro-preview", + name: "Gemini 3 Pro Preview (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 2, + output: 12, + cacheRead: 0.2, + cacheWrite: 0, + }, + contextWindow: 1000000, + maxTokens: 64000, + } satisfies Model<"google-vertex">, + "gemini-3-flash-preview": { + id: "gemini-3-flash-preview", + name: "Gemini 3 Flash Preview (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 0.5, + output: 3, + cacheRead: 0.05, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-2.0-flash": { + id: "gemini-2.0-flash", + name: "Gemini 2.0 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: false, + input: ["text", "image"], + cost: { + input: 0.1, + output: 0.4, + cacheRead: 0.025, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 8192, + } satisfies Model<"google-vertex">, + "gemini-2.0-flash-lite": { + id: "gemini-2.0-flash-lite", + name: "Gemini 2.0 Flash Lite (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 0.1, + output: 0.4, + cacheRead: 0.025, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-2.5-pro": { + id: "gemini-2.5-pro", + name: "Gemini 2.5 Pro (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 1.25, + output: 10, + cacheRead: 0.31, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-2.5-flash": { + id: "gemini-2.5-flash", + name: "Gemini 2.5 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 0.3, + output: 2.5, + cacheRead: 0.075, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-2.5-flash-lite-preview-09-2025": { + id: "gemini-2.5-flash-lite-preview-09-2025", + name: "Gemini 2.5 Flash Lite Preview 09-25 (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 0.1, + output: 0.4, + cacheRead: 0.025, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-2.5-flash-lite": { + id: "gemini-2.5-flash-lite", + name: "Gemini 2.5 Flash Lite (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: true, + input: ["text", "image"], + cost: { + input: 0.1, + output: 0.4, + cacheRead: 0.025, + cacheWrite: 0, + }, + contextWindow: 1048576, + maxTokens: 65536, + } satisfies Model<"google-vertex">, + "gemini-1.5-pro": { + id: "gemini-1.5-pro", + name: "Gemini 1.5 Pro (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: false, + input: ["text", "image"], + cost: { + input: 1.25, + output: 5, + cacheRead: 0.3125, + cacheWrite: 0, + }, + contextWindow: 1000000, + maxTokens: 8192, + } satisfies Model<"google-vertex">, + "gemini-1.5-flash": { + id: "gemini-1.5-flash", + name: "Gemini 1.5 Flash (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: false, + input: ["text", "image"], + cost: { + input: 0.075, + output: 0.3, + cacheRead: 0.01875, + cacheWrite: 0, + }, + contextWindow: 1000000, + maxTokens: 8192, + } satisfies Model<"google-vertex">, + "gemini-1.5-flash-8b": { + id: "gemini-1.5-flash-8b", + name: "Gemini 1.5 Flash-8B (Vertex)", + api: "google-vertex", + provider: "google-vertex", + baseUrl: "https://{location}-aiplatform.googleapis.com", + reasoning: false, + input: ["text", "image"], + cost: { + input: 0.0375, + output: 0.15, + cacheRead: 0.01, + cacheWrite: 0, + }, + contextWindow: 1000000, + maxTokens: 8192, + } satisfies Model<"google-vertex">, + }, } as const; diff --git a/packages/ai/src/providers/google-shared.ts b/packages/ai/src/providers/google-shared.ts index e8b00d2a..f0b2c3e4 100644 --- a/packages/ai/src/providers/google-shared.ts +++ b/packages/ai/src/providers/google-shared.ts @@ -7,7 +7,7 @@ import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; import { transformMessages } from "./transorm-messages.js"; -type GoogleApiType = "google-generative-ai" | "google-gemini-cli"; +type GoogleApiType = "google-generative-ai" | "google-gemini-cli" | "google-vertex"; /** * Convert internal messages to Gemini Content[] format. @@ -73,6 +73,9 @@ export function convertMessages(model: Model, contex args: block.arguments, }, }; + if (model.provider === "google-vertex" && part?.functionCall?.id) { + delete part.functionCall.id; // Vertex AI does not support 'id' in functionCall + } if (block.thoughtSignature) { part.thoughtSignature = block.thoughtSignature; } @@ -121,6 +124,10 @@ export function convertMessages(model: Model, contex }, }; + if (model.provider === "google-vertex" && functionResponsePart.functionResponse?.id) { + delete functionResponsePart.functionResponse.id; // Vertex AI does not support 'id' in functionResponse + } + // Cloud Code Assist API requires all function responses to be in a single user turn. // Check if the last content is already a user turn with function responses and merge. const lastContent = contents[contents.length - 1]; diff --git a/packages/ai/src/providers/google-vertex.ts b/packages/ai/src/providers/google-vertex.ts new file mode 100644 index 00000000..4c136d61 --- /dev/null +++ b/packages/ai/src/providers/google-vertex.ts @@ -0,0 +1,346 @@ +import { + type GenerateContentConfig, + type GenerateContentParameters, + GoogleGenAI, + type ThinkingConfig, + type ThinkingLevel, +} from "@google/genai"; +import { calculateCost } from "../models.js"; +import type { + Api, + AssistantMessage, + Context, + Model, + StreamFunction, + StreamOptions, + TextContent, + ThinkingContent, + ToolCall, +} from "../types.js"; +import { AssistantMessageEventStream } from "../utils/event-stream.js"; +import { sanitizeSurrogates } from "../utils/sanitize-unicode.js"; +import { convertMessages, convertTools, mapStopReason, mapToolChoice } from "./google-shared.js"; + +export interface GoogleVertexOptions extends StreamOptions { + toolChoice?: "auto" | "none" | "any"; + thinking?: { + enabled: boolean; + budgetTokens?: number; // -1 for dynamic, 0 to disable + level?: ThinkingLevel; + }; + project?: string; + location?: string; +} + +const API_VERSION = "v1"; + +// Counter for generating unique tool call IDs +let toolCallCounter = 0; + +export const streamGoogleVertex: StreamFunction<"google-vertex"> = ( + model: Model<"google-vertex">, + context: Context, + options?: GoogleVertexOptions, +): AssistantMessageEventStream => { + const stream = new AssistantMessageEventStream(); + + (async () => { + const output: AssistantMessage = { + role: "assistant", + content: [], + api: "google-vertex" as Api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + + try { + const project = resolveProject(options); + const location = resolveLocation(options); + const client = createClient(model, project, location); + const params = buildParams(model, context, options); + const googleStream = await client.models.generateContentStream(params); + + stream.push({ type: "start", partial: output }); + let currentBlock: TextContent | ThinkingContent | null = null; + const blocks = output.content; + const blockIndex = () => blocks.length - 1; + for await (const chunk of googleStream) { + const candidate = chunk.candidates?.[0]; + if (candidate?.content?.parts) { + for (const part of candidate.content.parts) { + if (part.text !== undefined) { + const isThinking = part.thought === true; + if ( + !currentBlock || + (isThinking && currentBlock.type !== "thinking") || + (!isThinking && currentBlock.type !== "text") + ) { + if (currentBlock) { + if (currentBlock.type === "text") { + stream.push({ + type: "text_end", + contentIndex: blocks.length - 1, + content: currentBlock.text, + partial: output, + }); + } else { + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: currentBlock.thinking, + partial: output, + }); + } + } + if (isThinking) { + currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined }; + output.content.push(currentBlock); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); + } else { + currentBlock = { type: "text", text: "" }; + output.content.push(currentBlock); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); + } + } + if (currentBlock.type === "thinking") { + currentBlock.thinking += part.text; + currentBlock.thinkingSignature = part.thoughtSignature; + stream.push({ + type: "thinking_delta", + contentIndex: blockIndex(), + delta: part.text, + partial: output, + }); + } else { + currentBlock.text += part.text; + stream.push({ + type: "text_delta", + contentIndex: blockIndex(), + delta: part.text, + partial: output, + }); + } + } + + if (part.functionCall) { + if (currentBlock) { + if (currentBlock.type === "text") { + stream.push({ + type: "text_end", + contentIndex: blockIndex(), + content: currentBlock.text, + partial: output, + }); + } else { + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: currentBlock.thinking, + partial: output, + }); + } + currentBlock = null; + } + + const providedId = part.functionCall.id; + const needsNewId = + !providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId); + const toolCallId = needsNewId + ? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}` + : providedId; + + const toolCall: ToolCall = { + type: "toolCall", + id: toolCallId, + name: part.functionCall.name || "", + arguments: part.functionCall.args as Record, + ...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }), + }; + + output.content.push(toolCall); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta: JSON.stringify(toolCall.arguments), + partial: output, + }); + stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output }); + } + } + } + + if (candidate?.finishReason) { + output.stopReason = mapStopReason(candidate.finishReason); + if (output.content.some((b) => b.type === "toolCall")) { + output.stopReason = "toolUse"; + } + } + + if (chunk.usageMetadata) { + output.usage = { + input: chunk.usageMetadata.promptTokenCount || 0, + output: + (chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0), + cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0, + cacheWrite: 0, + totalTokens: chunk.usageMetadata.totalTokenCount || 0, + cost: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + total: 0, + }, + }; + calculateCost(model, output.usage); + } + } + + if (currentBlock) { + if (currentBlock.type === "text") { + stream.push({ + type: "text_end", + contentIndex: blockIndex(), + content: currentBlock.text, + partial: output, + }); + } else { + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: currentBlock.thinking, + partial: output, + }); + } + } + + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + + if (output.stopReason === "aborted" || output.stopReason === "error") { + throw new Error("An unknown error occurred"); + } + + stream.push({ type: "done", reason: output.stopReason, message: output }); + stream.end(); + } catch (error) { + // Remove internal index property used during streaming + for (const block of output.content) { + if ("index" in block) { + delete (block as { index?: number }).index; + } + } + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason, error: output }); + stream.end(); + } + })(); + + return stream; +}; + +function createClient(model: Model<"google-vertex">, project: string, location: string): GoogleGenAI { + const httpOptions: { headers?: Record } = {}; + + if (model.headers) { + httpOptions.headers = { ...model.headers }; + } + + const hasHttpOptions = Object.values(httpOptions).some(Boolean); + + return new GoogleGenAI({ + vertexai: true, + project, + location, + apiVersion: API_VERSION, + httpOptions: hasHttpOptions ? httpOptions : undefined, + }); +} + +function resolveProject(options?: GoogleVertexOptions): string { + const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT; + if (!project) { + throw new Error( + "Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.", + ); + } + return project; +} + +function resolveLocation(options?: GoogleVertexOptions): string { + const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION; + if (!location) { + throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options."); + } + return location; +} + +function buildParams( + model: Model<"google-vertex">, + context: Context, + options: GoogleVertexOptions = {}, +): GenerateContentParameters { + const contents = convertMessages(model, context); + + const generationConfig: GenerateContentConfig = {}; + if (options.temperature !== undefined) { + generationConfig.temperature = options.temperature; + } + if (options.maxTokens !== undefined) { + generationConfig.maxOutputTokens = options.maxTokens; + } + + const config: GenerateContentConfig = { + ...(Object.keys(generationConfig).length > 0 && generationConfig), + ...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }), + ...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }), + }; + + if (context.tools && context.tools.length > 0 && options.toolChoice) { + config.toolConfig = { + functionCallingConfig: { + mode: mapToolChoice(options.toolChoice), + }, + }; + } else { + config.toolConfig = undefined; + } + + if (options.thinking?.enabled && model.reasoning) { + const thinkingConfig: ThinkingConfig = { includeThoughts: true }; + if (options.thinking.level !== undefined) { + thinkingConfig.thinkingLevel = options.thinking.level; + } else if (options.thinking.budgetTokens !== undefined) { + thinkingConfig.thinkingBudget = options.thinking.budgetTokens; + } + config.thinkingConfig = thinkingConfig; + } + + if (options.signal) { + if (options.signal.aborted) { + throw new Error("Request aborted"); + } + config.abortSignal = options.signal; + } + + const params: GenerateContentParameters = { + model: model.id, + contents, + config, + }; + + return params; +} diff --git a/packages/ai/src/stream.ts b/packages/ai/src/stream.ts index f68d5f60..d7499f7a 100644 --- a/packages/ai/src/stream.ts +++ b/packages/ai/src/stream.ts @@ -6,6 +6,7 @@ import { type GoogleThinkingLevel, streamGoogleGeminiCli, } from "./providers/google-gemini-cli.js"; +import { type GoogleVertexOptions, streamGoogleVertex } from "./providers/google-vertex.js"; import { type OpenAICompletionsOptions, streamOpenAICompletions } from "./providers/openai-completions.js"; import { type OpenAIResponsesOptions, streamOpenAIResponses } from "./providers/openai-responses.js"; import type { @@ -38,6 +39,14 @@ export function getEnvApiKey(provider: any): string | undefined { return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY; } + // Vertex AI doesn't use API keys. + // It relies on Google Cloud auth: `gcloud auth application-default login`. + // @google/genai library picks up and manages the auth automatically. + // Return a dummy value to maintain consistency. + if (provider === "google-vertex") { + return "vertex-ai-authenticated"; + } + const envMap: Record = { openai: "OPENAI_API_KEY", google: "GEMINI_API_KEY", @@ -85,6 +94,9 @@ export function stream( providerOptions as GoogleGeminiCliOptions, ); + case "google-vertex": + return streamGoogleVertex(model as Model<"google-vertex">, context, providerOptions as GoogleVertexOptions); + default: { // This should never be reached if all Api cases are handled const _exhaustive: never = api; @@ -239,6 +251,44 @@ function mapOptionsForApi( } satisfies GoogleGeminiCliOptions; } + case "google-vertex": { + // Explicitly disable thinking when reasoning is not specified + if (!options?.reasoning) { + return { ...base, thinking: { enabled: false } } satisfies GoogleVertexOptions; + } + + const vertexModel = model as Model<"google-vertex">; + const effort = clampReasoning(options.reasoning)!; + + if (isGemini3ProModel(vertexModel as unknown as Model<"google-generative-ai">)) { + return { + ...base, + thinking: { + enabled: true, + level: getGemini3ThinkingLevel(effort, vertexModel as unknown as Model<"google-generative-ai">), + }, + } satisfies GoogleVertexOptions; + } + + if (isGemini3FlashModel(vertexModel as unknown as Model<"google-generative-ai">)) { + return { + ...base, + thinking: { + enabled: true, + level: getGemini3ThinkingLevel(effort, vertexModel as unknown as Model<"google-generative-ai">), + }, + } satisfies GoogleVertexOptions; + } + + return { + ...base, + thinking: { + enabled: true, + budgetTokens: getGoogleBudget(vertexModel as unknown as Model<"google-generative-ai">, effort), + }, + } satisfies GoogleVertexOptions; + } + default: { // Exhaustiveness check const _exhaustive: never = model.api; diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 54944852..07988e30 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -1,6 +1,7 @@ import type { AnthropicOptions } from "./providers/anthropic.js"; import type { GoogleOptions } from "./providers/google.js"; import type { GoogleGeminiCliOptions } from "./providers/google-gemini-cli.js"; +import type { GoogleVertexOptions } from "./providers/google-vertex.js"; import type { OpenAICompletionsOptions } from "./providers/openai-completions.js"; import type { OpenAIResponsesOptions } from "./providers/openai-responses.js"; import type { AssistantMessageEventStream } from "./utils/event-stream.js"; @@ -12,7 +13,8 @@ export type Api = | "openai-responses" | "anthropic-messages" | "google-generative-ai" - | "google-gemini-cli"; + | "google-gemini-cli" + | "google-vertex"; export interface ApiOptionsMap { "anthropic-messages": AnthropicOptions; @@ -20,6 +22,7 @@ export interface ApiOptionsMap { "openai-responses": OpenAIResponsesOptions; "google-generative-ai": GoogleOptions; "google-gemini-cli": GoogleGeminiCliOptions; + "google-vertex": GoogleVertexOptions; } // Compile-time exhaustiveness check - this will fail if ApiOptionsMap doesn't have all KnownApi keys @@ -38,6 +41,7 @@ export type KnownProvider = | "google" | "google-gemini-cli" | "google-antigravity" + | "google-vertex" | "openai" | "github-copilot" | "xai" diff --git a/packages/ai/test/stream.test.ts b/packages/ai/test/stream.test.ts index 02229d66..992132b3 100644 --- a/packages/ai/test/stream.test.ts +++ b/packages/ai/test/stream.test.ts @@ -368,6 +368,46 @@ describe("Generate E2E Tests", () => { }); }); + describe("Google Vertex Provider (gemini-3-flash-preview)", () => { + const vertexProject = process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT; + const vertexLocation = process.env.GOOGLE_CLOUD_LOCATION; + const isVertexConfigured = Boolean(vertexProject && vertexLocation); + const vertexOptions = { project: vertexProject, location: vertexLocation } as const; + const llm = getModel("google-vertex", "gemini-3-flash-preview"); + + it.skipIf(!isVertexConfigured)("should complete basic text generation", { retry: 3 }, async () => { + await basicTextGeneration(llm, vertexOptions); + }); + + it.skipIf(!isVertexConfigured)("should handle tool calling", { retry: 3 }, async () => { + await handleToolCall(llm, vertexOptions); + }); + + it.skipIf(!isVertexConfigured)("should handle thinking", { retry: 3 }, async () => { + const { ThinkingLevel } = await import("@google/genai"); + await handleThinking(llm, { + ...vertexOptions, + thinking: { enabled: true, budgetTokens: 1024, level: ThinkingLevel.LOW }, + }); + }); + + it.skipIf(!isVertexConfigured)("should handle streaming", { retry: 3 }, async () => { + await handleStreaming(llm, vertexOptions); + }); + + it.skipIf(!isVertexConfigured)("should handle multi-turn with thinking and tools", { retry: 3 }, async () => { + const { ThinkingLevel } = await import("@google/genai"); + await multiTurn(llm, { + ...vertexOptions, + thinking: { enabled: true, budgetTokens: 1024, level: ThinkingLevel.MEDIUM }, + }); + }); + + it.skipIf(!isVertexConfigured)("should handle image input", { retry: 3 }, async () => { + await handleImage(llm, vertexOptions); + }); + }); + describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Completions Provider (gpt-4o-mini)", () => { const llm: Model<"openai-completions"> = { ...getModel("openai", "gpt-4o-mini"), api: "openai-completions" }; diff --git a/packages/coding-agent/src/core/model-resolver.ts b/packages/coding-agent/src/core/model-resolver.ts index d2124413..4a8e8c22 100644 --- a/packages/coding-agent/src/core/model-resolver.ts +++ b/packages/coding-agent/src/core/model-resolver.ts @@ -16,6 +16,7 @@ export const defaultModelPerProvider: Record = { google: "gemini-2.5-pro", "google-gemini-cli": "gemini-2.5-pro", "google-antigravity": "gemini-3-pro-high", + "google-vertex": "gemini-3-pro-preview", "github-copilot": "gpt-4o", openrouter: "openai/gpt-5.1-codex", xai: "grok-4-fast-non-reasoning",