diff --git a/package-lock.json b/package-lock.json index 36d62149..3dfb8aaa 100644 --- a/package-lock.json +++ b/package-lock.json @@ -654,6 +654,15 @@ } } }, + "node_modules/@google/generative-ai": { + "version": "0.24.1", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.24.1.tgz", + "integrity": "sha512-MqO+MLfM6kjxcKoy0p1wRzG3b4ZZXtPI+z2IE26UogS2Cm/XHO+7gGRBh6gcJsOiIVoH93UwKvW4HdgiOZCy9Q==", + "license": "Apache-2.0", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@mariozechner/ai": { "resolved": "packages/ai", "link": true @@ -1604,6 +1613,7 @@ "dependencies": { "@anthropic-ai/sdk": "0.60.0", "@google/genai": "1.14.0", + "@google/generative-ai": "^0.24.1", "chalk": "^5.5.0", "openai": "5.12.2" }, diff --git a/packages/ai/package.json b/packages/ai/package.json index 7af2e7a0..fe2db351 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -5,7 +5,10 @@ "type": "module", "main": "./dist/index.js", "types": "./dist/index.d.ts", - "files": ["dist", "README.md"], + "files": [ + "dist", + "README.md" + ], "scripts": { "clean": "rm -rf dist", "build": "tsc -p tsconfig.build.json", @@ -13,13 +16,21 @@ "prepublishOnly": "npm run clean && npm run build" }, "dependencies": { - "openai": "5.12.2", "@anthropic-ai/sdk": "0.60.0", "@google/genai": "1.14.0", - "chalk": "^5.5.0" + "@google/generative-ai": "^0.24.1", + "chalk": "^5.5.0", + "openai": "5.12.2" }, - "devDependencies": {}, - "keywords": ["ai", "llm", "openai", "anthropic", "gemini", "unified", "api"], + "keywords": [ + "ai", + "llm", + "openai", + "anthropic", + "gemini", + "unified", + "api" + ], "author": "Mario Zechner", "license": "MIT", "repository": { @@ -30,4 +41,4 @@ "engines": { "node": ">=20.0.0" } -} \ No newline at end of file +} diff --git a/packages/ai/src/providers/anthropic.ts b/packages/ai/src/providers/anthropic.ts index 8b1d7bc8..d7a406b0 100644 --- a/packages/ai/src/providers/anthropic.ts +++ b/packages/ai/src/providers/anthropic.ts @@ -186,7 +186,7 @@ export class AnthropicLLM implements LLM { toolCalls, model: this.model, usage, - stopResaon: this.mapStopReason(msg.stop_reason), + stopReason: this.mapStopReason(msg.stop_reason), }; } catch (error) { return { @@ -198,7 +198,7 @@ export class AnthropicLLM implements LLM { cacheRead: 0, cacheWrite: 0, }, - stopResaon: "error", + stopReason: "error", error: error instanceof Error ? error.message : String(error), }; } diff --git a/packages/ai/src/providers/gemini.ts b/packages/ai/src/providers/gemini.ts new file mode 100644 index 00000000..a3910fb0 --- /dev/null +++ b/packages/ai/src/providers/gemini.ts @@ -0,0 +1,264 @@ +import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"; +import type { + AssistantMessage, + Context, + LLM, + LLMOptions, + Message, + StopReason, + TokenUsage, + Tool, + ToolCall, +} from "../types.js"; + +export interface GeminiLLMOptions extends LLMOptions { + toolChoice?: "auto" | "none" | "any"; +} + +export class GeminiLLM implements LLM { + private client: GoogleGenerativeAI; + private model: string; + + constructor(model: string, apiKey?: string) { + if (!apiKey) { + if (!process.env.GEMINI_API_KEY) { + throw new Error( + "Gemini API key is required. Set GEMINI_API_KEY environment variable or pass it as an argument.", + ); + } + apiKey = process.env.GEMINI_API_KEY; + } + this.client = new GoogleGenerativeAI(apiKey); + this.model = model; + } + + async complete(context: Context, options?: GeminiLLMOptions): Promise { + try { + const model = this.client.getGenerativeModel({ + model: this.model, + systemInstruction: context.systemPrompt, + tools: context.tools ? this.convertTools(context.tools) : undefined, + toolConfig: options?.toolChoice + ? { + functionCallingConfig: { + mode: this.mapToolChoice(options.toolChoice), + }, + } + : undefined, + }); + + const contents = this.convertMessages(context.messages); + + const stream = await model.generateContentStream({ + contents, + generationConfig: { + temperature: options?.temperature, + maxOutputTokens: options?.maxTokens, + }, + }); + + let content = ""; + let thinking = ""; + const toolCalls: ToolCall[] = []; + let usage: TokenUsage = { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + }; + let stopReason: StopReason = "stop"; + let inTextBlock = false; + let inThinkingBlock = false; + + // Process the stream + for await (const chunk of stream.stream) { + // Extract parts from the chunk + const candidate = chunk.candidates?.[0]; + if (candidate?.content?.parts) { + for (const part of candidate.content.parts) { + if (part.text) { + // Check if it's thinking content + if ((part as any).thought) { + thinking += part.text; + options?.onThinking?.(part.text, false); + inThinkingBlock = true; + if (inTextBlock) { + options?.onText?.("", true); + inTextBlock = false; + } + } else { + content += part.text; + options?.onText?.(part.text, false); + inTextBlock = true; + if (inThinkingBlock) { + options?.onThinking?.("", true); + inThinkingBlock = false; + } + } + } + + // Handle function calls + if (part.functionCall) { + if (inTextBlock) { + options?.onText?.("", true); + inTextBlock = false; + } + if (inThinkingBlock) { + options?.onThinking?.("", true); + inThinkingBlock = false; + } + + toolCalls.push({ + id: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + name: part.functionCall.name, + arguments: part.functionCall.args as Record, + }); + } + } + } + + // Map finish reason + if (candidate?.finishReason) { + stopReason = this.mapStopReason(candidate.finishReason); + } + } + + // Signal end of blocks + if (inTextBlock) { + options?.onText?.("", true); + } + if (inThinkingBlock) { + options?.onThinking?.("", true); + } + + // Get final response for usage metadata + const response = await stream.response; + if (response.usageMetadata) { + usage = { + input: response.usageMetadata.promptTokenCount || 0, + output: response.usageMetadata.candidatesTokenCount || 0, + cacheRead: response.usageMetadata.cachedContentTokenCount || 0, + cacheWrite: 0, + }; + } + + return { + role: "assistant", + content: content || undefined, + thinking: thinking || undefined, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + model: this.model, + usage, + stopReason, + }; + } catch (error) { + return { + role: "assistant", + model: this.model, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + }, + stopReason: "error", + error: error instanceof Error ? error.message : String(error), + }; + } + } + + private convertMessages(messages: Message[]): any[] { + const contents: any[] = []; + + for (const msg of messages) { + if (msg.role === "user") { + contents.push({ + role: "user", + parts: [{ text: msg.content }], + }); + } else if (msg.role === "assistant") { + const parts: any[] = []; + + if (msg.content) { + parts.push({ text: msg.content }); + } + + if (msg.toolCalls) { + for (const toolCall of msg.toolCalls) { + parts.push({ + functionCall: { + name: toolCall.name, + args: toolCall.arguments, + }, + }); + } + } + + if (parts.length > 0) { + contents.push({ + role: "model", + parts, + }); + } + } else if (msg.role === "toolResult") { + // Tool results are sent as function responses + contents.push({ + role: "user", + parts: [ + { + functionResponse: { + name: msg.toolCallId.split("_")[1], // Extract function name from our ID format + response: { + result: msg.content, + isError: msg.isError || false, + }, + }, + }, + ], + }); + } + } + + return contents; + } + + private convertTools(tools: Tool[]): any[] { + return [ + { + functionDeclarations: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.parameters, + })), + }, + ]; + } + + private mapToolChoice(choice: string): FunctionCallingMode { + switch (choice) { + case "auto": + return FunctionCallingMode.AUTO; + case "none": + return FunctionCallingMode.NONE; + case "any": + return FunctionCallingMode.ANY; + default: + return FunctionCallingMode.AUTO; + } + } + + private mapStopReason(reason: string): StopReason { + switch (reason) { + case "STOP": + return "stop"; + case "MAX_TOKENS": + return "length"; + case "SAFETY": + return "safety"; + case "RECITATION": + return "safety"; + default: + return "stop"; + } + } +} diff --git a/packages/ai/src/providers/openai-completions.ts b/packages/ai/src/providers/openai-completions.ts index c2129a9c..e2253ba8 100644 --- a/packages/ai/src/providers/openai-completions.ts +++ b/packages/ai/src/providers/openai-completions.ts @@ -163,7 +163,7 @@ export class OpenAICompletionsLLM implements LLM { toolCalls: toolCalls.length > 0 ? toolCalls : undefined, model: this.model, usage, - stopResaon: this.mapStopReason(finishReason), + stopReason: this.mapStopReason(finishReason), }; } catch (error) { return { @@ -175,7 +175,7 @@ export class OpenAICompletionsLLM implements LLM { cacheRead: 0, cacheWrite: 0, }, - stopResaon: "error", + stopReason: "error", error: error instanceof Error ? error.message : String(error), }; } diff --git a/packages/ai/src/providers/openai-responses.ts b/packages/ai/src/providers/openai-responses.ts index 0ef453c5..e6d3eed5 100644 --- a/packages/ai/src/providers/openai-responses.ts +++ b/packages/ai/src/providers/openai-responses.ts @@ -144,7 +144,7 @@ export class OpenAIResponsesLLM implements LLM { role: "assistant", model: this.model, usage, - stopResaon: "error", + stopReason: "error", error: `Code ${event.code}: ${event.message}` || "Unknown error", }; } @@ -158,7 +158,7 @@ export class OpenAIResponsesLLM implements LLM { toolCalls: toolCalls.length > 0 ? toolCalls : undefined, model: this.model, usage, - stopResaon: stopReason, + stopReason, }; } catch (error) { return { @@ -170,7 +170,7 @@ export class OpenAIResponsesLLM implements LLM { cacheRead: 0, cacheWrite: 0, }, - stopResaon: "error", + stopReason: "error", error: error instanceof Error ? error.message : String(error), }; } diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 1d5e6703..74bb095f 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -51,7 +51,7 @@ export interface AssistantMessage { model: string; usage: TokenUsage; - stopResaon: StopReason; + stopReason: StopReason; error?: string | Error; } diff --git a/packages/ai/test/examples/gemini.ts b/packages/ai/test/examples/gemini.ts new file mode 100644 index 00000000..813358fb --- /dev/null +++ b/packages/ai/test/examples/gemini.ts @@ -0,0 +1,61 @@ +import chalk from "chalk"; +import { GeminiLLM, GeminiLLMOptions } from "../../src/providers/gemini.js"; +import { Context, Tool } from "../../src/types.js"; + +// Define a simple calculator tool +const tools: Tool[] = [ + { + name: "calculate", + description: "Perform a mathematical calculation", + parameters: { + type: "object" as const, + properties: { + expression: { + type: "string", + description: "The mathematical expression to evaluate" + } + }, + required: ["expression"] + } + } +]; + +const options: GeminiLLMOptions = { + onText: (t, complete) => process.stdout.write(t + (complete ? "\n" : "")), + onThinking: (t, complete) => process.stdout.write(chalk.dim(t + (complete ? "\n" : ""))), + toolChoice: "auto" +}; + +const ai = new GeminiLLM("gemini-2.0-flash-exp", process.env.GEMINI_API_KEY || "fake-api-key-for-testing"); +const context: Context = { + systemPrompt: "You are a helpful assistant that can use tools to answer questions.", + messages: [ + { + role: "user", + content: "Think about birds briefly. Then give me a list of 10 birds. Finally, calculate 42 * 17 + 123 and 453 + 434 in parallel using the calculator tool.", + } + ], + tools +} + +let msg = await ai.complete(context, options) +context.messages.push(msg); +console.log(); +console.log(chalk.yellow(JSON.stringify(msg, null, 2))); + +for (const toolCall of msg.toolCalls || []) { + if (toolCall.name === "calculate") { + const expression = toolCall.arguments.expression; + const result = eval(expression); + context.messages.push({ + role: "toolResult", + content: `The result of ${expression} is ${result}.`, + toolCallId: toolCall.id, + isError: false + }); + } +} + +msg = await ai.complete(context, options); +console.log(); +console.log(chalk.yellow(JSON.stringify(msg, null, 2))); \ No newline at end of file