diff --git a/packages/ai/src/providers/openai-completions.ts b/packages/ai/src/providers/openai-completions.ts index 8e80723f..c4332b3c 100644 --- a/packages/ai/src/providers/openai-completions.ts +++ b/packages/ai/src/providers/openai-completions.ts @@ -333,10 +333,12 @@ export const streamSimpleOpenAICompletions: StreamFunction<"openai-completions", const base = buildBaseOptions(model, options, apiKey); const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning); + const toolChoice = (options as OpenAICompletionsOptions | undefined)?.toolChoice; return streamOpenAICompletions(model, context, { ...base, reasoningEffort, + toolChoice, } satisfies OpenAICompletionsOptions); }; diff --git a/packages/ai/test/openai-completions-tool-choice.test.ts b/packages/ai/test/openai-completions-tool-choice.test.ts new file mode 100644 index 00000000..9bfeead1 --- /dev/null +++ b/packages/ai/test/openai-completions-tool-choice.test.ts @@ -0,0 +1,75 @@ +import { Type } from "@sinclair/typebox"; +import { describe, expect, it, vi } from "vitest"; +import type { Tool } from "../src/types.js"; + +let lastParams: unknown; + +class FakeOpenAI { + chat = { + completions: { + create: async (params: unknown) => { + lastParams = params; + return { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: {}, finish_reason: "stop" }], + usage: { + prompt_tokens: 1, + completion_tokens: 1, + prompt_tokens_details: { cached_tokens: 0 }, + completion_tokens_details: { reasoning_tokens: 0 }, + }, + }; + }, + }; + }, + }, + }; +} + +vi.mock("openai", () => ({ default: FakeOpenAI })); + +describe("openai-completions tool_choice", () => { + it("forwards toolChoice from simple options to payload", async () => { + const { streamSimple } = await import("../src/stream.js"); + const { getModel } = await import("../src/models.js"); + const { compat: _compat, ...baseModel } = getModel("openai", "gpt-4o-mini")!; + const model = { ...baseModel, api: "openai-completions" } as const; + const tools: Tool[] = [ + { + name: "ping", + description: "Ping tool", + parameters: Type.Object({ + ok: Type.Boolean(), + }), + }, + ]; + let payload: unknown; + + await streamSimple( + model, + { + messages: [ + { + role: "user", + content: "Call ping with ok=true", + timestamp: Date.now(), + }, + ], + tools, + }, + { + apiKey: "test", + toolChoice: "required", + onPayload: (params: unknown) => { + payload = params; + }, + } as unknown as Parameters[2], + ).result(); + + const params = (payload ?? lastParams) as { tool_choice?: string; tools?: unknown[] }; + expect(params.tool_choice).toBe("required"); + expect(Array.isArray(params.tools)).toBe(true); + expect(params.tools?.length ?? 0).toBeGreaterThan(0); + }); +});